Browse Source

improve filter_channel_updates

blacklist channels that do not really get updated
dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
eb4e6bb0de
  1. 38
      electrum/lnpeer.py
  2. 142
      electrum/lnrouter.py

38
electrum/lnpeer.py

@ -241,12 +241,14 @@ class Peer(Logger):
self.verify_node_announcements(node_anns)
self.channel_db.on_node_announcement(node_anns)
# channel updates
good, bad = self.channel_db.filter_channel_updates(chan_upds)
if bad:
self.logger.info(f'adding {len(bad)} unknown channel ids')
self.network.lngossip.add_new_ids(bad)
self.verify_channel_updates(good)
self.channel_db.on_channel_update(good)
orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates(chan_upds, max_age=self.network.lngossip.max_age)
if orphaned:
self.logger.info(f'adding {len(orphaned)} unknown channel ids')
self.network.lngossip.add_new_ids(orphaned)
if good:
self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds)}')
self.verify_channel_updates(good)
self.channel_db.update_policies(good, to_delete)
# refresh gui
if chan_anns or node_anns or chan_upds:
self.network.lngossip.refresh_gui()
@ -273,7 +275,7 @@ class Peer(Logger):
short_channel_id = payload['short_channel_id']
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
raise Exception('wrong chain hash')
if not verify_sig_for_channel_update(payload, payload['node_id']):
if not verify_sig_for_channel_update(payload, payload['start_node']):
raise BaseException('verify error')
@log_exceptions
@ -990,21 +992,29 @@ class Peer(Logger):
OnionFailureCode.EXPIRY_TOO_SOON: 2,
OnionFailureCode.CHANNEL_DISABLED: 4,
}
offset = failure_codes.get(code)
if offset:
if code in failure_codes:
offset = failure_codes[code]
channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:]
message_type, payload = decode_msg(channel_update)
payload['raw'] = channel_update
try:
self.logger.info(f"trying to apply channel update on our db {payload}")
self.channel_db.add_channel_update(payload)
self.logger.info("successfully applied channel update on our db")
except NotFoundChanAnnouncementForUpdate:
orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates([payload])
if good:
self.verify_channel_updates(good)
self.channel_db.update_policies(good, to_delete)
self.logger.info("applied channel update on our db")
elif orphaned:
# maybe it is a private channel (and data in invoice was outdated)
self.logger.info("maybe channel update is for private channel?")
start_node_id = route[sender_idx].node_id
self.channel_db.add_channel_update_for_private_channel(payload, start_node_id)
elif expired:
blacklist = True
elif deprecated:
self.logger.info(f'channel update is not more recent. blacklisting channel')
blacklist = True
else:
blacklist = True
if blacklist:
# blacklist channel after reporter node
# TODO this should depend on the error (even more granularity)
# also, we need finer blacklisting (directed edges; nodes)

142
electrum/lnrouter.py

