From 34f22e6681f9360f743dd1cea1675845c43b8b55 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Thu, 21 Mar 2019 12:44:32 +0100 Subject: [PATCH] lnrouter: load data before finding path --- electrum/lnrouter.py | 73 ++++++++++++++++++++++---------------------- electrum/lnworker.py | 1 + 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 2e4e7ce20..ad03b8f41 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -229,7 +229,9 @@ class ChannelDB(SqlDB): def _update_counts(self): self.num_channels = self.DBSession.query(ChannelInfo).count() + self.num_policies = self.DBSession.query(Policy).count() self.num_nodes = self.DBSession.query(NodeInfo).count() + self.print_error('update counts', self.num_channels, self.num_policies) @sql def add_recent_peer(self, peer: LNPeerAddr): @@ -272,19 +274,6 @@ class ChannelDB(SqlDB): r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all() return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r] - @sql - def get_channel_info(self, channel_id: bytes): - return self._chan_query_for_id(channel_id).one_or_none() - - @sql - def get_channels_for_node(self, node_id): - """Returns the set of channels that have node_id as one of the endpoints.""" - condition = or_( - ChannelInfo.node1_id == node_id.hex(), - ChannelInfo.node2_id == node_id.hex()) - rows = self.DBSession.query(ChannelInfo).filter(condition).all() - return [bytes.fromhex(x.short_channel_id) for x in rows] - @sql def missing_short_chan_ids(self) -> Set[int]: expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id))) @@ -296,7 +285,7 @@ class ChannelDB(SqlDB): @sql def add_verified_channel_info(self, short_id, capacity): # called from lnchannelverifier - channel_info = self._chan_query_for_id(short_id).one_or_none() + channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none() channel_info.trusted = True channel_info.capacity = capacity self.DBSession.commit() @@ -372,7 +361,6 @@ class ChannelDB(SqlDB): if p and p.timestamp >= new_policy.timestamp: continue new_policies[(short_channel_id, node_id)] = new_policy - #self.print_error('on_channel_update: %d/%d'%(len(new_policies), len(msg_payloads))) # commit pending removals self.DBSession.commit() # add and commit new policies @@ -380,7 +368,9 @@ class ChannelDB(SqlDB): self.DBSession.add(new_policy) self.DBSession.commit() if new_policies: + self.print_error('on_channel_update: %d/%d'%(len(new_policies), len(msg_payloads))) self.print_error('last timestamp:', datetime.fromtimestamp(self._get_last_timestamp()).ctime()) + self._update_counts() @sql #@profiler @@ -432,7 +422,7 @@ class ChannelDB(SqlDB): if not start_node_id or not short_channel_id: return None channel_info = self.get_channel_info(short_channel_id) if channel_info is not None: - return self.get_policy_for_node(channel_info, start_node_id) + return self.get_policy_for_node(short_channel_id, start_node_id) msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) if not msg: return None @@ -446,12 +436,12 @@ class ChannelDB(SqlDB): @sql def remove_channel(self, short_channel_id): - self._chan_query_for_id(short_channel_id).delete('evaluate') + r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none() + if not r: + return + self.DBSession.delete(r) self.DBSession.commit() - def _chan_query_for_id(self, short_channel_id) -> Query: - return self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()) - def print_graph(self, full_ids=False): # used for debugging. # FIXME there is a race here - iterables could change size from another thread @@ -488,23 +478,33 @@ class ChannelDB(SqlDB): direction)) - @sql - def get_policy_for_node(self, channel_info, node) -> Optional['Policy']: - """ - raises when initiator/non-initiator both unequal node - """ - if node.hex() not in (channel_info.node1_id, channel_info.node2_id): - raise Exception("the given node is not a party in this channel") - n1 = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node = channel_info.node1_id).one_or_none() - if n1: - return n1 - n2 = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node = channel_info.node2_id).one_or_none() - return n2 - @sql def get_node_addresses(self, node_info): return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all() + @sql + @profiler + def load_data(self): + r = self.DBSession.query(ChannelInfo).all() + self._channels = dict([(bfh(x.short_channel_id), x) for x in r]) + r = self.DBSession.query(Policy).filter_by().all() + self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r]) + self._channels_for_node = defaultdict(set) + for channel_info in self._channels.values(): + self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id)) + self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id)) + self.print_error('load data', len(self._channels), len(self._policies), len(self._channels_for_node)) + + def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']: + return self._policies.get((node_id, short_channel_id)) + + def get_channel_info(self, channel_id: bytes): + return self._channels.get(channel_id) + + def get_channels_for_node(self, node_id): + """Returns the set of channels that have node_id as one of the endpoints.""" + return self._channels_for_node.get(node_id) + class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes), @@ -586,9 +586,9 @@ class LNPathFinder(PrintError): channel_info = self.channel_db.get_channel_info(short_channel_id) # type: ChannelInfo if channel_info is None: return float('inf'), 0 - - channel_policy = self.channel_db.get_policy_for_node(channel_info, start_node) - if channel_policy is None: return float('inf'), 0 + channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node) + if channel_policy is None: + return float('inf'), 0 if channel_policy.is_disabled(): return float('inf'), 0 route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node) if payment_amt_msat < channel_policy.htlc_minimum_msat: @@ -618,6 +618,7 @@ class LNPathFinder(PrintError): To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1]; i.e. an element reads as, "to get to node_id, travel through short_channel_id" """ + self.channel_db.load_data() assert type(nodeA) is bytes assert type(nodeB) is bytes assert type(invoice_amount_msat) is int diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 44d19ab70..883c54720 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -611,6 +611,7 @@ class LNWorker(PrintError): def _calc_routing_hints_for_invoice(self, amount_sat): """calculate routing hints (BOLT-11 'r' field)""" + self.channel_db.load_data() routing_hints = [] with self.lock: channels = list(self.channels.values())