Browse Source

lnrouter: fix primary key conflict in Policy update

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
2c80996fbf
  1. 44
      electrum/lnrouter.py

44
electrum/lnrouter.py

@ -347,6 +347,7 @@ class ChannelDB(SqlDB):
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
new_policies = {}
for msg_payload in msg_payloads:
short_channel_id = msg_payload['short_channel_id']
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
@ -354,7 +355,27 @@ class ChannelDB(SqlDB):
channel_info = channel_infos.get(short_channel_id)
if not channel_info:
continue
self._update_channel_info(channel_info, msg_payload, trusted=trusted)
flags = int.from_bytes(msg_payload['channel_flags'], 'big')
direction = flags & FLAG_DIRECTION
node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id
if not trusted and not verify_sig_for_channel_update(msg_payload, bytes.fromhex(node_id)):
continue
short_channel_id = channel_info.short_channel_id
new_policy = Policy.from_msg(msg_payload, node_id, channel_info.short_channel_id)
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none()
if old_policy:
if old_policy.timestamp >= new_policy.timestamp:
continue
self.DBSession.delete(old_policy)
p = new_policies.get((short_channel_id, node_id))
if p and p.timestamp >= new_policy.timestamp:
continue
new_policies[(short_channel_id, node_id)] = new_policy
# commit pending removals
self.DBSession.commit()
# add and commit new policies
for new_policy in new_policies.values():
self.DBSession.add(new_policy)
self.DBSession.commit()
@sql
@ -468,27 +489,6 @@ class ChannelDB(SqlDB):
bh2u(node2) if full_ids else bh2u(node2[-4:]),
direction))
def _update_channel_info(self, channel_info, msg: dict, trusted=False):
assert channel_info.short_channel_id == msg['short_channel_id'].hex()
flags = int.from_bytes(msg['channel_flags'], 'big')
direction = flags & FLAG_DIRECTION
node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id
new_policy = Policy.from_msg(msg, node_id, channel_info.short_channel_id)
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node=node_id).one_or_none()
if not old_policy:
self.DBSession.add(new_policy)
return
if old_policy.timestamp >= new_policy.timestamp:
return # ignore
if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)):
return # ignore
old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta
old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat
old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat
old_policy.fee_base_msat = new_policy.fee_base_msat
old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths
old_policy.channel_flags = new_policy.channel_flags
old_policy.timestamp = new_policy.timestamp
@sql
def get_policy_for_node(self, channel_info, node) -> Optional['Policy']:

Loading…
Cancel
Save