Browse Source

lnhtlc: add lock to make methods thread-safe

many methods are accessed from both the asyncio thread and the GUI thread

fixes #6373
patch-4
SomberNight 4 years ago
parent
commit
51f42a25f9
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 40
      electrum/lnhtlc.py

40
electrum/lnhtlc.py

@ -1,5 +1,6 @@
from copy import deepcopy from copy import deepcopy
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING, Set from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING, Set
import threading
from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate
from .util import bh2u, bfh from .util import bh2u, bfh
@ -7,6 +8,7 @@ from .util import bh2u, bfh
if TYPE_CHECKING: if TYPE_CHECKING:
from .json_db import StoredDict from .json_db import StoredDict
class HTLCManager: class HTLCManager:
def __init__(self, log:'StoredDict', *, initial_feerate=None): def __init__(self, log:'StoredDict', *, initial_feerate=None):
@ -39,8 +41,16 @@ class HTLCManager:
if not log[sub]['fee_updates']: if not log[sub]['fee_updates']:
log[sub]['fee_updates'][0] = FeeUpdate(rate=initial_feerate, ctn_local=0, ctn_remote=0) log[sub]['fee_updates'][0] = FeeUpdate(rate=initial_feerate, ctn_local=0, ctn_remote=0)
self.log = log self.log = log
self.lock = threading.RLock()
self._init_maybe_active_htlc_ids() self._init_maybe_active_htlc_ids()
def with_lock(func):
def func_wrapper(self, *args, **kwargs):
with self.lock:
return func(self, *args, **kwargs)
return func_wrapper
@with_lock
def ctn_latest(self, sub: HTLCOwner) -> int: def ctn_latest(self, sub: HTLCOwner) -> int:
"""Return the ctn for the latest (newest that has a valid sig) ctx of sub""" """Return the ctn for the latest (newest that has a valid sig) ctx of sub"""
return self.ctn_oldest_unrevoked(sub) + int(self.is_revack_pending(sub)) return self.ctn_oldest_unrevoked(sub) + int(self.is_revack_pending(sub))
@ -63,12 +73,14 @@ class HTLCManager:
##### Actions on channel: ##### Actions on channel:
@with_lock
def channel_open_finished(self): def channel_open_finished(self):
self.log[LOCAL]['ctn'] = 0 self.log[LOCAL]['ctn'] = 0
self.log[REMOTE]['ctn'] = 0 self.log[REMOTE]['ctn'] = 0
self._set_revack_pending(LOCAL, False) self._set_revack_pending(LOCAL, False)
self._set_revack_pending(REMOTE, False) self._set_revack_pending(REMOTE, False)
@with_lock
def send_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc: def send_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc:
htlc_id = htlc.htlc_id htlc_id = htlc.htlc_id
if htlc_id != self.get_next_htlc_id(LOCAL): if htlc_id != self.get_next_htlc_id(LOCAL):
@ -80,6 +92,7 @@ class HTLCManager:
self._maybe_active_htlc_ids[LOCAL].add(htlc_id) self._maybe_active_htlc_ids[LOCAL].add(htlc_id)
return htlc return htlc
@with_lock
def recv_htlc(self, htlc: UpdateAddHtlc) -> None: def recv_htlc(self, htlc: UpdateAddHtlc) -> None:
htlc_id = htlc.htlc_id htlc_id = htlc.htlc_id
if htlc_id != self.get_next_htlc_id(REMOTE): if htlc_id != self.get_next_htlc_id(REMOTE):
@ -90,40 +103,47 @@ class HTLCManager:
self.log[REMOTE]['next_htlc_id'] += 1 self.log[REMOTE]['next_htlc_id'] += 1
self._maybe_active_htlc_ids[REMOTE].add(htlc_id) self._maybe_active_htlc_ids[REMOTE].add(htlc_id)
@with_lock
def send_settle(self, htlc_id: int) -> None: def send_settle(self, htlc_id: int) -> None:
next_ctn = self.ctn_latest(REMOTE) + 1 next_ctn = self.ctn_latest(REMOTE) + 1
if not self.is_htlc_active_at_ctn(ctx_owner=REMOTE, ctn=next_ctn, htlc_proposer=REMOTE, htlc_id=htlc_id): if not self.is_htlc_active_at_ctn(ctx_owner=REMOTE, ctn=next_ctn, htlc_proposer=REMOTE, htlc_id=htlc_id):
raise Exception(f"(local) cannot remove htlc that is not there...") raise Exception(f"(local) cannot remove htlc that is not there...")
self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: next_ctn} self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: next_ctn}
@with_lock
def recv_settle(self, htlc_id: int) -> None: def recv_settle(self, htlc_id: int) -> None:
next_ctn = self.ctn_latest(LOCAL) + 1 next_ctn = self.ctn_latest(LOCAL) + 1
if not self.is_htlc_active_at_ctn(ctx_owner=LOCAL, ctn=next_ctn, htlc_proposer=LOCAL, htlc_id=htlc_id): if not self.is_htlc_active_at_ctn(ctx_owner=LOCAL, ctn=next_ctn, htlc_proposer=LOCAL, htlc_id=htlc_id):
raise Exception(f"(remote) cannot remove htlc that is not there...") raise Exception(f"(remote) cannot remove htlc that is not there...")
self.log[LOCAL]['settles'][htlc_id] = {LOCAL: next_ctn, REMOTE: None} self.log[LOCAL]['settles'][htlc_id] = {LOCAL: next_ctn, REMOTE: None}
@with_lock
def send_fail(self, htlc_id: int) -> None: def send_fail(self, htlc_id: int) -> None:
next_ctn = self.ctn_latest(REMOTE) + 1 next_ctn = self.ctn_latest(REMOTE) + 1
if not self.is_htlc_active_at_ctn(ctx_owner=REMOTE, ctn=next_ctn, htlc_proposer=REMOTE, htlc_id=htlc_id): if not self.is_htlc_active_at_ctn(ctx_owner=REMOTE, ctn=next_ctn, htlc_proposer=REMOTE, htlc_id=htlc_id):
raise Exception(f"(local) cannot remove htlc that is not there...") raise Exception(f"(local) cannot remove htlc that is not there...")
self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: next_ctn} self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: next_ctn}
@with_lock
def recv_fail(self, htlc_id: int) -> None: def recv_fail(self, htlc_id: int) -> None:
next_ctn = self.ctn_latest(LOCAL) + 1 next_ctn = self.ctn_latest(LOCAL) + 1
if not self.is_htlc_active_at_ctn(ctx_owner=LOCAL, ctn=next_ctn, htlc_proposer=LOCAL, htlc_id=htlc_id): if not self.is_htlc_active_at_ctn(ctx_owner=LOCAL, ctn=next_ctn, htlc_proposer=LOCAL, htlc_id=htlc_id):
raise Exception(f"(remote) cannot remove htlc that is not there...") raise Exception(f"(remote) cannot remove htlc that is not there...")
self.log[LOCAL]['fails'][htlc_id] = {LOCAL: next_ctn, REMOTE: None} self.log[LOCAL]['fails'][htlc_id] = {LOCAL: next_ctn, REMOTE: None}
@with_lock
def send_update_fee(self, feerate: int) -> None: def send_update_fee(self, feerate: int) -> None:
fee_update = FeeUpdate(rate=feerate, fee_update = FeeUpdate(rate=feerate,
ctn_local=None, ctn_remote=self.ctn_latest(REMOTE) + 1) ctn_local=None, ctn_remote=self.ctn_latest(REMOTE) + 1)
self._new_feeupdate(fee_update, subject=LOCAL) self._new_feeupdate(fee_update, subject=LOCAL)
@with_lock
def recv_update_fee(self, feerate: int) -> None: def recv_update_fee(self, feerate: int) -> None:
fee_update = FeeUpdate(rate=feerate, fee_update = FeeUpdate(rate=feerate,
ctn_local=self.ctn_latest(LOCAL) + 1, ctn_remote=None) ctn_local=self.ctn_latest(LOCAL) + 1, ctn_remote=None)
self._new_feeupdate(fee_update, subject=REMOTE) self._new_feeupdate(fee_update, subject=REMOTE)
@with_lock
def _new_feeupdate(self, fee_update: FeeUpdate, subject: HTLCOwner) -> None: def _new_feeupdate(self, fee_update: FeeUpdate, subject: HTLCOwner) -> None:
# overwrite last fee update if not yet committed to by anyone; otherwise append # overwrite last fee update if not yet committed to by anyone; otherwise append
d = self.log[subject]['fee_updates'] d = self.log[subject]['fee_updates']
@ -136,14 +156,17 @@ class HTLCManager:
else: else:
d[n] = fee_update d[n] = fee_update
@with_lock
def send_ctx(self) -> None: def send_ctx(self) -> None:
assert self.ctn_latest(REMOTE) == self.ctn_oldest_unrevoked(REMOTE), (self.ctn_latest(REMOTE), self.ctn_oldest_unrevoked(REMOTE)) assert self.ctn_latest(REMOTE) == self.ctn_oldest_unrevoked(REMOTE), (self.ctn_latest(REMOTE), self.ctn_oldest_unrevoked(REMOTE))
self._set_revack_pending(REMOTE, True) self._set_revack_pending(REMOTE, True)
@with_lock
def recv_ctx(self) -> None: def recv_ctx(self) -> None:
assert self.ctn_latest(LOCAL) == self.ctn_oldest_unrevoked(LOCAL), (self.ctn_latest(LOCAL), self.ctn_oldest_unrevoked(LOCAL)) assert self.ctn_latest(LOCAL) == self.ctn_oldest_unrevoked(LOCAL), (self.ctn_latest(LOCAL), self.ctn_oldest_unrevoked(LOCAL))
self._set_revack_pending(LOCAL, True) self._set_revack_pending(LOCAL, True)
@with_lock
def send_rev(self) -> None: def send_rev(self) -> None:
self.log[LOCAL]['ctn'] += 1 self.log[LOCAL]['ctn'] += 1
self._set_revack_pending(LOCAL, False) self._set_revack_pending(LOCAL, False)
@ -164,6 +187,7 @@ class HTLCManager:
if fee_update.ctn_remote is None and fee_update.ctn_local <= self.ctn_latest(LOCAL): if fee_update.ctn_remote is None and fee_update.ctn_local <= self.ctn_latest(LOCAL):
fee_update.ctn_remote = self.ctn_latest(REMOTE) + 1 fee_update.ctn_remote = self.ctn_latest(REMOTE) + 1
@with_lock
def recv_rev(self) -> None: def recv_rev(self) -> None:
self.log[REMOTE]['ctn'] += 1 self.log[REMOTE]['ctn'] += 1
self._set_revack_pending(REMOTE, False) self._set_revack_pending(REMOTE, False)
@ -187,6 +211,7 @@ class HTLCManager:
# no need to keep local update raw msgs anymore, they have just been ACKed. # no need to keep local update raw msgs anymore, they have just been ACKed.
self.log['unacked_local_updates2'].pop(self.log[REMOTE]['ctn'], None) self.log['unacked_local_updates2'].pop(self.log[REMOTE]['ctn'], None)
@with_lock
def _update_maybe_active_htlc_ids(self) -> None: def _update_maybe_active_htlc_ids(self) -> None:
# - Loosely, we want a set that contains the htlcs that are # - Loosely, we want a set that contains the htlcs that are
# not "removed and revoked from all ctxs of both parties". (self._maybe_active_htlc_ids) # not "removed and revoked from all ctxs of both parties". (self._maybe_active_htlc_ids)
@ -209,6 +234,7 @@ class HTLCManager:
htlc = self.log[htlc_proposer]['adds'][htlc_id] # type: UpdateAddHtlc htlc = self.log[htlc_proposer]['adds'][htlc_id] # type: UpdateAddHtlc
self._balance_delta -= htlc.amount_msat * htlc_proposer self._balance_delta -= htlc.amount_msat * htlc_proposer
@with_lock
def _init_maybe_active_htlc_ids(self): def _init_maybe_active_htlc_ids(self):
# first idx is "side who offered htlc": # first idx is "side who offered htlc":
self._maybe_active_htlc_ids = {LOCAL: set(), REMOTE: set()} # type: Dict[HTLCOwner, Set[int]] self._maybe_active_htlc_ids = {LOCAL: set(), REMOTE: set()} # type: Dict[HTLCOwner, Set[int]]
@ -220,6 +246,7 @@ class HTLCManager:
# remove old htlcs # remove old htlcs
self._update_maybe_active_htlc_ids() self._update_maybe_active_htlc_ids()
@with_lock
def discard_unsigned_remote_updates(self): def discard_unsigned_remote_updates(self):
"""Discard updates sent by the remote, that the remote itself """Discard updates sent by the remote, that the remote itself
did not yet sign (i.e. there was no corresponding commitment_signed msg) did not yet sign (i.e. there was no corresponding commitment_signed msg)
@ -244,6 +271,7 @@ class HTLCManager:
if fee_update.ctn_local > self.ctn_latest(LOCAL): if fee_update.ctn_local > self.ctn_latest(LOCAL):
self.log[REMOTE]['fee_updates'].pop(k) self.log[REMOTE]['fee_updates'].pop(k)
@with_lock
def store_local_update_raw_msg(self, raw_update_msg: bytes, *, is_commitment_signed: bool) -> None: def store_local_update_raw_msg(self, raw_update_msg: bytes, *, is_commitment_signed: bool) -> None:
"""We need to be able to replay unacknowledged updates we sent to the remote """We need to be able to replay unacknowledged updates we sent to the remote
in case of disconnections. Hence, raw update and commitment_signed messages in case of disconnections. Hence, raw update and commitment_signed messages
@ -258,6 +286,7 @@ class HTLCManager:
l.append(raw_update_msg.hex()) l.append(raw_update_msg.hex())
self.log['unacked_local_updates2'][ctn_idx] = l self.log['unacked_local_updates2'][ctn_idx] = l
@with_lock
def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]: def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]:
#return self.log['unacked_local_updates2'] #return self.log['unacked_local_updates2']
return {int(ctn): [bfh(msg) for msg in messages] return {int(ctn): [bfh(msg) for msg in messages]
@ -265,6 +294,7 @@ class HTLCManager:
##### Queries re HTLCs: ##### Queries re HTLCs:
@with_lock
def is_htlc_active_at_ctn(self, *, ctx_owner: HTLCOwner, ctn: int, def is_htlc_active_at_ctn(self, *, ctx_owner: HTLCOwner, ctn: int,
htlc_proposer: HTLCOwner, htlc_id: int) -> bool: htlc_proposer: HTLCOwner, htlc_id: int) -> bool:
htlc_id = int(htlc_id) htlc_id = int(htlc_id)
@ -280,6 +310,7 @@ class HTLCManager:
return True return True
return False return False
@with_lock
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Dict[int, UpdateAddHtlc]: ctn: int = None) -> Dict[int, UpdateAddHtlc]:
"""Return the dict of received or sent (depending on direction) HTLCs """Return the dict of received or sent (depending on direction) HTLCs
@ -305,6 +336,7 @@ class HTLCManager:
d[htlc_id] = self.log[party]['adds'][htlc_id] d[htlc_id] = self.log[party]['adds'][htlc_id]
return d return d
@with_lock
def htlcs(self, subject: HTLCOwner, ctn: int = None) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: def htlcs(self, subject: HTLCOwner, ctn: int = None) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
"""Return the list of HTLCs in subject's ctx at ctn.""" """Return the list of HTLCs in subject's ctx at ctn."""
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
@ -315,16 +347,19 @@ class HTLCManager:
l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn).values()] l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn).values()]
return l return l
@with_lock
def get_htlcs_in_oldest_unrevoked_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: def get_htlcs_in_oldest_unrevoked_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
ctn = self.ctn_oldest_unrevoked(subject) ctn = self.ctn_oldest_unrevoked(subject)
return self.htlcs(subject, ctn) return self.htlcs(subject, ctn)
@with_lock
def get_htlcs_in_latest_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: def get_htlcs_in_latest_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
ctn = self.ctn_latest(subject) ctn = self.ctn_latest(subject)
return self.htlcs(subject, ctn) return self.htlcs(subject, ctn)
@with_lock
def get_htlcs_in_next_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: def get_htlcs_in_next_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
ctn = self.ctn_latest(subject) + 1 ctn = self.ctn_latest(subject) + 1
@ -336,6 +371,7 @@ class HTLCManager:
return False return False
return settles[htlc_id][htlc_sender] is not None return settles[htlc_id][htlc_sender] is not None
@with_lock
def all_settled_htlcs_ever_by_direction(self, subject: HTLCOwner, direction: Direction, def all_settled_htlcs_ever_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Sequence[UpdateAddHtlc]: ctn: int = None) -> Sequence[UpdateAddHtlc]:
"""Return the list of all HTLCs that have been ever settled in subject's """Return the list of all HTLCs that have been ever settled in subject's
@ -353,6 +389,7 @@ class HTLCManager:
d.append(self.log[party]['adds'][htlc_id]) d.append(self.log[party]['adds'][htlc_id])
return d return d
@with_lock
def all_settled_htlcs_ever(self, subject: HTLCOwner, ctn: int = None) \ def all_settled_htlcs_ever(self, subject: HTLCOwner, ctn: int = None) \
-> Sequence[Tuple[Direction, UpdateAddHtlc]]: -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
"""Return the list of all HTLCs that have been ever settled in subject's """Return the list of all HTLCs that have been ever settled in subject's
@ -365,6 +402,7 @@ class HTLCManager:
received = [(RECEIVED, x) for x in self.all_settled_htlcs_ever_by_direction(subject, RECEIVED, ctn)] received = [(RECEIVED, x) for x in self.all_settled_htlcs_ever_by_direction(subject, RECEIVED, ctn)]
return sent + received return sent + received
@with_lock
def get_balance_msat(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None, def get_balance_msat(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None,
initial_balance_msat: int) -> int: initial_balance_msat: int) -> int:
"""Returns the balance of 'whose' in 'ctx' at 'ctn'. """Returns the balance of 'whose' in 'ctx' at 'ctn'.
@ -396,6 +434,7 @@ class HTLCManager:
balance += htlc.amount_msat balance += htlc.amount_msat
return balance return balance
@with_lock
def _get_htlcs_that_got_removed_exactly_at_ctn( def _get_htlcs_that_got_removed_exactly_at_ctn(
self, ctn: int, *, ctx_owner: HTLCOwner, htlc_proposer: HTLCOwner, log_action: str, self, ctn: int, *, ctx_owner: HTLCOwner, htlc_proposer: HTLCOwner, log_action: str,
) -> Sequence[UpdateAddHtlc]: ) -> Sequence[UpdateAddHtlc]:
@ -443,6 +482,7 @@ class HTLCManager:
##### Queries re Fees: ##### Queries re Fees:
@with_lock
def get_feerate(self, subject: HTLCOwner, ctn: int) -> int: def get_feerate(self, subject: HTLCOwner, ctn: int) -> int:
"""Return feerate used in subject's commitment txn at ctn.""" """Return feerate used in subject's commitment txn at ctn."""
ctn = max(0, ctn) # FIXME rm this ctn = max(0, ctn) # FIXME rm this

Loading…
Cancel
Save