Browse Source

channel_db: raise specific exception if database is not loaded when we try to find a route

hard-fail-on-bad-server-string
ThomasV 5 years ago
parent
commit
beac1c4ddc
  1. 5
      electrum/channel_db.py
  2. 1
      electrum/tests/test_lnpeer.py
  3. 1
      electrum/tests/test_lnrouter.py

5
electrum/channel_db.py

@ -313,7 +313,8 @@ class ChannelDB(SqlDB):
return None return None
def get_recent_peers(self): def get_recent_peers(self):
assert self.data_loaded.is_set(), "channelDB load_data did not finish yet!" if not self.data_loaded.is_set():
raise Exception("channelDB data not loaded yet!")
with self.lock: with self.lock:
ret = [self.get_last_good_address(node_id) ret = [self.get_last_good_address(node_id)
for node_id in self._recent_peers] for node_id in self._recent_peers]
@ -693,6 +694,8 @@ class ChannelDB(SqlDB):
def get_channels_for_node(self, node_id: bytes, *, def get_channels_for_node(self, node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]: my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]:
"""Returns the set of short channel IDs where node_id is one of the channel participants.""" """Returns the set of short channel IDs where node_id is one of the channel participants."""
if not self.data_loaded.is_set():
raise Exception("channelDB data not loaded yet!")
relevant_channels = self._channels_for_node.get(node_id) or set() relevant_channels = self._channels_for_node.get(node_id) or set()
relevant_channels = set(relevant_channels) # copy relevant_channels = set(relevant_channels) # copy
# add our own channels # TODO maybe slow? # add our own channels # TODO maybe slow?

1
electrum/tests/test_lnpeer.py

@ -55,6 +55,7 @@ class MockNetwork:
self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir) self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir)
self.asyncio_loop = asyncio.get_event_loop() self.asyncio_loop = asyncio.get_event_loop()
self.channel_db = ChannelDB(self) self.channel_db = ChannelDB(self)
self.channel_db.data_loaded.set()
self.path_finder = LNPathFinder(self.channel_db) self.path_finder = LNPathFinder(self.channel_db)
self.tx_queue = tx_queue self.tx_queue = tx_queue

1
electrum/tests/test_lnrouter.py

@ -49,6 +49,7 @@ class Test_LNRouter(TestCaseForTestnet):
register_callback = lambda *args: None register_callback = lambda *args: None
interface = None interface = None
fake_network.channel_db = lnrouter.ChannelDB(fake_network()) fake_network.channel_db = lnrouter.ChannelDB(fake_network())
fake_network.channel_db.data_loaded.set()
cdb = fake_network.channel_db cdb = fake_network.channel_db
path_finder = lnrouter.LNPathFinder(cdb) path_finder = lnrouter.LNPathFinder(cdb)
self.assertEqual(cdb.num_channels, 0) self.assertEqual(cdb.num_channels, 0)

Loading…
Cancel
Save