@ -114,22 +114,16 @@ class Policy(Base):
timestamp = Column(Integer, nullable=False)
@staticmethod
def from_msg(payload, start_node, short_channel_id):
cltv_expiry_delta = payload['cltv_expiry_delta']
htlc_minimum_msat = payload['htlc_minimum_msat']
fee_base_msat = payload['fee_base_msat']
fee_proportional_millionths = payload['fee_proportional_millionths']
channel_flags = payload['channel_flags']
timestamp = payload['timestamp']
htlc_maximum_msat = payload.get('htlc_maximum_msat') # optional
cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None
fee_base_msat = int.from_bytes(fee_base_msat, "big")
fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
channel_flags = int.from_bytes(channel_flags, "big")
timestamp = int.from_bytes(timestamp, "big")
def from_msg(payload):
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big")
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big")
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big")
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big")
channel_flags = int.from_bytes(payload['channel_flags'], "big")
timestamp = int.from_bytes(payload['timestamp'], "big")
start_node = payload['start_node'].hex()
short_channel_id = payload['short_channel_id'].hex()
return Policy(start_node=start_node,
short_channel_id=short_channel_id,
@ -341,71 +335,98 @@ class ChannelDB(SqlDB):
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one()
return r.max_timestamp or 0
def print_change(self, old_policy, new_policy):
# print what changed between policies
if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
if old_policy.fee_base_msat != new_policy.fee_base_msat:
self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
if old_policy.channel_flags != new_policy.channel_flags:
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
@sql
def get_info_for_updates(self, msg_payloads):
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
def get_info_for_updates(self, payloads):
short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads]
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
return channel_infos
@sql
def get_policies_for_updates(self, payloads):
out = {}
for payload in payloads:
short_channel_id = payload['short_channel_id'].hex()
start_node = payload['start_node'].hex()
policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none()
if policy:
out[short_channel_id+start_node] = policy
return out
@profiler
def filter_channel_updates(self, payloads):
# add 'node_id' to payload
channel_infos = self.get_info_for_updates(payloads)
def filter_channel_updates(self, payloads, max_age=None):
orphaned = [] # no channel announcement for channel update
expired = [] # update older than two weeks
deprecated = [] # update older than database entry
good = [] # good updates
to_delete = [] # database entries to delete
# filter orphaned and expired first
known = []
unknown = []
now = int(time.time())
channel_infos = self.get_info_for_updates(payloads)
for payload in payloads:
short_channel_id = payload['short_channel_id']
timestamp = int.from_bytes(payload['timestamp'], "big")
if max_age and now - timestamp > max_age:
expired.append(short_channel_id)
continue
channel_info = channel_infos.get(short_channel_id)
if not channel_info:
unknown.append(short_channel_id)
orphaned.append(short_channel_id)
continue
flags = int.from_bytes(payload['channel_flags'], 'big')
direction = flags & FLAG_DIRECTION
node_id = bfh(channel_info.node1_id if direction == 0 else channel_info.node2_id)
payload['node_id'] = node_id
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
payload['start_node'] = bfh(start_node)
known.append(payload)
return known, unknown
# compare updates to existing database entries
old_policies = self.get_policies_for_updates(known)
for payload in known:
timestamp = int.from_bytes(payload['timestamp'], "big")
start_node = payload['start_node'].hex()
short_channel_id = payload['short_channel_id'].hex()
old_policy = old_policies.get(short_channel_id+start_node)
if old_policy:
if timestamp <= old_policy.timestamp:
deprecated.append(short_channel_id)
else:
good.append(payload)
to_delete.append(old_policy)
else:
good.append(payload)
return orphaned, expired, deprecated, good, to_delete
def add_channel_update(self, payload):
# called in tests/test_lnrouter
good, bad = self.filter_channel_updates([payload])
assert len(bad) == 0
self.on_channel_update(good)
orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload])
assert len(good) == 1
self.update_policies(good, to_delete)
@sql
@profiler
def on_channel_update(self, msg_payloads):
now = int(time.time())
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
new_policies = {}
for msg_payload in msg_payloads:
short_channel_id = msg_payload['short_channel_id'].hex()
node_id = msg_payload['node_id'].hex()
new_policy = Policy.from_msg(msg_payload, node_id, short_channel_id)
# must not be older than two weeks
if new_policy.timestamp < now - 14*24*3600:
continue
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none()
if old_policy:
if old_policy.timestamp >= new_policy.timestamp:
continue
self.DBSession.delete(old_policy)
p = new_policies.get((short_channel_id, node_id))
if p and p.timestamp >= new_policy.timestamp:
continue
new_policies[(short_channel_id, node_id)] = new_policy
# commit pending removals
def update_policies(self, to_add, to_delete):
for policy in to_delete:
self.DBSession.delete(policy)
self.DBSession.commit()
# add and commit new policies
for new_policy in new_policies.values():
self.DBSession.add(new_policy)
for payload in to_add:
policy = Policy.from_msg(payload)
self.DBSession.add(policy)
self.DBSession.commit()
if new_policies:
self.logger.debug(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}')
#self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
self._update_counts()
self._update_counts()
@sql
@profiler
@ -454,7 +475,7 @@ class ChannelDB(SqlDB):
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
if not msg:
return None
return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB
return Policy.from_msg(msg) # won't actually be written to DB
@sql
@profiler
@ -496,6 +517,7 @@ class ChannelDB(SqlDB):
if not verify_sig_for_channel_update(msg_payload, start_node_id):
return # ignore
short_channel_id = msg_payload['short_channel_id']
msg_payload['start_node'] = start_node_id
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
@sql

Loading…
Cancel
Save