|
@ -57,9 +57,7 @@ class Peer(Logger): |
|
|
|
|
|
|
|
|
def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase): |
|
|
def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase): |
|
|
self.initialized = asyncio.Event() |
|
|
self.initialized = asyncio.Event() |
|
|
self.node_anns = [] |
|
|
self.querying_lock = asyncio.Lock() |
|
|
self.chan_anns = [] |
|
|
|
|
|
self.chan_upds = [] |
|
|
|
|
|
self.transport = transport |
|
|
self.transport = transport |
|
|
self.pubkey = pubkey |
|
|
self.pubkey = pubkey |
|
|
self.lnworker = lnworker |
|
|
self.lnworker = lnworker |
|
@ -70,6 +68,7 @@ class Peer(Logger): |
|
|
self.lnwatcher = lnworker.network.lnwatcher |
|
|
self.lnwatcher = lnworker.network.lnwatcher |
|
|
self.channel_db = lnworker.network.channel_db |
|
|
self.channel_db = lnworker.network.channel_db |
|
|
self.ping_time = 0 |
|
|
self.ping_time = 0 |
|
|
|
|
|
self.reply_channel_range = asyncio.Queue() |
|
|
self.shutdown_received = defaultdict(asyncio.Future) |
|
|
self.shutdown_received = defaultdict(asyncio.Future) |
|
|
self.channel_accepted = defaultdict(asyncio.Queue) |
|
|
self.channel_accepted = defaultdict(asyncio.Queue) |
|
|
self.channel_reestablished = defaultdict(asyncio.Future) |
|
|
self.channel_reestablished = defaultdict(asyncio.Future) |
|
@ -89,7 +88,7 @@ class Peer(Logger): |
|
|
|
|
|
|
|
|
def send_message(self, message_name: str, **kwargs): |
|
|
def send_message(self, message_name: str, **kwargs): |
|
|
assert type(message_name) is str |
|
|
assert type(message_name) is str |
|
|
self.logger.info(f"Sending {message_name.upper()}") |
|
|
self.logger.debug(f"Sending {message_name.upper()}") |
|
|
self.transport.send_bytes(encode_msg(message_name, **kwargs)) |
|
|
self.transport.send_bytes(encode_msg(message_name, **kwargs)) |
|
|
|
|
|
|
|
|
async def initialize(self): |
|
|
async def initialize(self): |
|
@ -177,13 +176,13 @@ class Peer(Logger): |
|
|
self.initialized.set() |
|
|
self.initialized.set() |
|
|
|
|
|
|
|
|
def on_node_announcement(self, payload): |
|
|
def on_node_announcement(self, payload): |
|
|
self.node_anns.append(payload) |
|
|
self.channel_db.node_anns.append(payload) |
|
|
|
|
|
|
|
|
def on_channel_update(self, payload): |
|
|
def on_channel_update(self, payload): |
|
|
self.chan_upds.append(payload) |
|
|
self.channel_db.chan_upds.append(payload) |
|
|
|
|
|
|
|
|
def on_channel_announcement(self, payload): |
|
|
def on_channel_announcement(self, payload): |
|
|
self.chan_anns.append(payload) |
|
|
self.channel_db.chan_anns.append(payload) |
|
|
|
|
|
|
|
|
def on_announcement_signatures(self, payload): |
|
|
def on_announcement_signatures(self, payload): |
|
|
channel_id = payload['channel_id'] |
|
|
channel_id = payload['channel_id'] |
|
@ -207,15 +206,11 @@ class Peer(Logger): |
|
|
@handle_disconnect |
|
|
@handle_disconnect |
|
|
async def main_loop(self): |
|
|
async def main_loop(self): |
|
|
async with aiorpcx.TaskGroup() as group: |
|
|
async with aiorpcx.TaskGroup() as group: |
|
|
await group.spawn(self._gossip_loop()) |
|
|
|
|
|
await group.spawn(self._message_loop()) |
|
|
await group.spawn(self._message_loop()) |
|
|
# kill group if the peer times out |
|
|
# kill group if the peer times out |
|
|
await group.spawn(asyncio.wait_for(self.initialized.wait(), 10)) |
|
|
await group.spawn(asyncio.wait_for(self.initialized.wait(), 10)) |
|
|
|
|
|
|
|
|
@log_exceptions |
|
|
def request_gossip(self, timestamp=0): |
|
|
async def _gossip_loop(self): |
|
|
|
|
|
await self.initialized.wait() |
|
|
|
|
|
timestamp = self.channel_db.get_last_timestamp() |
|
|
|
|
|
if timestamp == 0: |
|
|
if timestamp == 0: |
|
|
self.logger.info('requesting whole channel graph') |
|
|
self.logger.info('requesting whole channel graph') |
|
|
else: |
|
|
else: |
|
@ -225,28 +220,47 @@ class Peer(Logger): |
|
|
chain_hash=constants.net.rev_genesis_bytes(), |
|
|
chain_hash=constants.net.rev_genesis_bytes(), |
|
|
first_timestamp=timestamp, |
|
|
first_timestamp=timestamp, |
|
|
timestamp_range=b'\xff'*4) |
|
|
timestamp_range=b'\xff'*4) |
|
|
while True: |
|
|
|
|
|
await asyncio.sleep(5) |
|
|
def query_channel_range(self, index, num): |
|
|
if self.node_anns: |
|
|
self.logger.info(f'query channel range') |
|
|
self.channel_db.on_node_announcement(self.node_anns) |
|
|
self.send_message( |
|
|
self.node_anns = [] |
|
|
'query_channel_range', |
|
|
if self.chan_anns: |
|
|
chain_hash=constants.net.rev_genesis_bytes(), |
|
|
self.channel_db.on_channel_announcement(self.chan_anns) |
|
|
first_blocknum=index, |
|
|
self.chan_anns = [] |
|
|
number_of_blocks=num) |
|
|
if self.chan_upds: |
|
|
|
|
|
self.channel_db.on_channel_update(self.chan_upds) |
|
|
def encode_short_ids(self, ids): |
|
|
self.chan_upds = [] |
|
|
return chr(1) + zlib.compress(bfh(''.join(ids))) |
|
|
# todo: enable when db is fixed |
|
|
|
|
|
#need_to_get = sorted(self.channel_db.missing_short_chan_ids()) |
|
|
def decode_short_ids(self, encoded): |
|
|
#if need_to_get and not self.receiving_channels: |
|
|
if encoded[0] == 0: |
|
|
# self.logger.info(f'missing {len(need_to_get)} channels') |
|
|
decoded = encoded[1:] |
|
|
# zlibencoded = zlib.compress(bfh(''.join(need_to_get[0:100]))) |
|
|
elif encoded[0] == 1: |
|
|
# self.send_message( |
|
|
decoded = zlib.decompress(encoded[1:]) |
|
|
# 'query_short_channel_ids', |
|
|
else: |
|
|
# chain_hash=constants.net.rev_genesis_bytes(), |
|
|
raise BaseException('zlib') |
|
|
# len=1+len(zlibencoded), |
|
|
ids = [decoded[i:i+8] for i in range(0, len(decoded), 8)] |
|
|
# encoded_short_ids=b'\x01' + zlibencoded) |
|
|
return ids |
|
|
# self.receiving_channels = True |
|
|
|
|
|
|
|
|
def on_reply_channel_range(self, payload): |
|
|
|
|
|
first = int.from_bytes(payload['first_blocknum'], 'big') |
|
|
|
|
|
num = int.from_bytes(payload['number_of_blocks'], 'big') |
|
|
|
|
|
complete = bool(payload['complete']) |
|
|
|
|
|
encoded = payload['encoded_short_ids'] |
|
|
|
|
|
ids = self.decode_short_ids(encoded) |
|
|
|
|
|
self.reply_channel_range.put_nowait((first, num, complete, ids)) |
|
|
|
|
|
|
|
|
|
|
|
async def query_short_channel_ids(self, ids, compressed=True): |
|
|
|
|
|
await self.querying_lock.acquire() |
|
|
|
|
|
#self.logger.info('querying {} short_channel_ids'.format(len(ids))) |
|
|
|
|
|
s = b''.join(ids) |
|
|
|
|
|
encoded = zlib.compress(s) if compressed else s |
|
|
|
|
|
prefix = b'\x01' if compressed else b'\x00' |
|
|
|
|
|
self.send_message( |
|
|
|
|
|
'query_short_channel_ids', |
|
|
|
|
|
chain_hash=constants.net.rev_genesis_bytes(), |
|
|
|
|
|
len=1+len(encoded), |
|
|
|
|
|
encoded_short_ids=prefix+encoded) |
|
|
|
|
|
|
|
|
async def _message_loop(self): |
|
|
async def _message_loop(self): |
|
|
try: |
|
|
try: |
|
@ -260,7 +274,7 @@ class Peer(Logger): |
|
|
self.ping_if_required() |
|
|
self.ping_if_required() |
|
|
|
|
|
|
|
|
def on_reply_short_channel_ids_end(self, payload): |
|
|
def on_reply_short_channel_ids_end(self, payload): |
|
|
self.receiving_channels = False |
|
|
self.querying_lock.release() |
|
|
|
|
|
|
|
|
def close_and_cleanup(self): |
|
|
def close_and_cleanup(self): |
|
|
try: |
|
|
try: |
|
|