Browse Source

lnworker: define use_trampoline() for code clarity

patch-4
ThomasV 2 years ago
parent
commit
2af59e32b2
  1. 4
      electrum/gui/qt/channels_list.py
  2. 2
      electrum/lnchannel.py
  3. 8
      electrum/lnpeer.py
  4. 33
      electrum/lnworker.py
  5. 3
      electrum/tests/test_lnpeer.py

4
electrum/gui/qt/channels_list.py

@ -183,7 +183,7 @@ class ChannelsList(MyTreeView):
WaitingDialog(self, 'please wait..', task, self.on_request_sent, self.on_failure)
def freeze_channel_for_sending(self, chan, b):
if self.lnworker.channel_db or self.lnworker.is_trampoline_peer(chan.node_id):
if not self.lnworker.uses_trampoline() or self.lnworker.is_trampoline_peer(chan.node_id):
chan.set_frozen_for_sending(b)
else:
msg = messages.MSG_NON_TRAMPOLINE_CHANNEL_FROZEN_WITHOUT_GOSSIP
@ -198,7 +198,7 @@ class ChannelsList(MyTreeView):
channel_id2 = idx2.sibling(idx2.row(), self.Columns.NODE_ALIAS).data(ROLE_CHANNEL_ID)
chan1 = self.lnworker.channels.get(channel_id1)
chan2 = self.lnworker.channels.get(channel_id2)
if chan1 and chan2 and (self.lnworker.channel_db or chan1.node_id != chan2.node_id):
if chan1 and chan2 and (not self.lnworker.uses_trampoline() or chan1.node_id != chan2.node_id):
return chan1, chan2
return None, None

2
electrum/lnchannel.py

@ -824,7 +824,7 @@ class Channel(AbstractChannel):
return self.can_send_ctx_updates() and self.is_open()
def is_frozen_for_sending(self) -> bool:
if self.lnworker and self.lnworker.channel_db is None and not self.lnworker.is_trampoline_peer(self.node_id):
if self.lnworker and self.lnworker.uses_trampoline() and not self.lnworker.is_trampoline_peer(self.node_id):
return True
return self.storage.get('frozen_for_sending', False)

8
electrum/lnpeer.py

@ -366,16 +366,16 @@ class Peer(Logger):
self.maybe_set_initialized()
def on_node_announcement(self, payload):
if self.lnworker.channel_db:
if not self.lnworker.uses_trampoline():
self.gossip_queue.put_nowait(('node_announcement', payload))
def on_channel_announcement(self, payload):
if self.lnworker.channel_db:
if not self.lnworker.uses_trampoline():
self.gossip_queue.put_nowait(('channel_announcement', payload))
def on_channel_update(self, payload):
self.maybe_save_remote_update(payload)
if self.lnworker.channel_db:
if not self.lnworker.uses_trampoline():
self.gossip_queue.put_nowait(('channel_update', payload))
def maybe_save_remote_update(self, payload):
@ -702,7 +702,7 @@ class Peer(Logger):
# will raise if init fails
await asyncio.wait_for(self.initialized, LN_P2P_NETWORK_TIMEOUT)
# trampoline is not yet in features
if not self.lnworker.channel_db and not self.lnworker.is_trampoline_peer(self.pubkey):
if self.lnworker.uses_trampoline() and not self.lnworker.is_trampoline_peer(self.pubkey):
raise Exception('Not a trampoline node: ' + str(self.their_features))
feerate = self.lnworker.current_feerate_per_kw()

33
electrum/lnworker.py

