Browse Source

LNGossip: sync channel db using query_channel_range

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
1011245c5e
  1. 84
      electrum/lnpeer.py
  2. 57
      electrum/lnrouter.py
  3. 86
      electrum/lnworker.py

84
electrum/lnpeer.py

@ -57,9 +57,7 @@ class Peer(Logger):
def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase): def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase):
self.initialized = asyncio.Event() self.initialized = asyncio.Event()
self.node_anns = [] self.querying_lock = asyncio.Lock()
self.chan_anns = []
self.chan_upds = []
self.transport = transport self.transport = transport
self.pubkey = pubkey self.pubkey = pubkey
self.lnworker = lnworker self.lnworker = lnworker
@ -70,6 +68,7 @@ class Peer(Logger):
self.lnwatcher = lnworker.network.lnwatcher self.lnwatcher = lnworker.network.lnwatcher
self.channel_db = lnworker.network.channel_db self.channel_db = lnworker.network.channel_db
self.ping_time = 0 self.ping_time = 0
self.reply_channel_range = asyncio.Queue()
self.shutdown_received = defaultdict(asyncio.Future) self.shutdown_received = defaultdict(asyncio.Future)
self.channel_accepted = defaultdict(asyncio.Queue) self.channel_accepted = defaultdict(asyncio.Queue)
self.channel_reestablished = defaultdict(asyncio.Future) self.channel_reestablished = defaultdict(asyncio.Future)
@ -89,7 +88,7 @@ class Peer(Logger):
def send_message(self, message_name: str, **kwargs): def send_message(self, message_name: str, **kwargs):
assert type(message_name) is str assert type(message_name) is str
self.logger.info(f"Sending {message_name.upper()}") self.logger.debug(f"Sending {message_name.upper()}")
self.transport.send_bytes(encode_msg(message_name, **kwargs)) self.transport.send_bytes(encode_msg(message_name, **kwargs))
async def initialize(self): async def initialize(self):
@ -177,13 +176,13 @@ class Peer(Logger):
self.initialized.set() self.initialized.set()
def on_node_announcement(self, payload): def on_node_announcement(self, payload):
self.node_anns.append(payload) self.channel_db.node_anns.append(payload)
def on_channel_update(self, payload): def on_channel_update(self, payload):
self.chan_upds.append(payload) self.channel_db.chan_upds.append(payload)
def on_channel_announcement(self, payload): def on_channel_announcement(self, payload):
self.chan_anns.append(payload) self.channel_db.chan_anns.append(payload)
def on_announcement_signatures(self, payload): def on_announcement_signatures(self, payload):
channel_id = payload['channel_id'] channel_id = payload['channel_id']
@ -207,15 +206,11 @@ class Peer(Logger):
@handle_disconnect @handle_disconnect
async def main_loop(self): async def main_loop(self):
async with aiorpcx.TaskGroup() as group: async with aiorpcx.TaskGroup() as group:
await group.spawn(self._gossip_loop())
await group.spawn(self._message_loop()) await group.spawn(self._message_loop())
# kill group if the peer times out # kill group if the peer times out
await group.spawn(asyncio.wait_for(self.initialized.wait(), 10)) await group.spawn(asyncio.wait_for(self.initialized.wait(), 10))
@log_exceptions def request_gossip(self, timestamp=0):
async def _gossip_loop(self):
await self.initialized.wait()
timestamp = self.channel_db.get_last_timestamp()
if timestamp == 0: if timestamp == 0:
self.logger.info('requesting whole channel graph') self.logger.info('requesting whole channel graph')
else: else:
@ -225,28 +220,47 @@ class Peer(Logger):
chain_hash=constants.net.rev_genesis_bytes(), chain_hash=constants.net.rev_genesis_bytes(),
first_timestamp=timestamp, first_timestamp=timestamp,
timestamp_range=b'\xff'*4) timestamp_range=b'\xff'*4)
while True:
await asyncio.sleep(5) def query_channel_range(self, index, num):
if self.node_anns: self.logger.info(f'query channel range')
self.channel_db.on_node_announcement(self.node_anns) self.send_message(
self.node_anns = [] 'query_channel_range',
if self.chan_anns: chain_hash=constants.net.rev_genesis_bytes(),
self.channel_db.on_channel_announcement(self.chan_anns) first_blocknum=index,
self.chan_anns = [] number_of_blocks=num)
if self.chan_upds:
self.channel_db.on_channel_update(self.chan_upds) def encode_short_ids(self, ids):
self.chan_upds = [] return chr(1) + zlib.compress(bfh(''.join(ids)))
# todo: enable when db is fixed
#need_to_get = sorted(self.channel_db.missing_short_chan_ids()) def decode_short_ids(self, encoded):
#if need_to_get and not self.receiving_channels: if encoded[0] == 0:
# self.logger.info(f'missing {len(need_to_get)} channels') decoded = encoded[1:]
# zlibencoded = zlib.compress(bfh(''.join(need_to_get[0:100]))) elif encoded[0] == 1:
# self.send_message( decoded = zlib.decompress(encoded[1:])
# 'query_short_channel_ids', else:
# chain_hash=constants.net.rev_genesis_bytes(), raise BaseException('zlib')
# len=1+len(zlibencoded), ids = [decoded[i:i+8] for i in range(0, len(decoded), 8)]
# encoded_short_ids=b'\x01' + zlibencoded) return ids
# self.receiving_channels = True
def on_reply_channel_range(self, payload):
first = int.from_bytes(payload['first_blocknum'], 'big')
num = int.from_bytes(payload['number_of_blocks'], 'big')
complete = bool(payload['complete'])
encoded = payload['encoded_short_ids']
ids = self.decode_short_ids(encoded)
self.reply_channel_range.put_nowait((first, num, complete, ids))
async def query_short_channel_ids(self, ids, compressed=True):
await self.querying_lock.acquire()
#self.logger.info('querying {} short_channel_ids'.format(len(ids)))
s = b''.join(ids)
encoded = zlib.compress(s) if compressed else s
prefix = b'\x01' if compressed else b'\x00'
self.send_message(
'query_short_channel_ids',
chain_hash=constants.net.rev_genesis_bytes(),
len=1+len(encoded),
encoded_short_ids=prefix+encoded)
async def _message_loop(self): async def _message_loop(self):
try: try:
@ -260,7 +274,7 @@ class Peer(Logger):
self.ping_if_required() self.ping_if_required()
def on_reply_short_channel_ids_end(self, payload): def on_reply_short_channel_ids_end(self, payload):
self.receiving_channels = False self.querying_lock.release()
def close_and_cleanup(self): def close_and_cleanup(self):
try: try:

57
electrum/lnrouter.py

@ -223,6 +223,20 @@ class ChannelDB(SqlDB):
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self) self.ca_verifier = LNChannelVerifier(network, self)
self.update_counts() self.update_counts()
self.node_anns = []
self.chan_anns = []
self.chan_upds = []
def process_gossip(self):
if self.node_anns:
self.on_node_announcement(self.node_anns)
self.node_anns = []
if self.chan_anns:
self.on_channel_announcement(self.chan_anns)
self.chan_anns = []
if self.chan_upds:
self.on_channel_update(self.chan_upds)
self.chan_upds = []
@sql @sql
def update_counts(self): def update_counts(self):
@ -232,7 +246,32 @@ class ChannelDB(SqlDB):
self.num_channels = self.DBSession.query(ChannelInfo).count() self.num_channels = self.DBSession.query(ChannelInfo).count()
self.num_policies = self.DBSession.query(Policy).count() self.num_policies = self.DBSession.query(Policy).count()
self.num_nodes = self.DBSession.query(NodeInfo).count() self.num_nodes = self.DBSession.query(NodeInfo).count()
self.logger.info(f'update counts {self.num_channels} {self.num_policies}')
@sql
@profiler
def purge_unknown_channels(self, channel_ids):
ids = [x.hex() for x in channel_ids]
missing = self.DBSession \
.query(ChannelInfo) \
.filter(not_(ChannelInfo.short_channel_id.in_(ids))) \
.all()
if missing:
self.logger.info("deleting {} channels".format(len(missing)))
delete_query = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(ids)))
self.DBSession.execute(delete_query)
self.DBSession.commit()
@sql
@profiler
def compare_channels(self, channel_ids):
ids = [x.hex() for x in channel_ids]
# I need to get the unknown, and also the channels that need refresh
known = self.DBSession \
.query(ChannelInfo) \
.filter(ChannelInfo.short_channel_id.in_(ids)) \
.all()
known = [bfh(r.short_channel_id) for r in known]
return known
@sql @sql
def add_recent_peer(self, peer: LNPeerAddr): def add_recent_peer(self, peer: LNPeerAddr):
@ -276,12 +315,14 @@ class ChannelDB(SqlDB):
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r] return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r]
@sql @sql
def missing_short_chan_ids(self) -> Set[int]: def missing_channel_announcements(self) -> Set[int]:
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id))) expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all()) return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
if chan_ids_from_policy:
return chan_ids_from_policy @sql
return set() def missing_channel_updates(self) -> Set[int]:
expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id)))
return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
@sql @sql
def add_verified_channel_info(self, short_id, capacity): def add_verified_channel_info(self, short_id, capacity):
@ -316,8 +357,8 @@ class ChannelDB(SqlDB):
for channel_info in new_channels.values(): for channel_info in new_channels.values():
self.DBSession.add(channel_info) self.DBSession.add(channel_info)
self.DBSession.commit() self.DBSession.commit()
#self.logger.info('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
self._update_counts() self._update_counts()
self.logger.info('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
self.network.trigger_callback('ln_status') self.network.trigger_callback('ln_status')
@sql @sql
@ -370,7 +411,7 @@ class ChannelDB(SqlDB):
self.DBSession.commit() self.DBSession.commit()
if new_policies: if new_policies:
self.logger.info(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}') self.logger.info(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}')
self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}') #self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
self._update_counts() self._update_counts()
@sql @sql

