|
|
@ -114,22 +114,16 @@ class Policy(Base): |
|
|
|
timestamp = Column(Integer, nullable=False) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def from_msg(payload, start_node, short_channel_id): |
|
|
|
cltv_expiry_delta = payload['cltv_expiry_delta'] |
|
|
|
htlc_minimum_msat = payload['htlc_minimum_msat'] |
|
|
|
fee_base_msat = payload['fee_base_msat'] |
|
|
|
fee_proportional_millionths = payload['fee_proportional_millionths'] |
|
|
|
channel_flags = payload['channel_flags'] |
|
|
|
timestamp = payload['timestamp'] |
|
|
|
htlc_maximum_msat = payload.get('htlc_maximum_msat') # optional |
|
|
|
|
|
|
|
cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big") |
|
|
|
htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big") |
|
|
|
htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None |
|
|
|
fee_base_msat = int.from_bytes(fee_base_msat, "big") |
|
|
|
fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big") |
|
|
|
channel_flags = int.from_bytes(channel_flags, "big") |
|
|
|
timestamp = int.from_bytes(timestamp, "big") |
|
|
|
def from_msg(payload): |
|
|
|
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big") |
|
|
|
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big") |
|
|
|
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None |
|
|
|
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big") |
|
|
|
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big") |
|
|
|
channel_flags = int.from_bytes(payload['channel_flags'], "big") |
|
|
|
timestamp = int.from_bytes(payload['timestamp'], "big") |
|
|
|
start_node = payload['start_node'].hex() |
|
|
|
short_channel_id = payload['short_channel_id'].hex() |
|
|
|
|
|
|
|
return Policy(start_node=start_node, |
|
|
|
short_channel_id=short_channel_id, |
|
|
@ -341,71 +335,98 @@ class ChannelDB(SqlDB): |
|
|
|
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one() |
|
|
|
return r.max_timestamp or 0 |
|
|
|
|
|
|
|
def print_change(self, old_policy, new_policy): |
|
|
|
# print what changed between policies |
|
|
|
if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta: |
|
|
|
self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}') |
|
|
|
if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat: |
|
|
|
self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}') |
|
|
|
if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat: |
|
|
|
self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}') |
|
|
|
if old_policy.fee_base_msat != new_policy.fee_base_msat: |
|
|
|
self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}') |
|
|
|
if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths: |
|
|
|
self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}') |
|
|
|
if old_policy.channel_flags != new_policy.channel_flags: |
|
|
|
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}') |
|
|
|
|
|
|
|
@sql |
|
|
|
def get_info_for_updates(self, msg_payloads): |
|
|
|
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads] |
|
|
|
def get_info_for_updates(self, payloads): |
|
|
|
short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads] |
|
|
|
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() |
|
|
|
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} |
|
|
|
return channel_infos |
|
|
|
|
|
|
|
@sql |
|
|
|
def get_policies_for_updates(self, payloads): |
|
|
|
out = {} |
|
|
|
for payload in payloads: |
|
|
|
short_channel_id = payload['short_channel_id'].hex() |
|
|
|
start_node = payload['start_node'].hex() |
|
|
|
policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none() |
|
|
|
if policy: |
|
|
|
out[short_channel_id+start_node] = policy |
|
|
|
return out |
|
|
|
|
|
|
|
@profiler |
|
|
|
def filter_channel_updates(self, payloads): |
|
|
|
# add 'node_id' to payload |
|
|
|
channel_infos = self.get_info_for_updates(payloads) |
|
|
|
def filter_channel_updates(self, payloads, max_age=None): |
|
|
|
orphaned = [] # no channel announcement for channel update |
|
|
|
expired = [] # update older than two weeks |
|
|
|
deprecated = [] # update older than database entry |
|
|
|
good = [] # good updates |
|
|
|
to_delete = [] # database entries to delete |
|
|
|
# filter orphaned and expired first |
|
|
|
known = [] |
|
|
|
unknown = [] |
|
|
|
now = int(time.time()) |
|
|
|
channel_infos = self.get_info_for_updates(payloads) |
|
|
|
for payload in payloads: |
|
|
|
short_channel_id = payload['short_channel_id'] |
|
|
|
timestamp = int.from_bytes(payload['timestamp'], "big") |
|
|
|
if max_age and now - timestamp > max_age: |
|
|
|
expired.append(short_channel_id) |
|
|
|
continue |
|
|
|
channel_info = channel_infos.get(short_channel_id) |
|
|
|
if not channel_info: |
|
|
|
unknown.append(short_channel_id) |
|
|
|
orphaned.append(short_channel_id) |
|
|
|
continue |
|
|
|
flags = int.from_bytes(payload['channel_flags'], 'big') |
|
|
|
direction = flags & FLAG_DIRECTION |
|
|
|
node_id = bfh(channel_info.node1_id if direction == 0 else channel_info.node2_id) |
|
|
|
payload['node_id'] = node_id |
|
|
|
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id |
|
|
|
payload['start_node'] = bfh(start_node) |
|
|
|
known.append(payload) |
|
|
|
return known, unknown |
|
|
|
# compare updates to existing database entries |
|
|
|
old_policies = self.get_policies_for_updates(known) |
|
|
|
for payload in known: |
|
|
|
timestamp = int.from_bytes(payload['timestamp'], "big") |
|
|
|
start_node = payload['start_node'].hex() |
|
|
|
short_channel_id = payload['short_channel_id'].hex() |
|
|
|
old_policy = old_policies.get(short_channel_id+start_node) |
|
|
|
if old_policy: |
|
|
|
if timestamp <= old_policy.timestamp: |
|
|
|
deprecated.append(short_channel_id) |
|
|
|
else: |
|
|
|
good.append(payload) |
|
|
|
to_delete.append(old_policy) |
|
|
|
else: |
|
|
|
good.append(payload) |
|
|
|
return orphaned, expired, deprecated, good, to_delete |
|
|
|
|
|
|
|
def add_channel_update(self, payload): |
|
|
|
# called in tests/test_lnrouter |
|
|
|
good, bad = self.filter_channel_updates([payload]) |
|
|
|
assert len(bad) == 0 |
|
|
|
self.on_channel_update(good) |
|
|
|
orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload]) |
|
|
|
assert len(good) == 1 |
|
|
|
self.update_policies(good, to_delete) |
|
|
|
|
|
|
|
@sql |
|
|
|
@profiler |
|
|
|
def on_channel_update(self, msg_payloads): |
|
|
|
now = int(time.time()) |
|
|
|
if type(msg_payloads) is dict: |
|
|
|
msg_payloads = [msg_payloads] |
|
|
|
new_policies = {} |
|
|
|
for msg_payload in msg_payloads: |
|
|
|
short_channel_id = msg_payload['short_channel_id'].hex() |
|
|
|
node_id = msg_payload['node_id'].hex() |
|
|
|
new_policy = Policy.from_msg(msg_payload, node_id, short_channel_id) |
|
|
|
# must not be older than two weeks |
|
|
|
if new_policy.timestamp < now - 14*24*3600: |
|
|
|
continue |
|
|
|
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none() |
|
|
|
if old_policy: |
|
|
|
if old_policy.timestamp >= new_policy.timestamp: |
|
|
|
continue |
|
|
|
self.DBSession.delete(old_policy) |
|
|
|
p = new_policies.get((short_channel_id, node_id)) |
|
|
|
if p and p.timestamp >= new_policy.timestamp: |
|
|
|
continue |
|
|
|
new_policies[(short_channel_id, node_id)] = new_policy |
|
|
|
# commit pending removals |
|
|
|
def update_policies(self, to_add, to_delete): |
|
|
|
for policy in to_delete: |
|
|
|
self.DBSession.delete(policy) |
|
|
|
self.DBSession.commit() |
|
|
|
# add and commit new policies |
|
|
|
for new_policy in new_policies.values(): |
|
|
|
self.DBSession.add(new_policy) |
|
|
|
for payload in to_add: |
|
|
|
policy = Policy.from_msg(payload) |
|
|
|
self.DBSession.add(policy) |
|
|
|
self.DBSession.commit() |
|
|
|
if new_policies: |
|
|
|
self.logger.debug(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}') |
|
|
|
#self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}') |
|
|
|
self._update_counts() |
|
|
|
self._update_counts() |
|
|
|
|
|
|
|
@sql |
|
|
|
@profiler |
|
|
@ -454,7 +475,7 @@ class ChannelDB(SqlDB): |
|
|
|
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) |
|
|
|
if not msg: |
|
|
|
return None |
|
|
|
return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB |
|
|
|
return Policy.from_msg(msg) # won't actually be written to DB |
|
|
|
|
|
|
|
@sql |
|
|
|
@profiler |
|
|
@ -496,6 +517,7 @@ class ChannelDB(SqlDB): |
|
|
|
if not verify_sig_for_channel_update(msg_payload, start_node_id): |
|
|
|
return # ignore |
|
|
|
short_channel_id = msg_payload['short_channel_id'] |
|
|
|
msg_payload['start_node'] = start_node_id |
|
|
|
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload |
|
|
|
|
|
|
|
@sql |
|
|
|