@ -223,6 +223,9 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
def channel_db(self):
return self.network.channel_db if self.network else None
def uses_trampoline(self):
return not bool(self.channel_db)
@property
def peers(self) -> Mapping[bytes, Peer]:
"""Returns a read-only copy of peers."""
@ -235,7 +238,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
def get_node_alias(self, node_id: bytes) -> Optional[str]:
"""Returns the alias of the node, or None if unknown."""
node_alias = None
if self.channel_db:
if not self.uses_trampoline():
node_info = self.channel_db.get_node_info_for_node_id(node_id)
if node_info:
node_alias = node_info.alias
@ -371,8 +374,8 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
peer_addr = peer.transport.peer_addr
# reset connection attempt count
self._on_connection_successfully_established(peer_addr)
# add into channel db
if self.channel_db:
if not self.uses_trampoline():
# add into channel db
self.channel_db.add_recent_peer(peer_addr)
# save network address into channels we might have with peer
for chan in peer.channels.values():
@ -492,7 +495,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
if rest is not None:
host, port = split_host_port(rest)
else:
if not self.channel_db:
if self.uses_trampoline():
addr = trampolines_by_id().get(node_id)
if not addr:
raise ConnStringFormatError(_('Address unknown for node:') + ' ' + bh2u(node_id))
@ -1299,7 +1302,7 @@ class LNWallet(LNWorker):
if code == OnionFailureCode.MPP_TIMEOUT:
raise PaymentFailure(failure_msg.code_name())
# trampoline
if not self.channel_db:
if self.uses_trampoline():
def maybe_raise_trampoline_fee(htlc_log):
if htlc_log.trampoline_fee_level == self.trampoline_fee_level:
self.trampoline_fee_level += 1
@ -1370,7 +1373,7 @@ class LNWallet(LNWorker):
key = (payment_hash, short_channel_id, htlc.htlc_id)
self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route
# if we sent MPP to a trampoline, add item to sent_buckets
if not self.channel_db and amount_msat != total_msat:
if self.uses_trampoline() and amount_msat != total_msat:
if payment_secret not in self.sent_buckets:
self.sent_buckets[payment_secret] = (0, 0)
amount_sent, amount_failed = self.sent_buckets[payment_secret]
@ -1531,7 +1534,7 @@ class LNWallet(LNWorker):
return False
def suggest_peer(self) -> Optional[bytes]:
if self.channel_db:
if not self.uses_trampoline():
return self.lnrater.suggest_peer()
else:
return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
@ -1572,7 +1575,7 @@ class LNWallet(LNWorker):
try:
self.logger.info("trying single-part payment")
# try to send over a single channel
if not self.channel_db:
if self.uses_trampoline():
for chan in my_active_channels:
if not self.is_trampoline_peer(chan.node_id):
continue
@ -1640,7 +1643,7 @@ class LNWallet(LNWorker):
for chan in my_active_channels}
self.logger.info(f"channels_with_funds: {channels_with_funds}")
if not self.channel_db:
if self.uses_trampoline():
# in the case of a legacy payment, we don't allow splitting via different
# trampoline nodes, because of https://github.com/ACINQ/eclair/issues/2127
use_single_node, _ = is_legacy_relay(invoice_features, r_tags)
@ -2043,7 +2046,7 @@ class LNWallet(LNWorker):
self.logger.info(f"htlc_failed {failure_message}")
# check sent_buckets if we use trampoline
if not self.channel_db and payment_secret in self.sent_buckets:
if self.uses_trampoline() and payment_secret in self.sent_buckets:
amount_sent, amount_failed = self.sent_buckets[payment_secret]
amount_failed += amount_receiver_msat
self.sent_buckets[payment_secret] = amount_sent, amount_failed
@ -2163,7 +2166,7 @@ class LNWallet(LNWorker):
can_send_dict = defaultdict(int)
with self.lock:
for c in self.get_channels_for_sending():
if self.channel_db:
if not self.uses_trampoline():
can_send_dict[0] += send_capacity(c)
else:
can_send_dict[c.node_id] += send_capacity(c)
@ -2271,7 +2274,7 @@ class LNWallet(LNWorker):
continue
if chan1 == chan2:
continue
if not self.channel_db and chan1.node_id == chan2.node_id:
if self.uses_trampoline() and chan1.node_id == chan2.node_id:
continue
if direction == SENT:
if chan1.can_pay(delta*1000):
@ -2326,7 +2329,7 @@ class LNWallet(LNWorker):
async def rebalance_channels(self, chan1, chan2, amount_msat):
if chan1 == chan2:
raise Exception('Rebalance requires two different channels')
if not self.channel_db and chan1.node_id == chan2.node_id:
if self.uses_trampoline() and chan1.node_id == chan2.node_id:
raise Exception('Rebalance requires channels from different trampolines')
lnaddr, invoice = self.create_invoice(
amount_msat=amount_msat,
@ -2408,7 +2411,7 @@ class LNWallet(LNWorker):
async def reestablish_peer_for_given_channel(self, chan: Channel) -> None:
now = time.time()
peer_addresses = []
if not self.channel_db:
if self.uses_trampoline():
addr = trampolines_by_id().get(chan.node_id)
if addr:
peer_addresses.append(addr)
@ -2590,7 +2593,7 @@ class LNWallet(LNWorker):
if success:
return
# try with gossip db
if not self.channel_db:
if self.uses_trampoline():
raise Exception(_('Please enable gossip'))
node_id = self.network.channel_db.get_node_by_prefix(cb.node_id_prefix)
addresses_from_gossip = self.network.channel_db.get_node_addresses(node_id)

3
electrum/tests/test_lnpeer.py

@ -185,6 +185,9 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
def channel_db(self):
return self.network.channel_db if self.network else None
def uses_trampoline(self):
return not bool(self.channel_db)
@property
def channels(self):
return self._channels

Loading…
Cancel
Save