Browse Source

lnworker: define use_trampoline() for code clarity

patch-4
ThomasV 2 years ago
parent
commit
2af59e32b2
  1. 4
      electrum/gui/qt/channels_list.py
  2. 2
      electrum/lnchannel.py
  3. 8
      electrum/lnpeer.py
  4. 33
      electrum/lnworker.py
  5. 3
      electrum/tests/test_lnpeer.py

4
electrum/gui/qt/channels_list.py

@ -183,7 +183,7 @@ class ChannelsList(MyTreeView):
WaitingDialog(self, 'please wait..', task, self.on_request_sent, self.on_failure) WaitingDialog(self, 'please wait..', task, self.on_request_sent, self.on_failure)
def freeze_channel_for_sending(self, chan, b): def freeze_channel_for_sending(self, chan, b):
if self.lnworker.channel_db or self.lnworker.is_trampoline_peer(chan.node_id): if not self.lnworker.uses_trampoline() or self.lnworker.is_trampoline_peer(chan.node_id):
chan.set_frozen_for_sending(b) chan.set_frozen_for_sending(b)
else: else:
msg = messages.MSG_NON_TRAMPOLINE_CHANNEL_FROZEN_WITHOUT_GOSSIP msg = messages.MSG_NON_TRAMPOLINE_CHANNEL_FROZEN_WITHOUT_GOSSIP
@ -198,7 +198,7 @@ class ChannelsList(MyTreeView):
channel_id2 = idx2.sibling(idx2.row(), self.Columns.NODE_ALIAS).data(ROLE_CHANNEL_ID) channel_id2 = idx2.sibling(idx2.row(), self.Columns.NODE_ALIAS).data(ROLE_CHANNEL_ID)
chan1 = self.lnworker.channels.get(channel_id1) chan1 = self.lnworker.channels.get(channel_id1)
chan2 = self.lnworker.channels.get(channel_id2) chan2 = self.lnworker.channels.get(channel_id2)
if chan1 and chan2 and (self.lnworker.channel_db or chan1.node_id != chan2.node_id): if chan1 and chan2 and (not self.lnworker.uses_trampoline() or chan1.node_id != chan2.node_id):
return chan1, chan2 return chan1, chan2
return None, None return None, None

2
electrum/lnchannel.py

@ -824,7 +824,7 @@ class Channel(AbstractChannel):
return self.can_send_ctx_updates() and self.is_open() return self.can_send_ctx_updates() and self.is_open()
def is_frozen_for_sending(self) -> bool: def is_frozen_for_sending(self) -> bool:
if self.lnworker and self.lnworker.channel_db is None and not self.lnworker.is_trampoline_peer(self.node_id): if self.lnworker and self.lnworker.uses_trampoline() and not self.lnworker.is_trampoline_peer(self.node_id):
return True return True
return self.storage.get('frozen_for_sending', False) return self.storage.get('frozen_for_sending', False)

8
electrum/lnpeer.py

@ -366,16 +366,16 @@ class Peer(Logger):
self.maybe_set_initialized() self.maybe_set_initialized()
def on_node_announcement(self, payload): def on_node_announcement(self, payload):
if self.lnworker.channel_db: if not self.lnworker.uses_trampoline():
self.gossip_queue.put_nowait(('node_announcement', payload)) self.gossip_queue.put_nowait(('node_announcement', payload))
def on_channel_announcement(self, payload): def on_channel_announcement(self, payload):
if self.lnworker.channel_db: if not self.lnworker.uses_trampoline():
self.gossip_queue.put_nowait(('channel_announcement', payload)) self.gossip_queue.put_nowait(('channel_announcement', payload))
def on_channel_update(self, payload): def on_channel_update(self, payload):
self.maybe_save_remote_update(payload) self.maybe_save_remote_update(payload)
if self.lnworker.channel_db: if not self.lnworker.uses_trampoline():
self.gossip_queue.put_nowait(('channel_update', payload)) self.gossip_queue.put_nowait(('channel_update', payload))
def maybe_save_remote_update(self, payload): def maybe_save_remote_update(self, payload):
@ -702,7 +702,7 @@ class Peer(Logger):
# will raise if init fails # will raise if init fails
await asyncio.wait_for(self.initialized, LN_P2P_NETWORK_TIMEOUT) await asyncio.wait_for(self.initialized, LN_P2P_NETWORK_TIMEOUT)
# trampoline is not yet in features # trampoline is not yet in features
if not self.lnworker.channel_db and not self.lnworker.is_trampoline_peer(self.pubkey): if self.lnworker.uses_trampoline() and not self.lnworker.is_trampoline_peer(self.pubkey):
raise Exception('Not a trampoline node: ' + str(self.their_features)) raise Exception('Not a trampoline node: ' + str(self.their_features))
feerate = self.lnworker.current_feerate_per_kw() feerate = self.lnworker.current_feerate_per_kw()

