|
|
@ -347,6 +347,7 @@ class ChannelDB(SqlDB): |
|
|
|
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads] |
|
|
|
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() |
|
|
|
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} |
|
|
|
new_policies = {} |
|
|
|
for msg_payload in msg_payloads: |
|
|
|
short_channel_id = msg_payload['short_channel_id'] |
|
|
|
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']: |
|
|
@ -354,7 +355,27 @@ class ChannelDB(SqlDB): |
|
|
|
channel_info = channel_infos.get(short_channel_id) |
|
|
|
if not channel_info: |
|
|
|
continue |
|
|
|
self._update_channel_info(channel_info, msg_payload, trusted=trusted) |
|
|
|
flags = int.from_bytes(msg_payload['channel_flags'], 'big') |
|
|
|
direction = flags & FLAG_DIRECTION |
|
|
|
node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id |
|
|
|
if not trusted and not verify_sig_for_channel_update(msg_payload, bytes.fromhex(node_id)): |
|
|
|
continue |
|
|
|
short_channel_id = channel_info.short_channel_id |
|
|
|
new_policy = Policy.from_msg(msg_payload, node_id, channel_info.short_channel_id) |
|
|
|
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none() |
|
|
|
if old_policy: |
|
|
|
if old_policy.timestamp >= new_policy.timestamp: |
|
|
|
continue |
|
|
|
self.DBSession.delete(old_policy) |
|
|
|
p = new_policies.get((short_channel_id, node_id)) |
|
|
|
if p and p.timestamp >= new_policy.timestamp: |
|
|
|
continue |
|
|
|
new_policies[(short_channel_id, node_id)] = new_policy |
|
|
|
# commit pending removals |
|
|
|
self.DBSession.commit() |
|
|
|
# add and commit new policies |
|
|
|
for new_policy in new_policies.values(): |
|
|
|
self.DBSession.add(new_policy) |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
@sql |
|
|
@ -468,27 +489,6 @@ class ChannelDB(SqlDB): |
|
|
|
bh2u(node2) if full_ids else bh2u(node2[-4:]), |
|
|
|
direction)) |
|
|
|
|
|
|
|
def _update_channel_info(self, channel_info, msg: dict, trusted=False): |
|
|
|
assert channel_info.short_channel_id == msg['short_channel_id'].hex() |
|
|
|
flags = int.from_bytes(msg['channel_flags'], 'big') |
|
|
|
direction = flags & FLAG_DIRECTION |
|
|
|
node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id |
|
|
|
new_policy = Policy.from_msg(msg, node_id, channel_info.short_channel_id) |
|
|
|
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node=node_id).one_or_none() |
|
|
|
if not old_policy: |
|
|
|
self.DBSession.add(new_policy) |
|
|
|
return |
|
|
|
if old_policy.timestamp >= new_policy.timestamp: |
|
|
|
return # ignore |
|
|
|
if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)): |
|
|
|
return # ignore |
|
|
|
old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta |
|
|
|
old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat |
|
|
|
old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat |
|
|
|
old_policy.fee_base_msat = new_policy.fee_base_msat |
|
|
|
old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths |
|
|
|
old_policy.channel_flags = new_policy.channel_flags |
|
|
|
old_policy.timestamp = new_policy.timestamp |
|
|
|
|
|
|
|
@sql |
|
|
|
def get_policy_for_node(self, channel_info, node) -> Optional['Policy']: |
|
|
|