Browse Source

sqlite in lnrouter: lnpeer: introduce _gossip_loop for gossip handling separated from message handling

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 6 years ago
committed by ThomasV
parent
commit
95a2174789
  1. 44
      electrum/lnpeer.py
  2. 16
      electrum/lnrouter.py
  3. 6
      electrum/tests/test_lnpeer.py

44
electrum/lnpeer.py

@ -59,7 +59,6 @@ class Peer(PrintError):
self.node_anns = [] self.node_anns = []
self.chan_anns = [] self.chan_anns = []
self.chan_upds = [] self.chan_upds = []
self.last_chan_db_upd = time.time()
self.transport = transport self.transport = transport
self.pubkey = pubkey self.pubkey = pubkey
self.lnworker = lnworker self.lnworker = lnworker
@ -209,39 +208,40 @@ class Peer(PrintError):
@log_exceptions @log_exceptions
@handle_disconnect @handle_disconnect
async def main_loop(self): async def main_loop(self):
""" async with aiorpcx.TaskGroup() as group:
This is used in LNWorker and is necessary so that we don't kill the main await group.spawn(self._gossip_loop())
task group. It is not merged with _main_loop, so that we can test if the await group.spawn(self._message_loop())
correct exceptions are getting thrown using _main_loop.
"""
await self._main_loop()
async def _main_loop(self): async def _gossip_loop(self):
"""This is separate from main_loop for the tests.""" await self.initialized.wait()
try: while True:
await asyncio.wait_for(self.initialize(), 10) await asyncio.sleep(5)
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e: if self.node_anns:
self.print_error('initialize failed, disconnecting: {}'.format(repr(e)))
return
# loop
async for msg in self.transport.read_messages():
self.process_message(msg)
await asyncio.sleep(.01)
if time.time() - self.last_chan_db_upd > 5:
self.last_chan_db_upd = time.time()
self.channel_db.on_node_announcement(self.node_anns) self.channel_db.on_node_announcement(self.node_anns)
self.node_anns = [] self.node_anns = []
if self.chan_anns:
self.channel_db.on_channel_announcement(self.chan_anns) self.channel_db.on_channel_announcement(self.chan_anns)
self.chan_anns = [] self.chan_anns = []
if self.chan_upds:
self.channel_db.on_channel_update(self.chan_upds) self.channel_db.on_channel_update(self.chan_upds)
self.chan_upds = [] self.chan_upds = []
need_to_get = self.channel_db.missing_short_chan_ids() #type: Set[int] need_to_get = self.channel_db.missing_short_chan_ids() #type: Set[int]
if need_to_get and not self.receiving_channels: if need_to_get and not self.receiving_channels:
self.print_error('QUERYING SHORT CHANNEL IDS; ', len(need_to_get)) self.print_error('QUERYING SHORT CHANNEL IDS; missing', len(need_to_get), 'channels')
zlibencoded = zlib.compress(b"".join(x.to_bytes(byteorder='big', length=8) for x in need_to_get)) zlibencoded = zlib.compress(bfh(''.join(need_to_get)))
self.send_message('query_short_channel_ids', chain_hash=bytes.fromhex(bitcoin.rev_hex(constants.net.GENESIS)), len=1+len(zlibencoded), encoded_short_ids=b'\x01' + zlibencoded) self.send_message('query_short_channel_ids', chain_hash=bytes.fromhex(bitcoin.rev_hex(constants.net.GENESIS)), len=1+len(zlibencoded), encoded_short_ids=b'\x01' + zlibencoded)
self.receiving_channels = True self.receiving_channels = True
async def _message_loop(self):
try:
await asyncio.wait_for(self.initialize(), 10)
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
self.print_error('initialize failed, disconnecting: {}'.format(repr(e)))
return
# loop
async for msg in self.transport.read_messages():
self.process_message(msg)
await asyncio.sleep(.01)
self.ping_if_required() self.ping_if_required()
def on_reply_short_channel_ids_end(self, payload): def on_reply_short_channel_ids_end(self, payload):

16
electrum/lnrouter.py

@ -347,7 +347,19 @@ class ChannelDB:
def missing_short_chan_ids(self) -> Set[int]: def missing_short_chan_ids(self) -> Set[int]:
expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfo.short_channel_id))) expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfo.short_channel_id)))
return set(DBSession.query(Policy.short_channel_id).filter(expr).all()) chan_ids_from_policy = set(x[0] for x in DBSession.query(Policy.short_channel_id).filter(expr).all())
if chan_ids_from_policy:
return chan_ids_from_policy
# fetch channels for node_ids missing in node_info. that will also give us node_announcement
expr = not_(ChannelInfo.node1_id.in_(DBSession.query(NodeInfo.node_id)))
chan_ids_from_id1 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
if chan_ids_from_id1:
return chan_ids_from_id1
expr = not_(ChannelInfo.node2_id.in_(DBSession.query(NodeInfo.node_id)))
chan_ids_from_id2 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
if chan_ids_from_id2:
return chan_ids_from_id2
return set()
def add_verified_channel_info(self, short_id, capacity): def add_verified_channel_info(self, short_id, capacity):
# called from lnchannelverifier # called from lnchannelverifier
@ -390,6 +402,8 @@ class ChannelDB:
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']: if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
continue continue
channel_info = channel_infos.get(short_channel_id) channel_info = channel_infos.get(short_channel_id)
if not channel_info:
continue
channel_info.on_channel_update(msg_payload, trusted=trusted) channel_info.on_channel_update(msg_payload, trusted=trusted)
DBSession.commit() DBSession.commit()

6
electrum/tests/test_lnpeer.py

@ -173,7 +173,7 @@ class TestPeer(unittest.TestCase):
p1 = Peer(mock_lnworker, b"\x00" * 33, mock_transport, request_initial_sync=False) p1 = Peer(mock_lnworker, b"\x00" * 33, mock_transport, request_initial_sync=False)
mock_lnworker.peer = p1 mock_lnworker.peer = p1
with self.assertRaises(LightningPeerConnectionClosed): with self.assertRaises(LightningPeerConnectionClosed):
run(asyncio.wait_for(p1._main_loop(), 1)) run(asyncio.wait_for(p1._message_loop(), 1))
def prepare_peers(self): def prepare_peers(self):
k1, k2 = keypair(), keypair() k1, k2 = keypair(), keypair()
@ -231,7 +231,7 @@ class TestPeer(unittest.TestCase):
print("HTLC ADDED") print("HTLC ADDED")
self.assertEqual(await fut, 'Payment received') self.assertEqual(await fut, 'Payment received')
gath.cancel() gath.cancel()
gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop()) gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop())
with self.assertRaises(asyncio.CancelledError): with self.assertRaises(asyncio.CancelledError):
run(gath) run(gath)
@ -254,7 +254,7 @@ class TestPeer(unittest.TestCase):
# AssertionError is ok since we shouldn't use old routes, and the # AssertionError is ok since we shouldn't use old routes, and the
# route finding should fail when channel is closed # route finding should fail when channel is closed
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
run(asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._main_loop(), p2._main_loop())) run(asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._message_loop(), p2._message_loop()))
def run(coro): def run(coro):
return asyncio.get_event_loop().run_until_complete(coro) return asyncio.get_event_loop().run_until_complete(coro)

Loading…
Cancel
Save