86
electrum/lnworker.py

@ -133,9 +133,7 @@ class LNWorker(Logger):
self.channel_db = self.network.channel_db self.channel_db = self.network.channel_db
self._last_tried_peer = {} # LNPeerAddr -> unix timestamp self._last_tried_peer = {} # LNPeerAddr -> unix timestamp
self._add_peers_from_config() self._add_peers_from_config()
# wait until we see confirmations
asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.main_loop()), self.network.asyncio_loop) asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.main_loop()), self.network.asyncio_loop)
self.first_timestamp_requested = None
def _add_peers_from_config(self): def _add_peers_from_config(self):
peer_list = self.config.get('lightning_peers', []) peer_list = self.config.get('lightning_peers', [])
@ -215,9 +213,24 @@ class LNWorker(Logger):
self.logger.info('got {} ln peers from dns seed'.format(len(peers))) self.logger.info('got {} ln peers from dns seed'.format(len(peers)))
return peers return peers
@staticmethod
def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
assert len(addr_list) >= 1
# choose first one that is an IP
for addr_in_db in addr_list:
host = addr_in_db.host
port = addr_in_db.port
if is_ip_address(host):
return host, port
# otherwise choose one at random
# TODO maybe filter out onion if not on tor?
choice = random.choice(addr_list)
return choice.host, choice.port
class LNGossip(LNWorker): class LNGossip(LNWorker):
# height of first channel announcements
first_block = 497000
def __init__(self, network): def __init__(self, network):
seed = os.urandom(32) seed = os.urandom(32)
@ -226,6 +239,61 @@ class LNGossip(LNWorker):
super().__init__(xprv) super().__init__(xprv)
self.localfeatures |= LnLocalFeatures.GOSSIP_QUERIES_REQ self.localfeatures |= LnLocalFeatures.GOSSIP_QUERIES_REQ
def start_network(self, network: 'Network'):
super().start_network(network)
asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.gossip_task()), self.network.asyncio_loop)
async def gossip_task(self):
req_index = self.first_block
req_num = self.network.get_local_height() - req_index
while len(self.peers) == 0:
await asyncio.sleep(1)
continue
# todo: parallelize over peers
peer = list(self.peers.values())[0]
await peer.initialized.wait()
# send channels_range query. peer will reply with several intervals
peer.query_channel_range(req_index, req_num)
intervals = []
ids = set()
# wait until requested range is covered
while True:
index, num, complete, _ids = await peer.reply_channel_range.get()
ids.update(_ids)
intervals.append((index, index+num))
intervals.sort()
while len(intervals) > 1:
a,b = intervals[0]
c,d = intervals[1]
if b == c:
intervals = [(a,d)] + intervals[2:]
else:
break
if len(intervals) == 1:
a, b = intervals[0]
if a <= req_index and b >= req_index + req_num:
break
self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete))
# TODO: filter results by date of last channel update, purge DB
#if complete:
# self.channel_db.purge_unknown_channels(ids)
known = self.channel_db.compare_channels(ids)
unknown = list(ids - set(known))
total = len(unknown)
N = 500
while unknown:
self.channel_db.process_gossip()
await peer.query_short_channel_ids(unknown[0:N])
unknown = unknown[N:]
self.logger.info(f'Querying channels: {total - len(unknown)}/{total}. Count: {self.channel_db.num_channels}')
# request gossip fromm current time
now = int(time.time())
peer.request_gossip(now)
while True:
await asyncio.sleep(5)
self.channel_db.process_gossip()
class LNWallet(LNWorker): class LNWallet(LNWorker):
@ -548,20 +616,6 @@ class LNWallet(LNWorker):
def on_channels_updated(self): def on_channels_updated(self):
self.network.trigger_callback('channels') self.network.trigger_callback('channels')
@staticmethod
def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
assert len(addr_list) >= 1
# choose first one that is an IP
for addr_in_db in addr_list:
host = addr_in_db.host
port = addr_in_db.port
if is_ip_address(host):
return host, port
# otherwise choose one at random
# TODO maybe filter out onion if not on tor?
choice = random.choice(addr_list)
return choice.host, choice.port
def open_channel(self, connect_contents, local_amt_sat, push_amt_sat, password=None, timeout=20): def open_channel(self, connect_contents, local_amt_sat, push_amt_sat, password=None, timeout=20):
node_id, rest = extract_nodeid(connect_contents) node_id, rest = extract_nodeid(connect_contents)
peer = self.peers.get(node_id) peer = self.peers.get(node_id)

Loading…
Cancel
Save