Browse Source

lnhtlc: speed-up methods for recent ctns

we maintain a set of interesting htlc_ids
hard-fail-on-bad-server-string
SomberNight 5 years ago
parent
commit
ec7473789e
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 1
      electrum/lnchannel.py
  2. 109
      electrum/lnhtlc.py
  3. 3
      electrum/lnworker.py
  4. 46
      electrum/tests/test_lnpeer.py

1
electrum/lnchannel.py

@ -624,6 +624,7 @@ class Channel(Logger):
assert type(whose) is HTLCOwner assert type(whose) is HTLCOwner
initial = self.config[whose].initial_msat initial = self.config[whose].initial_msat
# TODO slow. -- and 'balance' is called from a decent number of places (e.g. 'make_commitment')
for direction, htlc in self.hm.all_settled_htlcs_ever(ctx_owner, ctn): for direction, htlc in self.hm.all_settled_htlcs_ever(ctx_owner, ctn):
# note: could "simplify" to (whose * ctx_owner == direction * SENT) # note: could "simplify" to (whose * ctx_owner == direction * SENT)
if whose == ctx_owner: if whose == ctx_owner:

109
electrum/lnhtlc.py

@ -13,10 +13,10 @@ class HTLCManager:
if len(log) == 0: if len(log) == 0:
initial = { initial = {
'adds': {}, 'adds': {}, # "side who offered htlc" -> htlc_id -> htlc
'locked_in': {}, 'locked_in': {}, # "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn
'settles': {}, 'settles': {}, # "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn
'fails': {}, 'fails': {}, # "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn
'fee_updates': {}, # "side who initiated fee update" -> action -> list of FeeUpdates 'fee_updates': {}, # "side who initiated fee update" -> action -> list of FeeUpdates
'revack_pending': False, 'revack_pending': False,
'next_htlc_id': 0, 'next_htlc_id': 0,
@ -36,6 +36,7 @@ class HTLCManager:
if not log[sub]['fee_updates']: if not log[sub]['fee_updates']:
log[sub]['fee_updates'][0] = FeeUpdate(rate=initial_feerate, ctn_local=0, ctn_remote=0) log[sub]['fee_updates'][0] = FeeUpdate(rate=initial_feerate, ctn_local=0, ctn_remote=0)
self.log = log self.log = log
self._init_maybe_active_htlc_ids()
def ctn_latest(self, sub: HTLCOwner) -> int: def ctn_latest(self, sub: HTLCOwner) -> int:
"""Return the ctn for the latest (newest that has a valid sig) ctx of sub""" """Return the ctn for the latest (newest that has a valid sig) ctx of sub"""
@ -73,6 +74,7 @@ class HTLCManager:
self.log[LOCAL]['adds'][htlc_id] = htlc self.log[LOCAL]['adds'][htlc_id] = htlc
self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE)+1} self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE)+1}
self.log[LOCAL]['next_htlc_id'] += 1 self.log[LOCAL]['next_htlc_id'] += 1
self._maybe_active_htlc_ids[LOCAL].add(htlc_id)
return htlc return htlc
def recv_htlc(self, htlc: UpdateAddHtlc) -> None: def recv_htlc(self, htlc: UpdateAddHtlc) -> None:
@ -83,6 +85,7 @@ class HTLCManager:
self.log[REMOTE]['adds'][htlc_id] = htlc self.log[REMOTE]['adds'][htlc_id] = htlc
self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL)+1, REMOTE: None} self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL)+1, REMOTE: None}
self.log[REMOTE]['next_htlc_id'] += 1 self.log[REMOTE]['next_htlc_id'] += 1
self._maybe_active_htlc_ids[REMOTE].add(htlc_id)
def send_settle(self, htlc_id: int) -> None: def send_settle(self, htlc_id: int) -> None:
self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1} self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}
@ -130,13 +133,17 @@ class HTLCManager:
self.log[LOCAL]['ctn'] += 1 self.log[LOCAL]['ctn'] += 1
self._set_revack_pending(LOCAL, False) self._set_revack_pending(LOCAL, False)
# htlcs # htlcs
for ctns in self.log[REMOTE]['locked_in'].values(): for htlc_id in self._maybe_active_htlc_ids[REMOTE]:
ctns = self.log[REMOTE]['locked_in'][htlc_id]
if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL): if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL):
ctns[REMOTE] = self.ctn_latest(REMOTE) + 1 ctns[REMOTE] = self.ctn_latest(REMOTE) + 1
for log_action in ('settles', 'fails'): for log_action in ('settles', 'fails'):
for ctns in self.log[LOCAL][log_action].values(): for htlc_id in self._maybe_active_htlc_ids[LOCAL]:
ctns = self.log[LOCAL][log_action].get(htlc_id, None)
if ctns is None: continue
if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL): if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL):
ctns[REMOTE] = self.ctn_latest(REMOTE) + 1 ctns[REMOTE] = self.ctn_latest(REMOTE) + 1
self._update_maybe_active_htlc_ids()
# fee updates # fee updates
for k, fee_update in list(self.log[REMOTE]['fee_updates'].items()): for k, fee_update in list(self.log[REMOTE]['fee_updates'].items()):
if fee_update.ctn_remote is None and fee_update.ctn_local <= self.ctn_latest(LOCAL): if fee_update.ctn_remote is None and fee_update.ctn_local <= self.ctn_latest(LOCAL):
@ -146,13 +153,17 @@ class HTLCManager:
self.log[REMOTE]['ctn'] += 1 self.log[REMOTE]['ctn'] += 1
self._set_revack_pending(REMOTE, False) self._set_revack_pending(REMOTE, False)
# htlcs # htlcs
for ctns in self.log[LOCAL]['locked_in'].values(): for htlc_id in self._maybe_active_htlc_ids[LOCAL]:
ctns = self.log[LOCAL]['locked_in'][htlc_id]
if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE): if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE):
ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
for log_action in ('settles', 'fails'): for log_action in ('settles', 'fails'):
for ctns in self.log[REMOTE][log_action].values(): for htlc_id in self._maybe_active_htlc_ids[REMOTE]:
ctns = self.log[REMOTE][log_action].get(htlc_id, None)
if ctns is None: continue
if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE): if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE):
ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
self._update_maybe_active_htlc_ids()
# fee updates # fee updates
for k, fee_update in list(self.log[LOCAL]['fee_updates'].items()): for k, fee_update in list(self.log[LOCAL]['fee_updates'].items()):
if fee_update.ctn_local is None and fee_update.ctn_remote <= self.ctn_latest(REMOTE): if fee_update.ctn_local is None and fee_update.ctn_remote <= self.ctn_latest(REMOTE):
@ -161,6 +172,32 @@ class HTLCManager:
# no need to keep local update raw msgs anymore, they have just been ACKed. # no need to keep local update raw msgs anymore, they have just been ACKed.
self.log['unacked_local_updates2'].pop(self.log[REMOTE]['ctn'], None) self.log['unacked_local_updates2'].pop(self.log[REMOTE]['ctn'], None)
def _update_maybe_active_htlc_ids(self) -> None:
# Loosely, we want a set that contains the htlcs that are
# not "removed and revoked from all ctxs of both parties".
# It is guaranteed that those htlcs are in the set, but older htlcs might be there too:
# there is a sanity margin of 1 ctn -- this relaxes the care needed re order of method calls.
sanity_margin = 1
for htlc_proposer in (LOCAL, REMOTE):
for log_action in ('settles', 'fails'):
for htlc_id in list(self._maybe_active_htlc_ids[htlc_proposer]):
ctns = self.log[htlc_proposer][log_action].get(htlc_id, None)
if ctns is None: continue
if (ctns[LOCAL] is not None
and ctns[LOCAL] <= self.ctn_oldest_unrevoked(LOCAL) - sanity_margin
and ctns[REMOTE] is not None
and ctns[REMOTE] <= self.ctn_oldest_unrevoked(REMOTE) - sanity_margin):
self._maybe_active_htlc_ids[htlc_proposer].remove(htlc_id)
def _init_maybe_active_htlc_ids(self):
self._maybe_active_htlc_ids = {LOCAL: set(), REMOTE: set()} # first idx is "side who offered htlc"
# add all htlcs
for htlc_proposer in (LOCAL, REMOTE):
for htlc_id in self.log[htlc_proposer]['adds']:
self._maybe_active_htlc_ids[htlc_proposer].add(htlc_id)
# remove old htlcs
self._update_maybe_active_htlc_ids()
def discard_unsigned_remote_updates(self): def discard_unsigned_remote_updates(self):
"""Discard updates sent by the remote, that the remote itself """Discard updates sent by the remote, that the remote itself
did not yet sign (i.e. there was no corresponding commitment_signed msg) did not yet sign (i.e. there was no corresponding commitment_signed msg)
@ -170,6 +207,7 @@ class HTLCManager:
if ctns[LOCAL] > self.ctn_latest(LOCAL): if ctns[LOCAL] > self.ctn_latest(LOCAL):
del self.log[REMOTE]['locked_in'][htlc_id] del self.log[REMOTE]['locked_in'][htlc_id]
del self.log[REMOTE]['adds'][htlc_id] del self.log[REMOTE]['adds'][htlc_id]
self._maybe_active_htlc_ids[REMOTE].discard(htlc_id)
if self.log[REMOTE]['locked_in']: if self.log[REMOTE]['locked_in']:
self.log[REMOTE]['next_htlc_id'] = max([int(x) for x in self.log[REMOTE]['locked_in'].keys()]) + 1 self.log[REMOTE]['next_htlc_id'] = max([int(x) for x in self.log[REMOTE]['locked_in'].keys()]) + 1
else: else:
@ -222,7 +260,12 @@ class HTLCManager:
party = subject if direction == SENT else subject.inverted() party = subject if direction == SENT else subject.inverted()
settles = self.log[party]['settles'] settles = self.log[party]['settles']
fails = self.log[party]['fails'] fails = self.log[party]['fails']
for htlc_id, ctns in self.log[party]['locked_in'].items(): if ctn >= self.ctn_oldest_unrevoked(subject):
considered_htlc_ids = self._maybe_active_htlc_ids[party]
else: # ctn is too old; need to consider full log (slow...)
considered_htlc_ids = self.log[party]['locked_in']
for htlc_id in considered_htlc_ids:
ctns = self.log[party]['locked_in'][htlc_id]
if ctns[subject] is not None and ctns[subject] <= ctn: if ctns[subject] is not None and ctns[subject] <= ctn:
not_settled = htlc_id not in settles or settles[htlc_id][subject] is None or settles[htlc_id][subject] > ctn not_settled = htlc_id not in settles or settles[htlc_id][subject] is None or settles[htlc_id][subject] > ctn
not_failed = htlc_id not in fails or fails[htlc_id][subject] is None or fails[htlc_id][subject] > ctn not_failed = htlc_id not in fails or fails[htlc_id][subject] is None or fails[htlc_id][subject] > ctn
@ -290,32 +333,50 @@ class HTLCManager:
received = [(RECEIVED, x) for x in self.all_settled_htlcs_ever_by_direction(subject, RECEIVED, ctn)] received = [(RECEIVED, x) for x in self.all_settled_htlcs_ever_by_direction(subject, RECEIVED, ctn)]
return sent + received return sent + received
def received_in_ctn(self, ctn: int) -> Sequence[UpdateAddHtlc]: def _get_htlcs_that_got_removed_exactly_at_ctn(
self, ctn: int, *, ctx_owner: HTLCOwner, htlc_proposer: HTLCOwner, log_action: str,
) -> Sequence[UpdateAddHtlc]:
if ctn >= self.ctn_oldest_unrevoked(ctx_owner):
considered_htlc_ids = self._maybe_active_htlc_ids[htlc_proposer]
else: # ctn is too old; need to consider full log (slow...)
considered_htlc_ids = self.log[htlc_proposer][log_action]
htlcs = []
for htlc_id in considered_htlc_ids:
ctns = self.log[htlc_proposer][log_action].get(htlc_id, None)
if ctns is None: continue
if ctns[ctx_owner] == ctn:
htlcs.append(self.log[htlc_proposer]['adds'][htlc_id])
return htlcs
def received_in_ctn(self, local_ctn: int) -> Sequence[UpdateAddHtlc]:
""" """
received htlcs that became fulfilled when we send a revocation. received htlcs that became fulfilled when we send a revocation.
we check only local, because they are commited in the remote ctx first. we check only local, because they are committed in the remote ctx first.
""" """
return [self.log[REMOTE]['adds'][htlc_id] return self._get_htlcs_that_got_removed_exactly_at_ctn(local_ctn,
for htlc_id, ctns in self.log[REMOTE]['settles'].items() ctx_owner=LOCAL,
if ctns[LOCAL] == ctn] htlc_proposer=REMOTE,
log_action='settles')
def sent_in_ctn(self, ctn: int) -> Sequence[UpdateAddHtlc]: def sent_in_ctn(self, remote_ctn: int) -> Sequence[UpdateAddHtlc]:
""" """
sent htlcs that became fulfilled when we received a revocation sent htlcs that became fulfilled when we received a revocation
we check only remote, because they are commited in the local ctx first. we check only remote, because they are committed in the local ctx first.
""" """
return [self.log[LOCAL]['adds'][htlc_id] return self._get_htlcs_that_got_removed_exactly_at_ctn(remote_ctn,
for htlc_id, ctns in self.log[LOCAL]['settles'].items() ctx_owner=REMOTE,
if ctns[REMOTE] == ctn] htlc_proposer=LOCAL,
log_action='settles')
def failed_in_ctn(self, ctn: int) -> Sequence[UpdateAddHtlc]: def failed_in_ctn(self, remote_ctn: int) -> Sequence[UpdateAddHtlc]:
""" """
sent htlcs that became failed when we received a revocation sent htlcs that became failed when we received a revocation
we check only remote, because they are commited in the local ctx first. we check only remote, because they are committed in the local ctx first.
""" """
return [self.log[LOCAL]['adds'][htlc_id] return self._get_htlcs_that_got_removed_exactly_at_ctn(remote_ctn,
for htlc_id, ctns in self.log[LOCAL]['fails'].items() ctx_owner=REMOTE,
if ctns[REMOTE] == ctn] htlc_proposer=LOCAL,
log_action='fails')
##### Queries re Fees: ##### Queries re Fees:

3
electrum/lnworker.py

@ -124,7 +124,7 @@ FALLBACK_NODE_LIST_MAINNET = [
class PaymentInfo(NamedTuple): class PaymentInfo(NamedTuple):
payment_hash: bytes payment_hash: bytes
amount: int amount: int # in satoshis
direction: int direction: int
status: int status: int
@ -934,6 +934,7 @@ class LNWallet(LNWorker):
success = payment_attempt_log.success success = payment_attempt_log.success
if success: if success:
break break
self.logger.debug(f'payment attempts log for RHASH {key}: {repr(log)}')
self.network.trigger_callback('invoice_status', key) self.network.trigger_callback('invoice_status', key)
return success return success

46
electrum/tests/test_lnpeer.py

@ -7,6 +7,9 @@ from collections import defaultdict
import logging import logging
import concurrent import concurrent
from concurrent import futures from concurrent import futures
import unittest
from aiorpcx import TaskGroup
from electrum import constants from electrum import constants
from electrum.network import Network from electrum.network import Network
@ -18,13 +21,13 @@ from electrum.util import bh2u, create_and_start_event_loop
from electrum.lnpeer import Peer from electrum.lnpeer import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure, LnLocalFeatures from electrum.lnutil import PaymentFailure, LnLocalFeatures, HTLCOwner
from electrum.lnchannel import channel_states, peer_states, Channel from electrum.lnchannel import channel_states, peer_states, Channel
from electrum.lnrouter import LNPathFinder from electrum.lnrouter import LNPathFinder
from electrum.channel_db import ChannelDB from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound from electrum.lnworker import LNWallet, NoPathFound
from electrum.lnmsg import encode_msg, decode_msg from electrum.lnmsg import encode_msg, decode_msg
from electrum.logging import console_stderr_handler from electrum.logging import console_stderr_handler, Logger
from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID
from .test_lnchannel import create_test_channels from .test_lnchannel import create_test_channels
@ -81,8 +84,9 @@ class MockWallet:
def is_lightning_backup(self): def is_lightning_backup(self):
return False return False
class MockLNWallet: class MockLNWallet(Logger):
def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue): def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue):
Logger.__init__(self)
self.remote_keypair = remote_keypair self.remote_keypair = remote_keypair
self.node_keypair = local_keypair self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue) self.network = MockNetwork(tx_queue)
@ -216,9 +220,11 @@ class TestPeer(ElectrumTestCase):
return p1, p2, w1, w2, q1, q2 return p1, p2, w1, w2, q1, q2
@staticmethod @staticmethod
def prepare_invoice(w2 # receiver def prepare_invoice(
w2, # receiver
*,
amount_sat=100_000,
): ):
amount_sat = 100000
amount_btc = amount_sat/Decimal(COIN) amount_btc = amount_sat/Decimal(COIN)
payment_preimage = os.urandom(32) payment_preimage = os.urandom(32)
RHASH = sha256(payment_preimage) RHASH = sha256(payment_preimage)
@ -300,6 +306,35 @@ class TestPeer(ElectrumTestCase):
with self.assertRaises(concurrent.futures.CancelledError): with self.assertRaises(concurrent.futures.CancelledError):
run(f()) run(f())
@unittest.skip("too expensive")
#@needs_test_with_all_chacha20_implementations
def test_payments_stresstest(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
alice_init_balance_msat = alice_channel.balance(HTLCOwner.LOCAL)
bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL)
num_payments = 1000
#pay_reqs1 = [self.prepare_invoice(w1, amount_sat=1) for i in range(num_payments)]
pay_reqs2 = [self.prepare_invoice(w2, amount_sat=1) for i in range(num_payments)]
max_htlcs_in_flight = asyncio.Semaphore(5)
async def single_payment(pay_req):
async with max_htlcs_in_flight:
await w1._pay(pay_req)
async def many_payments():
async with TaskGroup() as group:
for pay_req in pay_reqs2:
await group.spawn(single_payment(pay_req))
gath.cancel()
gath = asyncio.gather(many_payments(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
async def f():
await gath
with self.assertRaises(concurrent.futures.CancelledError):
run(f())
self.assertEqual(alice_init_balance_msat - num_payments * 1000, alice_channel.balance(HTLCOwner.LOCAL))
self.assertEqual(alice_init_balance_msat - num_payments * 1000, bob_channel.balance(HTLCOwner.REMOTE))
self.assertEqual(bob_init_balance_msat + num_payments * 1000, bob_channel.balance(HTLCOwner.LOCAL))
self.assertEqual(bob_init_balance_msat + num_payments * 1000, alice_channel.balance(HTLCOwner.REMOTE))
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_close(self): def test_close(self):
alice_channel, bob_channel = create_test_channels() alice_channel, bob_channel = create_test_channels()
@ -352,5 +387,6 @@ class TestPeer(ElectrumTestCase):
with self.assertRaises(PaymentFailure): with self.assertRaises(PaymentFailure):
run(f()) run(f())
def run(coro): def run(coro):
return asyncio.run_coroutine_threadsafe(coro, loop=asyncio.get_event_loop()).result() return asyncio.run_coroutine_threadsafe(coro, loop=asyncio.get_event_loop()).result()

Loading…
Cancel
Save