33
electrum/lnworker.py

@ -223,6 +223,9 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
def channel_db(self): def channel_db(self):
return self.network.channel_db if self.network else None return self.network.channel_db if self.network else None
def uses_trampoline(self):
return not bool(self.channel_db)
@property @property
def peers(self) -> Mapping[bytes, Peer]: def peers(self) -> Mapping[bytes, Peer]:
"""Returns a read-only copy of peers.""" """Returns a read-only copy of peers."""
@ -235,7 +238,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
def get_node_alias(self, node_id: bytes) -> Optional[str]: def get_node_alias(self, node_id: bytes) -> Optional[str]:
"""Returns the alias of the node, or None if unknown.""" """Returns the alias of the node, or None if unknown."""
node_alias = None node_alias = None
if self.channel_db: if not self.uses_trampoline():
node_info = self.channel_db.get_node_info_for_node_id(node_id) node_info = self.channel_db.get_node_info_for_node_id(node_id)
if node_info: if node_info:
node_alias = node_info.alias node_alias = node_info.alias
@ -371,8 +374,8 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
peer_addr = peer.transport.peer_addr peer_addr = peer.transport.peer_addr
# reset connection attempt count # reset connection attempt count
self._on_connection_successfully_established(peer_addr) self._on_connection_successfully_established(peer_addr)
# add into channel db if not self.uses_trampoline():
if self.channel_db: # add into channel db
self.channel_db.add_recent_peer(peer_addr) self.channel_db.add_recent_peer(peer_addr)
# save network address into channels we might have with peer # save network address into channels we might have with peer
for chan in peer.channels.values(): for chan in peer.channels.values():
@ -492,7 +495,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
if rest is not None: if rest is not None:
host, port = split_host_port(rest) host, port = split_host_port(rest)
else: else:
if not self.channel_db: if self.uses_trampoline():
addr = trampolines_by_id().get(node_id) addr = trampolines_by_id().get(node_id)
if not addr: if not addr:
raise ConnStringFormatError(_('Address unknown for node:') + ' ' + bh2u(node_id)) raise ConnStringFormatError(_('Address unknown for node:') + ' ' + bh2u(node_id))
@ -1299,7 +1302,7 @@ class LNWallet(LNWorker):
if code == OnionFailureCode.MPP_TIMEOUT: if code == OnionFailureCode.MPP_TIMEOUT:
raise PaymentFailure(failure_msg.code_name()) raise PaymentFailure(failure_msg.code_name())
# trampoline # trampoline
if not self.channel_db: if self.uses_trampoline():
def maybe_raise_trampoline_fee(htlc_log): def maybe_raise_trampoline_fee(htlc_log):
if htlc_log.trampoline_fee_level == self.trampoline_fee_level: if htlc_log.trampoline_fee_level == self.trampoline_fee_level:
self.trampoline_fee_level += 1 self.trampoline_fee_level += 1
@ -1370,7 +1373,7 @@ class LNWallet(LNWorker):
key = (payment_hash, short_channel_id, htlc.htlc_id) key = (payment_hash, short_channel_id, htlc.htlc_id)
self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route
# if we sent MPP to a trampoline, add item to sent_buckets # if we sent MPP to a trampoline, add item to sent_buckets
if not self.channel_db and amount_msat != total_msat: if self.uses_trampoline() and amount_msat != total_msat:
if payment_secret not in self.sent_buckets: if payment_secret not in self.sent_buckets:
self.sent_buckets[payment_secret] = (0, 0) self.sent_buckets[payment_secret] = (0, 0)
amount_sent, amount_failed = self.sent_buckets[payment_secret] amount_sent, amount_failed = self.sent_buckets[payment_secret]
@ -1531,7 +1534,7 @@ class LNWallet(LNWorker):
return False return False
def suggest_peer(self) -> Optional[bytes]: def suggest_peer(self) -> Optional[bytes]:
if self.channel_db: if not self.uses_trampoline():
return self.lnrater.suggest_peer() return self.lnrater.suggest_peer()
else: else:
return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
@ -1572,7 +1575,7 @@ class LNWallet(LNWorker):
try: try:
self.logger.info("trying single-part payment") self.logger.info("trying single-part payment")
# try to send over a single channel # try to send over a single channel
if not self.channel_db: if self.uses_trampoline():
for chan in my_active_channels: for chan in my_active_channels:
if not self.is_trampoline_peer(chan.node_id): if not self.is_trampoline_peer(chan.node_id):
continue continue
@ -1640,7 +1643,7 @@ class LNWallet(LNWorker):
for chan in my_active_channels} for chan in my_active_channels}
self.logger.info(f"channels_with_funds: {channels_with_funds}") self.logger.info(f"channels_with_funds: {channels_with_funds}")
if not self.channel_db: if self.uses_trampoline():
# in the case of a legacy payment, we don't allow splitting via different # in the case of a legacy payment, we don't allow splitting via different
# trampoline nodes, because of https://github.com/ACINQ/eclair/issues/2127 # trampoline nodes, because of https://github.com/ACINQ/eclair/issues/2127
use_single_node, _ = is_legacy_relay(invoice_features, r_tags) use_single_node, _ = is_legacy_relay(invoice_features, r_tags)
@ -2043,7 +2046,7 @@ class LNWallet(LNWorker):
self.logger.info(f"htlc_failed {failure_message}") self.logger.info(f"htlc_failed {failure_message}")
# check sent_buckets if we use trampoline # check sent_buckets if we use trampoline
if not self.channel_db and payment_secret in self.sent_buckets: if self.uses_trampoline() and payment_secret in self.sent_buckets:
amount_sent, amount_failed = self.sent_buckets[payment_secret] amount_sent, amount_failed = self.sent_buckets[payment_secret]
amount_failed += amount_receiver_msat amount_failed += amount_receiver_msat
self.sent_buckets[payment_secret] = amount_sent, amount_failed self.sent_buckets[payment_secret] = amount_sent, amount_failed
@ -2163,7 +2166,7 @@ class LNWallet(LNWorker):
can_send_dict = defaultdict(int) can_send_dict = defaultdict(int)
with self.lock: with self.lock:
for c in self.get_channels_for_sending(): for c in self.get_channels_for_sending():
if self.channel_db: if not self.uses_trampoline():
can_send_dict[0] += send_capacity(c) can_send_dict[0] += send_capacity(c)
else: else:
can_send_dict[c.node_id] += send_capacity(c) can_send_dict[c.node_id] += send_capacity(c)
@ -2271,7 +2274,7 @@ class LNWallet(LNWorker):
continue continue
if chan1 == chan2: if chan1 == chan2:
continue continue
if not self.channel_db and chan1.node_id == chan2.node_id: if self.uses_trampoline() and chan1.node_id == chan2.node_id:
continue continue
if direction == SENT: if direction == SENT:
if chan1.can_pay(delta*1000): if chan1.can_pay(delta*1000):
@ -2326,7 +2329,7 @@ class LNWallet(LNWorker):
async def rebalance_channels(self, chan1, chan2, amount_msat): async def rebalance_channels(self, chan1, chan2, amount_msat):
if chan1 == chan2: if chan1 == chan2:
raise Exception('Rebalance requires two different channels') raise Exception('Rebalance requires two different channels')
if not self.channel_db and chan1.node_id == chan2.node_id: if self.uses_trampoline() and chan1.node_id == chan2.node_id:
raise Exception('Rebalance requires channels from different trampolines') raise Exception('Rebalance requires channels from different trampolines')
lnaddr, invoice = self.create_invoice( lnaddr, invoice = self.create_invoice(
amount_msat=amount_msat, amount_msat=amount_msat,
@ -2408,7 +2411,7 @@ class LNWallet(LNWorker):
async def reestablish_peer_for_given_channel(self, chan: Channel) -> None: async def reestablish_peer_for_given_channel(self, chan: Channel) -> None:
now = time.time() now = time.time()
peer_addresses = [] peer_addresses = []
if not self.channel_db: if self.uses_trampoline():
addr = trampolines_by_id().get(chan.node_id) addr = trampolines_by_id().get(chan.node_id)
if addr: if addr:
peer_addresses.append(addr) peer_addresses.append(addr)
@ -2590,7 +2593,7 @@ class LNWallet(LNWorker):
if success: if success:
return return
# try with gossip db # try with gossip db
if not self.channel_db: if self.uses_trampoline():
raise Exception(_('Please enable gossip')) raise Exception(_('Please enable gossip'))
node_id = self.network.channel_db.get_node_by_prefix(cb.node_id_prefix) node_id = self.network.channel_db.get_node_by_prefix(cb.node_id_prefix)
addresses_from_gossip = self.network.channel_db.get_node_addresses(node_id) addresses_from_gossip = self.network.channel_db.get_node_addresses(node_id)

3
electrum/tests/test_lnpeer.py

@ -185,6 +185,9 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
def channel_db(self): def channel_db(self):
return self.network.channel_db if self.network else None return self.network.channel_db if self.network else None
def uses_trampoline(self):
return not bool(self.channel_db)
@property @property
def channels(self): def channels(self):
return self._channels return self._channels

Loading…
Cancel
Save