Browse Source

Restructure wallet storage:

- Perform json deserializations in wallet_db
 - use StoredDict class that keeps tracks of its modifications
hard-fail-on-bad-server-string
ThomasV 5 years ago
parent
commit
dbceed2647
  1. 97
      electrum/json_db.py
  2. 138
      electrum/lnchannel.py
  3. 62
      electrum/lnhtlc.py
  4. 21
      electrum/lnpeer.py
  5. 4
      electrum/lnsweep.py
  6. 26
      electrum/lnutil.py
  7. 50
      electrum/lnworker.py
  8. 2
      electrum/plugins/labels/labels.py
  9. 27
      electrum/tests/test_lnchannel.py
  10. 22
      electrum/tests/test_lnhtlc.py
  11. 14
      electrum/tests/test_lnutil.py
  12. 2
      electrum/util.py
  13. 32
      electrum/wallet.py
  14. 97
      electrum/wallet_db.py

97
electrum/json_db.py

@ -45,6 +45,101 @@ def locked(func):
return wrapper return wrapper
class StoredObject:
db = None
def __setattr__(self, key, value):
if self.db:
self.db.set_modified(True)
object.__setattr__(self, key, value)
def set_db(self, db):
self.db = db
def to_json(self):
d = dict(vars(self))
d.pop('db', None)
return d
_RaiseKeyError = object() # singleton for no-default behavior
class StoredDict(dict):
def __init__(self, data, db, path):
self.db = db
self.lock = self.db.lock if self.db else threading.RLock()
self.path = path
# recursively convert dicts to StoredDict
for k, v in list(data.items()):
self.__setitem__(k, v)
def convert_key(self, key):
# convert int, HTLCOwner to str
return str(int(key)) if isinstance(key, int) else key
@locked
def __setitem__(self, key, v):
key = self.convert_key(key)
is_new = key not in self
# early return to prevent unnecessary disk writes
if not is_new and self[key] == v:
return
# recursively convert dict to StoredDict.
# _convert_dict is called breadth-first
if isinstance(v, dict):
if self.db:
v = self.db._convert_dict(self.path, key, v)
v = StoredDict(v, self.db, self.path + [key])
# convert_value is called depth-first
if isinstance(v, dict) or isinstance(v, str):
if self.db:
v = self.db._convert_value(self.path, key, v)
# set parent of StoredObject
if isinstance(v, StoredObject):
v.set_db(self.db)
# set item
dict.__setitem__(self, key, v)
if self.db:
self.db.set_modified(True)
@locked
def __delitem__(self, key):
key = self.convert_key(key)
dict.__delitem__(self, key)
if self.db:
self.db.set_modified(True)
@locked
def __getitem__(self, key):
key = self.convert_key(key)
return dict.__getitem__(self, key)
@locked
def __contains__(self, key):
key = self.convert_key(key)
return dict.__contains__(self, key)
@locked
def pop(self, key, v=_RaiseKeyError):
key = self.convert_key(key)
if v is _RaiseKeyError:
r = dict.pop(self, key)
else:
r = dict.pop(self, key, v)
if self.db:
self.db.set_modified(True)
return r
@locked
def get(self, key, default=None):
key = self.convert_key(key)
return dict.get(self, key, default)
class JsonDB(Logger): class JsonDB(Logger):
def __init__(self, data): def __init__(self, data):
@ -65,8 +160,6 @@ class JsonDB(Logger):
v = self.data.get(key) v = self.data.get(key)
if v is None: if v is None:
v = default v = default
else:
v = copy.deepcopy(v)
return v return v
@modifier @modifier

138
electrum/lnchannel.py

@ -54,6 +54,7 @@ from .lnhtlc import HTLCManager
if TYPE_CHECKING: if TYPE_CHECKING:
from .lnworker import LNWallet from .lnworker import LNWallet
from .json_db import StoredDict
# lightning channel states # lightning channel states
@ -92,17 +93,6 @@ state_transitions = [
(cs.CLOSED, cs.REDEEMED), (cs.CLOSED, cs.REDEEMED),
] ]
class ChannelJsonEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, bytes):
return binascii.hexlify(o).decode("ascii")
if isinstance(o, RevocationStore):
return o.serialize()
if isinstance(o, set):
return list(o)
if hasattr(o, 'to_json') and callable(o.to_json):
return o.to_json()
return super().default(o)
RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"]) RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"])
@ -110,31 +100,9 @@ RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_co
class RemoteCtnTooFarInFuture(Exception): pass class RemoteCtnTooFarInFuture(Exception): pass
def decodeAll(d, local):
for k, v in d.items():
if k == 'revocation_store':
yield (k, RevocationStore(v))
elif k.endswith("_basepoint") or k.endswith("_key"):
if local:
yield (k, Keypair(**dict(decodeAll(v, local))))
else:
yield (k, OnlyPubkeyKeypair(**dict(decodeAll(v, local))))
elif k in ["node_id", "channel_id", "short_channel_id", "pubkey", "privkey", "current_per_commitment_point", "next_per_commitment_point", "per_commitment_secret_seed", "current_commitment_signature", "current_htlc_signatures"] and v is not None:
yield (k, binascii.unhexlify(v))
else:
yield (k, v)
def htlcsum(htlcs): def htlcsum(htlcs):
return sum([x.amount_msat for x in htlcs]) return sum([x.amount_msat for x in htlcs])
# following two functions are used because json
# doesn't store int keys and byte string values
def str_bytes_dict_from_save(x) -> Dict[int, bytes]:
return {int(k): bfh(v) for k,v in x.items()}
def str_bytes_dict_to_save(x) -> Dict[str, str]:
return {str(k): bh2u(v) for k, v in x.items()}
class Channel(Logger): class Channel(Logger):
# note: try to avoid naming ctns/ctxs/etc as "current" and "pending". # note: try to avoid naming ctns/ctxs/etc as "current" and "pending".
@ -149,44 +117,53 @@ class Channel(Logger):
except: except:
return super().diagnostic_name() return super().diagnostic_name()
def __init__(self, state, *, sweep_address=None, name=None, lnworker=None, initial_feerate=None): def __init__(self, state: 'StoredDict', *, sweep_address=None, name=None, lnworker=None, initial_feerate=None):
self.name = name self.name = name
Logger.__init__(self) Logger.__init__(self)
self.lnworker = lnworker # type: Optional[LNWallet] self.lnworker = lnworker # type: Optional[LNWallet]
self.sweep_address = sweep_address self.sweep_address = sweep_address
assert 'local_state' not in state self.storage = state
self.db_lock = self.lnworker.wallet.storage.db.lock if self.lnworker else threading.RLock() self.db_lock = self.storage.db.lock if self.storage.db else threading.RLock()
self.config = {} self.config = {}
self.config[LOCAL] = state["local_config"] self.config[LOCAL] = state["local_config"]
if type(self.config[LOCAL]) is not LocalConfig:
conf = dict(decodeAll(self.config[LOCAL], True))
self.config[LOCAL] = LocalConfig(**conf)
assert type(self.config[LOCAL].htlc_basepoint.privkey) is bytes
self.config[REMOTE] = state["remote_config"] self.config[REMOTE] = state["remote_config"]
if type(self.config[REMOTE]) is not RemoteConfig: self.channel_id = bfh(state["channel_id"])
conf = dict(decodeAll(self.config[REMOTE], False)) self.constraints = state["constraints"]
self.config[REMOTE] = RemoteConfig(**conf) self.funding_outpoint = state["funding_outpoint"]
assert type(self.config[REMOTE].htlc_basepoint.pubkey) is bytes self.node_id = bfh(state["node_id"])
self.channel_id = bfh(state["channel_id"]) if type(state["channel_id"]) not in (bytes, type(None)) else state["channel_id"]
self.constraints = ChannelConstraints(**state["constraints"]) if type(state["constraints"]) is not ChannelConstraints else state["constraints"]
self.funding_outpoint = Outpoint(**dict(decodeAll(state["funding_outpoint"], False))) if type(state["funding_outpoint"]) is not Outpoint else state["funding_outpoint"]
self.node_id = bfh(state["node_id"]) if type(state["node_id"]) not in (bytes, type(None)) else state["node_id"] # type: bytes
self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"]) self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"])
self.short_channel_id_predicted = self.short_channel_id self.short_channel_id_predicted = self.short_channel_id
self.onion_keys = str_bytes_dict_from_save(state.get('onion_keys', {})) self.onion_keys = state['onion_keys']
self.data_loss_protect_remote_pcp = str_bytes_dict_from_save(state.get('data_loss_protect_remote_pcp', {})) self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp']
self.remote_update = bfh(state.get('remote_update')) if state.get('remote_update') else None self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate)
log = state.get('log')
self.hm = HTLCManager(log=log, initial_feerate=initial_feerate)
self._state = channel_states[state['state']] self._state = channel_states[state['state']]
self.peer_state = peer_states.DISCONNECTED self.peer_state = peer_states.DISCONNECTED
self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]] self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]]
self._outgoing_channel_update = None # type: Optional[bytes] self._outgoing_channel_update = None # type: Optional[bytes]
self.revocation_store = RevocationStore(state["revocation_store"]) self.revocation_store = RevocationStore(state["revocation_store"])
def set_onion_key(self, key, value):
self.onion_keys[key] = value
def get_onion_key(self, key):
return self.onion_keys.get(key)
def set_data_loss_protect_remote_pcp(self, key, value):
self.data_loss_protect_remote_pcp[key] = value
def get_data_loss_protect_remote_pcp(self, key):
self.data_loss_protect_remote_pcp.get(key)
def set_remote_update(self, raw):
self.storage['remote_update'] = raw.hex()
def get_remote_update(self):
return bfh(self.storage.get('remote_update')) if self.storage.get('remote_update') else None
def set_short_channel_id(self, short_id):
self.short_channel_id = short_id
self.storage["short_channel_id"] = short_id
def get_feerate(self, subject, ctn): def get_feerate(self, subject, ctn):
return self.hm.get_feerate(subject, ctn) return self.hm.get_feerate(subject, ctn)
@ -229,8 +206,10 @@ class Channel(Logger):
old_state = self._state old_state = self._state
if (old_state, state) not in state_transitions: if (old_state, state) not in state_transitions:
raise Exception(f"Transition not allowed: {old_state.name} -> {state.name}") raise Exception(f"Transition not allowed: {old_state.name} -> {state.name}")
self._state = state
self.logger.debug(f'Setting channel state: {old_state.name} -> {state.name}') self.logger.debug(f'Setting channel state: {old_state.name} -> {state.name}')
self._state = state
self.storage['state'] = self._state.name
if self.lnworker: if self.lnworker:
self.lnworker.save_channel(self) self.lnworker.save_channel(self)
self.lnworker.network.trigger_callback('channel', self) self.lnworker.network.trigger_callback('channel', self)
@ -656,51 +635,6 @@ class Channel(Logger):
else: else:
self.hm.recv_update_fee(feerate) self.hm.recv_update_fee(feerate)
def to_save(self):
to_save = {
"local_config": self.config[LOCAL],
"remote_config": self.config[REMOTE],
"channel_id": self.channel_id,
"short_channel_id": self.short_channel_id,
"constraints": self.constraints,
"funding_outpoint": self.funding_outpoint,
"node_id": self.node_id,
"log": self.hm.to_save(),
"revocation_store": self.revocation_store,
"onion_keys": str_bytes_dict_to_save(self.onion_keys),
"state": self._state.name,
"data_loss_protect_remote_pcp": str_bytes_dict_to_save(self.data_loss_protect_remote_pcp),
"remote_update": self.remote_update.hex() if self.remote_update else None
}
return to_save
def serialize(self):
namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()}
serialized_channel = {}
to_save_ref = self.to_save()
for k, v in to_save_ref.items():
if isinstance(v, tuple):
serialized_channel[k] = namedtuples_to_dict(v)
else:
serialized_channel[k] = v
dumped = ChannelJsonEncoder().encode(serialized_channel)
roundtripped = json.loads(dumped)
reconstructed = Channel(roundtripped)
to_save_new = reconstructed.to_save()
if to_save_new != to_save_ref:
from pprint import PrettyPrinter
pp = PrettyPrinter(indent=168)
try:
from deepdiff import DeepDiff
except ImportError:
raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(to_save_ref) + "\n" + pp.pformat(to_save_new))
else:
raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(DeepDiff(to_save_ref, to_save_new)))
return roundtripped
def __str__(self):
return str(self.serialize())
def make_commitment(self, subject, this_point, ctn) -> PartialTransaction: def make_commitment(self, subject, this_point, ctn) -> PartialTransaction:
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
feerate = self.get_feerate(subject, ctn) feerate = self.get_feerate(subject, ctn)

62
electrum/lnhtlc.py

@ -1,14 +1,17 @@
from copy import deepcopy from copy import deepcopy
from typing import Optional, Sequence, Tuple, List, Dict from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate
from .util import bh2u, bfh from .util import bh2u, bfh
if TYPE_CHECKING:
from .json_db import StoredDict
class HTLCManager: class HTLCManager:
def __init__(self, *, log=None, initial_feerate=None): def __init__(self, log:'StoredDict', *, initial_feerate=None):
if log is None:
if len(log) == 0:
initial = { initial = {
'adds': {}, 'adds': {},
'locked_in': {}, 'locked_in': {},
@ -17,33 +20,18 @@ class HTLCManager:
'fee_updates': {}, # "side who initiated fee update" -> action -> list of FeeUpdates 'fee_updates': {}, # "side who initiated fee update" -> action -> list of FeeUpdates
'revack_pending': False, 'revack_pending': False,
'next_htlc_id': 0, 'next_htlc_id': 0,
'ctn': -1, # oldest unrevoked ctx of sub 'ctn': -1, # oldest unrevoked ctx of sub
} }
log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)} log[LOCAL] = deepcopy(initial)
else: log[REMOTE] = deepcopy(initial)
assert type(log) is dict
log = {(HTLCOwner(int(k)) if k in ("-1", "1") else k): v
for k, v in deepcopy(log).items()}
for sub in (LOCAL, REMOTE):
log[sub]['adds'] = {int(htlc_id): UpdateAddHtlc(*htlc) for htlc_id, htlc in log[sub]['adds'].items()}
coerceHtlcOwner2IntMap = lambda ctns: {HTLCOwner(int(owner)): ctn for owner, ctn in ctns.items()}
# "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn
log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()}
log[sub]['settles'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['settles'].items()}
log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()}
# "side who initiated fee update" -> action -> list of FeeUpdates
log[sub]['fee_updates'] = { int(x): FeeUpdate(**fee_upd) for x,fee_upd in log[sub]['fee_updates'].items() }
if 'unacked_local_updates2' not in log:
log['unacked_local_updates2'] = {} log['unacked_local_updates2'] = {}
log['unacked_local_updates2'] = {int(ctn): [bfh(msg) for msg in messages]
for ctn, messages in log['unacked_local_updates2'].items()}
# maybe bootstrap fee_updates if initial_feerate was provided # maybe bootstrap fee_updates if initial_feerate was provided
if initial_feerate is not None: if initial_feerate is not None:
assert type(initial_feerate) is int assert type(initial_feerate) is int
for sub in (LOCAL, REMOTE): for sub in (LOCAL, REMOTE):
if not log[sub]['fee_updates']: if not log[sub]['fee_updates']:
log[sub]['fee_updates'][0] = FeeUpdate(initial_feerate, ctn_local=0, ctn_remote=0) log[sub]['fee_updates'][0] = FeeUpdate(rate=initial_feerate, ctn_local=0, ctn_remote=0)
self.log = log self.log = log
def ctn_latest(self, sub: HTLCOwner) -> int: def ctn_latest(self, sub: HTLCOwner) -> int:
@ -66,20 +54,6 @@ class HTLCManager:
def get_next_htlc_id(self, sub: HTLCOwner) -> int: def get_next_htlc_id(self, sub: HTLCOwner) -> int:
return self.log[sub]['next_htlc_id'] return self.log[sub]['next_htlc_id']
def to_save(self):
log = deepcopy(self.log)
for sub in (LOCAL, REMOTE):
# adds
d = {}
for htlc_id, htlc in log[sub]['adds'].items():
d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:]
log[sub]['adds'] = d
# fee_updates
log[sub]['fee_updates'] = { x:fee_upd.to_json() for x, fee_upd in self.log[sub]['fee_updates'].items() }
log['unacked_local_updates2'] = {ctn: [bh2u(msg) for msg in messages]
for ctn, messages in log['unacked_local_updates2'].items()}
return log
##### Actions on channel: ##### Actions on channel:
def channel_open_finished(self): def channel_open_finished(self):
@ -132,7 +106,7 @@ class HTLCManager:
def _new_feeupdate(self, fee_update: FeeUpdate, subject: HTLCOwner) -> None: def _new_feeupdate(self, fee_update: FeeUpdate, subject: HTLCOwner) -> None:
# overwrite last fee update if not yet committed to by anyone; otherwise append # overwrite last fee update if not yet committed to by anyone; otherwise append
d = self.log[subject]['fee_updates'] d = self.log[subject]['fee_updates']
assert type(d) is dict #assert type(d) is StoredDict
n = len(d) n = len(d)
last_fee_update = d[n-1] last_fee_update = d[n-1]
if (last_fee_update.ctn_local is None or last_fee_update.ctn_local > self.ctn_latest(LOCAL)) \ if (last_fee_update.ctn_local is None or last_fee_update.ctn_local > self.ctn_latest(LOCAL)) \
@ -194,7 +168,7 @@ class HTLCManager:
del self.log[REMOTE]['locked_in'][htlc_id] del self.log[REMOTE]['locked_in'][htlc_id]
del self.log[REMOTE]['adds'][htlc_id] del self.log[REMOTE]['adds'][htlc_id]
if self.log[REMOTE]['locked_in']: if self.log[REMOTE]['locked_in']:
self.log[REMOTE]['next_htlc_id'] = max(self.log[REMOTE]['locked_in']) + 1 self.log[REMOTE]['next_htlc_id'] = max([int(x) for x in self.log[REMOTE]['locked_in'].keys()]) + 1
else: else:
self.log[REMOTE]['next_htlc_id'] = 0 self.log[REMOTE]['next_htlc_id'] = 0
# htlcs removed # htlcs removed
@ -217,12 +191,14 @@ class HTLCManager:
ctn_idx = self.ctn_latest(REMOTE) ctn_idx = self.ctn_latest(REMOTE)
else: else:
ctn_idx = self.ctn_latest(REMOTE) + 1 ctn_idx = self.ctn_latest(REMOTE) + 1
if ctn_idx not in self.log['unacked_local_updates2']: l = self.log['unacked_local_updates2'].get(ctn_idx, [])
self.log['unacked_local_updates2'][ctn_idx] = [] l.append(raw_update_msg.hex())
self.log['unacked_local_updates2'][ctn_idx].append(raw_update_msg) self.log['unacked_local_updates2'][ctn_idx] = l
def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]: def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]:
return self.log['unacked_local_updates2'] #return self.log['unacked_local_updates2']
return {int(ctn): [bfh(msg) for msg in messages]
for ctn, messages in self.log['unacked_local_updates2'].items()}
##### Queries re HTLCs: ##### Queries re HTLCs:

21
electrum/lnpeer.py

@ -221,7 +221,7 @@ class Peer(Logger):
def maybe_save_remote_update(self, payload): def maybe_save_remote_update(self, payload):
for chan in self.channels.values(): for chan in self.channels.values():
if chan.short_channel_id == payload['short_channel_id']: if chan.short_channel_id == payload['short_channel_id']:
chan.remote_update = payload['raw'] chan.set_remote_update(payload['raw'])
self.logger.info("saved remote_update") self.logger.info("saved remote_update")
def on_announcement_signatures(self, payload): def on_announcement_signatures(self, payload):
@ -611,9 +611,15 @@ class Peer(Logger):
"constraints": constraints, "constraints": constraints,
"remote_update": None, "remote_update": None,
"state": channel_states.PREOPENING.name, "state": channel_states.PREOPENING.name,
'onion_keys': {},
'data_loss_protect_remote_pcp': {},
"log": {},
"revocation_store": {}, "revocation_store": {},
} }
return chan_dict channel_id = chan_dict.get('channel_id')
channels = self.lnworker.storage.db.get_dict('channels')
channels[channel_id] = chan_dict
return channels.get(channel_id)
async def on_open_channel(self, payload): async def on_open_channel(self, payload):
# payload['channel_flags'] # payload['channel_flags']
@ -684,7 +690,7 @@ class Peer(Logger):
signature=sig_64, signature=sig_64,
) )
chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig) chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig)
self.lnworker.save_channel(chan) self.lnworker.add_channel(chan)
self.lnworker.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) self.lnworker.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
def validate_remote_reserve(self, payload_field: bytes, dust_limit: int, funding_sat: int) -> int: def validate_remote_reserve(self, payload_field: bytes, dust_limit: int, funding_sat: int) -> int:
@ -850,7 +856,7 @@ class Peer(Logger):
else: else:
if dlp_enabled and should_close_they_are_ahead: if dlp_enabled and should_close_they_are_ahead:
self.logger.warning(f"channel_reestablish: remote is ahead of us! luckily DLP is enabled. remote PCP: {bh2u(their_local_pcp)}") self.logger.warning(f"channel_reestablish: remote is ahead of us! luckily DLP is enabled. remote PCP: {bh2u(their_local_pcp)}")
chan.data_loss_protect_remote_pcp[their_next_local_ctn - 1] = their_local_pcp chan.set_data_loss_protect_remote_pcp(their_next_local_ctn - 1, their_local_pcp)
self.lnworker.save_channel(chan) self.lnworker.save_channel(chan)
if should_close_they_are_ahead: if should_close_they_are_ahead:
self.logger.warning(f"channel_reestablish: remote is ahead of us! trying to get them to force-close.") self.logger.warning(f"channel_reestablish: remote is ahead of us! trying to get them to force-close.")
@ -885,7 +891,6 @@ class Peer(Logger):
self.logger.info(f"on_funding_locked. channel: {bh2u(channel_id)}") self.logger.info(f"on_funding_locked. channel: {bh2u(channel_id)}")
chan = self.channels.get(channel_id) chan = self.channels.get(channel_id)
if not chan: if not chan:
print(self.channels)
raise Exception("Got unknown funding_locked", channel_id) raise Exception("Got unknown funding_locked", channel_id)
if not chan.config[LOCAL].funding_locked_received: if not chan.config[LOCAL].funding_locked_received:
our_next_point = chan.config[REMOTE].next_per_commitment_point our_next_point = chan.config[REMOTE].next_per_commitment_point
@ -1004,11 +1009,11 @@ class Peer(Logger):
# peer may have sent us a channel update for the incoming direction previously # peer may have sent us a channel update for the incoming direction previously
pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id) pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
if pending_channel_update: if pending_channel_update:
chan.remote_update = pending_channel_update['raw'] chan.set_remote_update(pending_channel_update['raw'])
# add remote update with a fresh timestamp # add remote update with a fresh timestamp
if chan.remote_update: if chan.get_remote_update():
now = int(time.time()) now = int(time.time())
remote_update_decoded = decode_msg(chan.remote_update)[1] remote_update_decoded = decode_msg(chan.get_remote_update())[1]
remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big") remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
self.channel_db.add_channel_update(remote_update_decoded) self.channel_db.add_channel_update(remote_update_decoded)

4
electrum/lnsweep.py

@ -299,8 +299,8 @@ def analyze_ctx(chan: 'Channel', ctx: Transaction):
their_pcp = ecc.ECPrivkey(per_commitment_secret).get_public_key_bytes(compressed=True) their_pcp = ecc.ECPrivkey(per_commitment_secret).get_public_key_bytes(compressed=True)
is_revocation = True is_revocation = True
#_logger.info(f'tx for revoked: {list(txs.keys())}') #_logger.info(f'tx for revoked: {list(txs.keys())}')
elif ctn in chan.data_loss_protect_remote_pcp: elif chan.get_data_loss_protect_remote_pcp(ctn):
their_pcp = chan.data_loss_protect_remote_pcp[ctn] their_pcp = chan.get_data_loss_protect_remote_pcp(ctn)
is_revocation = False is_revocation = False
else: else:
return return

26
electrum/lnutil.py

@ -38,12 +38,7 @@ LN_MAX_FUNDING_SAT = pow(2, 24) - 1
def ln_dummy_address(): def ln_dummy_address():
return redeem_script_to_address('p2wsh', '') return redeem_script_to_address('p2wsh', '')
from .json_db import StoredObject
class StoredObject:
def to_json(self):
return dict(vars(self))
@attr.s @attr.s
class OnlyPubkeyKeypair(StoredObject): class OnlyPubkeyKeypair(StoredObject):
@ -180,21 +175,23 @@ class RevocationStore:
START_INDEX = 2 ** 48 - 1 START_INDEX = 2 ** 48 - 1
def __init__(self, storage): def __init__(self, storage):
self.index = storage.get('index', self.START_INDEX) if len(storage) == 0:
buckets = storage.get('buckets', {}) storage['index'] = self.START_INDEX
decode = lambda to_decode: ShachainElement(bfh(to_decode[0]), int(to_decode[1])) storage['buckets'] = {}
self.buckets = dict((int(k), decode(v)) for k, v in buckets.items()) self.storage = storage
self.buckets = storage['buckets']
def add_next_entry(self, hsh): def add_next_entry(self, hsh):
new_element = ShachainElement(index=self.index, secret=hsh) index = self.storage['index']
bucket = count_trailing_zeros(self.index) new_element = ShachainElement(index=index, secret=hsh)
bucket = count_trailing_zeros(index)
for i in range(0, bucket): for i in range(0, bucket):
this_bucket = self.buckets[i] this_bucket = self.buckets[i]
e = shachain_derive(new_element, this_bucket.index) e = shachain_derive(new_element, this_bucket.index)
if e != this_bucket: if e != this_bucket:
raise Exception("hash is not derivable: {} {} {}".format(bh2u(e.secret), bh2u(this_bucket.secret), this_bucket.index)) raise Exception("hash is not derivable: {} {} {}".format(bh2u(e.secret), bh2u(this_bucket.secret), this_bucket.index))
self.buckets[bucket] = new_element self.buckets[bucket] = new_element
self.index -= 1 self.storage['index'] = index - 1
def retrieve_secret(self, index: int) -> bytes: def retrieve_secret(self, index: int) -> bytes:
assert index <= self.START_INDEX, index assert index <= self.START_INDEX, index
@ -209,9 +206,6 @@ class RevocationStore:
return element.secret return element.secret
raise UnableToDeriveSecret() raise UnableToDeriveSecret()
def serialize(self):
return {"index": self.index, "buckets": dict( (k, [bh2u(v.secret), v.index]) for k, v in self.buckets.items()) }
def __eq__(self, o): def __eq__(self, o):
return type(o) is RevocationStore and self.serialize() == o.serialize() return type(o) is RevocationStore and self.serialize() == o.serialize()

50
electrum/lnworker.py

@ -34,13 +34,14 @@ from .bip32 import BIP32Node
from .util import bh2u, bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions from .util import bh2u, bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions
from .util import ignore_exceptions, make_aiohttp_session from .util import ignore_exceptions, make_aiohttp_session
from .util import timestamp_to_datetime from .util import timestamp_to_datetime
from .util import MyEncoder
from .logging import Logger from .logging import Logger
from .lntransport import LNTransport, LNResponderTransport from .lntransport import LNTransport, LNResponderTransport
from .lnpeer import Peer, LN_P2P_NETWORK_TIMEOUT from .lnpeer import Peer, LN_P2P_NETWORK_TIMEOUT
from .lnaddr import lnencode, LnAddr, lndecode from .lnaddr import lnencode, LnAddr, lndecode
from .ecc import der_sig_from_sig_string from .ecc import der_sig_from_sig_string
from .ecc_fast import is_using_fast_ecc from .ecc_fast import is_using_fast_ecc
from .lnchannel import Channel, ChannelJsonEncoder from .lnchannel import Channel
from .lnchannel import channel_states, peer_states from .lnchannel import channel_states, peer_states
from . import lnutil from . import lnutil
from .lnutil import funding_output_script from .lnutil import funding_output_script
@ -106,8 +107,6 @@ FALLBACK_NODE_LIST_MAINNET = [
LNPeerAddr(host='3.124.63.44', port=9735, pubkey=bfh('0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3')), LNPeerAddr(host='3.124.63.44', port=9735, pubkey=bfh('0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3')),
] ]
encoder = ChannelJsonEncoder()
from typing import NamedTuple from typing import NamedTuple
@ -347,19 +346,20 @@ class LNWallet(LNWorker):
LNWorker.__init__(self, xprv) LNWorker.__init__(self, xprv)
self.ln_keystore = keystore.from_xprv(xprv) self.ln_keystore = keystore.from_xprv(xprv)
self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ
self.payments = self.storage.get('lightning_payments', {}) # RHASH -> amount, direction, is_paid self.payments = self.storage.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid
self.preimages = self.storage.get('lightning_preimages', {}) # RHASH -> preimage self.preimages = self.storage.db.get_dict('lightning_preimages') # RHASH -> preimage
self.sweep_address = wallet.get_receiving_address() self.sweep_address = wallet.get_receiving_address()
self.lock = threading.RLock() self.lock = threading.RLock()
self.logs = defaultdict(list) # type: Dict[str, List[PaymentAttemptLog]] # key is RHASH self.logs = defaultdict(list) # type: Dict[str, List[PaymentAttemptLog]] # key is RHASH
# note: accessing channels (besides simple lookup) needs self.lock! # note: accessing channels (besides simple lookup) needs self.lock!
self.channels = {} # type: Dict[bytes, Channel] self.channels = {}
for x in wallet.storage.get("channels", {}).values(): channels = self.storage.db.get_dict("channels")
c = Channel(x, sweep_address=self.sweep_address, lnworker=self) for channel_id, c in channels.items():
self.channels[c.channel_id] = c self.channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
# timestamps of opening and closing transactions # timestamps of opening and closing transactions
self.channel_timestamps = self.storage.get('lightning_channel_timestamps', {}) self.channel_timestamps = self.storage.db.get_dict('lightning_channel_timestamps')
self.pending_payments = defaultdict(asyncio.Future) self.pending_payments = defaultdict(asyncio.Future)
@ignore_exceptions @ignore_exceptions
@ -610,17 +610,9 @@ class LNWallet(LNWorker):
assert type(chan) is Channel assert type(chan) is Channel
if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point: if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point:
raise Exception("Tried to save channel with next_point == current_point, this should not happen") raise Exception("Tried to save channel with next_point == current_point, this should not happen")
with self.lock: self.wallet.storage.write()
self.channels[chan.channel_id] = chan
self.save_channels()
self.network.trigger_callback('channel', chan) self.network.trigger_callback('channel', chan)
def save_channels(self):
with self.lock:
dumped = dict( (k.hex(), c.serialize()) for k, c in self.channels.items() )
self.storage.put("channels", dumped)
self.storage.write()
def save_short_chan_id(self, chan): def save_short_chan_id(self, chan):
""" """
Checks if Funding TX has been mined. If it has, save the short channel ID in chan; Checks if Funding TX has been mined. If it has, save the short channel ID in chan;
@ -648,8 +640,8 @@ class LNWallet(LNWorker):
return return
block_height, tx_pos = self.lnwatcher.get_txpos(chan.funding_outpoint.txid) block_height, tx_pos = self.lnwatcher.get_txpos(chan.funding_outpoint.txid)
assert tx_pos >= 0 assert tx_pos >= 0
chan.short_channel_id = ShortChannelID.from_components( chan.set_short_channel_id(ShortChannelID.from_components(
block_height, tx_pos, chan.funding_outpoint.output_index) block_height, tx_pos, chan.funding_outpoint.output_index))
self.logger.info(f"save_short_channel_id: {chan.short_channel_id}") self.logger.info(f"save_short_channel_id: {chan.short_channel_id}")
self.save_channel(chan) self.save_channel(chan)
@ -669,7 +661,6 @@ class LNWallet(LNWorker):
# save timestamp regardless of state, so that funding tx is returned in get_history # save timestamp regardless of state, so that funding tx is returned in get_history
self.channel_timestamps[bh2u(chan.channel_id)] = chan.funding_outpoint.txid, funding_height.height, funding_height.timestamp, None, None, None self.channel_timestamps[bh2u(chan.channel_id)] = chan.funding_outpoint.txid, funding_height.height, funding_height.timestamp, None, None, None
self.storage.put('lightning_channel_timestamps', self.channel_timestamps)
if chan.get_state() == channel_states.OPEN and self.should_channel_be_closed_due_to_expiring_htlcs(chan): if chan.get_state() == channel_states.OPEN and self.should_channel_be_closed_due_to_expiring_htlcs(chan):
self.logger.info(f"force-closing due to expiring htlcs") self.logger.info(f"force-closing due to expiring htlcs")
@ -714,7 +705,6 @@ class LNWallet(LNWorker):
# fixme: this is wasteful # fixme: this is wasteful
self.channel_timestamps[bh2u(chan.channel_id)] = funding_txid, funding_height.height, funding_height.timestamp, closing_txid, closing_height.height, closing_height.timestamp self.channel_timestamps[bh2u(chan.channel_id)] = funding_txid, funding_height.height, funding_height.timestamp, closing_txid, closing_height.height, closing_height.timestamp
self.storage.put('lightning_channel_timestamps', self.channel_timestamps)
# remove from channel_db # remove from channel_db
if chan.short_channel_id is not None: if chan.short_channel_id is not None:
@ -836,7 +826,7 @@ class LNWallet(LNWorker):
funding_sat=funding_sat, funding_sat=funding_sat,
push_msat=push_sat * 1000, push_msat=push_sat * 1000,
temp_channel_id=os.urandom(32)) temp_channel_id=os.urandom(32))
self.save_channel(chan) self.add_channel(chan)
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
self.network.trigger_callback('channels_updated', self.wallet) self.network.trigger_callback('channels_updated', self.wallet)
self.wallet.add_transaction(funding_tx) # save tx as local into the wallet self.wallet.add_transaction(funding_tx) # save tx as local into the wallet
@ -846,6 +836,10 @@ class LNWallet(LNWorker):
await asyncio.wait_for(self.network.broadcast_transaction(funding_tx), LN_P2P_NETWORK_TIMEOUT) await asyncio.wait_for(self.network.broadcast_transaction(funding_tx), LN_P2P_NETWORK_TIMEOUT)
return chan, funding_tx return chan, funding_tx
def add_channel(self, chan):
with self.lock:
self.channels[chan.channel_id] = chan
@log_exceptions @log_exceptions
async def add_peer(self, connect_str: str) -> Peer: async def add_peer(self, connect_str: str) -> Peer:
node_id, rest = extract_nodeid(connect_str) node_id, rest = extract_nodeid(connect_str)
@ -1133,7 +1127,6 @@ class LNWallet(LNWorker):
def save_preimage(self, payment_hash: bytes, preimage: bytes): def save_preimage(self, payment_hash: bytes, preimage: bytes):
assert sha256(preimage) == payment_hash assert sha256(preimage) == payment_hash
self.preimages[bh2u(payment_hash)] = bh2u(preimage) self.preimages[bh2u(payment_hash)] = bh2u(preimage)
self.storage.put('lightning_preimages', self.preimages)
self.storage.write() self.storage.write()
def get_preimage(self, payment_hash: bytes) -> bytes: def get_preimage(self, payment_hash: bytes) -> bytes:
@ -1152,7 +1145,6 @@ class LNWallet(LNWorker):
assert info.status in [PR_PAID, PR_UNPAID, PR_INFLIGHT] assert info.status in [PR_PAID, PR_UNPAID, PR_INFLIGHT]
with self.lock: with self.lock:
self.payments[key] = info.amount, info.direction, info.status self.payments[key] = info.amount, info.direction, info.status
self.storage.put('lightning_payments', self.payments)
self.storage.write() self.storage.write()
def get_payment_status(self, payment_hash): def get_payment_status(self, payment_hash):
@ -1238,7 +1230,6 @@ class LNWallet(LNWorker):
del self.payments[payment_hash_hex] del self.payments[payment_hash_hex]
except KeyError: except KeyError:
return return
self.storage.put('lightning_payments', self.payments)
self.storage.write() self.storage.write()
def get_balance(self): def get_balance(self):
@ -1246,6 +1237,7 @@ class LNWallet(LNWorker):
return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0 for chan in self.channels.values()))/1000 return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0 for chan in self.channels.values()))/1000
def list_channels(self): def list_channels(self):
encoder = MyEncoder()
with self.lock: with self.lock:
# we output the funding_outpoint instead of the channel_id because lnd uses channel_point (funding outpoint) to identify channels # we output the funding_outpoint instead of the channel_id because lnd uses channel_point (funding outpoint) to identify channels
for channel_id, chan in self.channels.items(): for channel_id, chan in self.channels.items():
@ -1283,7 +1275,9 @@ class LNWallet(LNWorker):
assert chan.is_closed() assert chan.is_closed()
with self.lock: with self.lock:
self.channels.pop(chan_id) self.channels.pop(chan_id)
self.save_channels() self.channel_timestamps.pop(chan_id.hex())
self.storage.get('channels').pop(chan_id.hex())
self.network.trigger_callback('channels_updated', self.wallet) self.network.trigger_callback('channels_updated', self.wallet)
self.network.trigger_callback('wallet_updated', self.wallet) self.network.trigger_callback('wallet_updated', self.wallet)

2
electrum/plugins/labels/labels.py

@ -149,8 +149,6 @@ class LabelsPlugin(BasePlugin):
wallet.labels[key] = value wallet.labels[key] = value
self.logger.info(f"received {len(response)} labels") self.logger.info(f"received {len(response)} labels")
# do not write to disk because we're in a daemon thread
wallet.storage.put('labels', wallet.labels)
self.set_nonce(wallet, response["nonce"] + 1) self.set_nonce(wallet, response["nonce"] + 1)
self.on_pulled(wallet) self.on_pulled(wallet)

27
electrum/tests/test_lnchannel.py

@ -35,6 +35,7 @@ from electrum.lnutil import FeeUpdate
from electrum.ecc import sig_string_from_der_sig from electrum.ecc import sig_string_from_der_sig
from electrum.logging import console_stderr_handler from electrum.logging import console_stderr_handler
from electrum.lnchannel import channel_states from electrum.lnchannel import channel_states
from electrum.json_db import StoredDict
from . import ElectrumTestCase from . import ElectrumTestCase
@ -45,9 +46,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
assert local_amount > 0 assert local_amount > 0
assert remote_amount > 0 assert remote_amount > 0
channel_id, _ = lnpeer.channel_id_from_funding_tx(funding_txid, funding_index) channel_id, _ = lnpeer.channel_id_from_funding_tx(funding_txid, funding_index)
state = {
return { "channel_id":channel_id.hex(),
"channel_id":channel_id,
"short_channel_id":channel_id[:8], "short_channel_id":channel_id[:8],
"funding_outpoint":lnpeer.Outpoint(funding_txid, funding_index), "funding_outpoint":lnpeer.Outpoint(funding_txid, funding_index),
"remote_config":lnpeer.RemoteConfig( "remote_config":lnpeer.RemoteConfig(
@ -63,7 +63,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
initial_msat=remote_amount, initial_msat=remote_amount,
reserve_sat=0, reserve_sat=0,
htlc_minimum_msat=1, htlc_minimum_msat=1,
next_per_commitment_point=nex, next_per_commitment_point=nex,
current_per_commitment_point=cur, current_per_commitment_point=cur,
), ),
@ -79,7 +78,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
max_accepted_htlcs=5, max_accepted_htlcs=5,
initial_msat=local_amount, initial_msat=local_amount,
reserve_sat=0, reserve_sat=0,
per_commitment_secret_seed=seed, per_commitment_secret_seed=seed,
funding_locked_received=True, funding_locked_received=True,
was_announced=False, was_announced=False,
@ -91,11 +89,14 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
is_initiator=is_initiator, is_initiator=is_initiator,
funding_txn_minimum_depth=3, funding_txn_minimum_depth=3,
), ),
"node_id":other_node_id, "node_id":other_node_id.hex(),
'onion_keys': {}, 'onion_keys': {},
'data_loss_protect_remote_pcp': {},
'state': 'PREOPENING', 'state': 'PREOPENING',
'log': {},
'revocation_store': {}, 'revocation_store': {},
} }
return StoredDict(state, None, [])
def bip32(sequence): def bip32(sequence):
node = bip32_utils.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence) node = bip32_utils.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence)
@ -317,7 +318,6 @@ class TestChannel(ElectrumTestCase):
# Bob revokes his prior commitment given to him by Alice, since he now # Bob revokes his prior commitment given to him by Alice, since he now
# has a valid signature for a newer commitment. # has a valid signature for a newer commitment.
bobRevocation, _ = bob_channel.revoke_current_commitment() bobRevocation, _ = bob_channel.revoke_current_commitment()
bob_channel.serialize()
self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL))) self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL)))
# Bob finally sends a signature for Alice's commitment transaction. # Bob finally sends a signature for Alice's commitment transaction.
@ -341,18 +341,14 @@ class TestChannel(ElectrumTestCase):
# her prior commitment transaction. Alice shouldn't have any HTLCs to # her prior commitment transaction. Alice shouldn't have any HTLCs to
# forward since she's sending an outgoing HTLC. # forward since she's sending an outgoing HTLC.
alice_channel.receive_revocation(bobRevocation) alice_channel.receive_revocation(bobRevocation)
alice_channel.serialize()
self.assertTrue(alice_channel.signature_fits(alice_channel.get_latest_commitment(LOCAL))) self.assertTrue(alice_channel.signature_fits(alice_channel.get_latest_commitment(LOCAL)))
alice_channel.serialize()
self.assertEqual(len(alice_channel.get_latest_commitment(LOCAL).outputs()), 2) self.assertEqual(len(alice_channel.get_latest_commitment(LOCAL).outputs()), 2)
self.assertEqual(len(alice_channel.get_latest_commitment(REMOTE).outputs()), 3) self.assertEqual(len(alice_channel.get_latest_commitment(REMOTE).outputs()), 3)
self.assertEqual(len(alice_channel.force_close_tx().outputs()), 2) self.assertEqual(len(alice_channel.force_close_tx().outputs()), 2)
self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1) self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1)
alice_channel.serialize()
self.assertEqual(alice_channel.get_next_commitment(LOCAL).outputs(), self.assertEqual(alice_channel.get_next_commitment(LOCAL).outputs(),
bob_channel.get_latest_commitment(REMOTE).outputs()) bob_channel.get_latest_commitment(REMOTE).outputs())
@ -365,14 +361,12 @@ class TestChannel(ElectrumTestCase):
self.assertEqual(len(alice_channel.force_close_tx().outputs()), 3) self.assertEqual(len(alice_channel.force_close_tx().outputs()), 3)
self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1) self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1)
alice_channel.serialize()
tx1 = str(alice_channel.force_close_tx()) tx1 = str(alice_channel.force_close_tx())
self.assertNotEqual(tx0, tx1) self.assertNotEqual(tx0, tx1)
# Alice then generates a revocation for bob. # Alice then generates a revocation for bob.
aliceRevocation, _ = alice_channel.revoke_current_commitment() aliceRevocation, _ = alice_channel.revoke_current_commitment()
alice_channel.serialize()
tx2 = str(alice_channel.force_close_tx()) tx2 = str(alice_channel.force_close_tx())
# since alice already has the signature for the next one, it doesn't change her force close tx (it was already the newer one) # since alice already has the signature for the next one, it doesn't change her force close tx (it was already the newer one)
@ -384,7 +378,6 @@ class TestChannel(ElectrumTestCase):
# into both commitment transactions. # into both commitment transactions.
self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL))) self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL)))
bob_channel.receive_revocation(aliceRevocation) bob_channel.receive_revocation(aliceRevocation)
bob_channel.serialize()
# At this point, both sides should have the proper number of satoshis # At this point, both sides should have the proper number of satoshis
# sent, and commitment height updated within their local channel # sent, and commitment height updated within their local channel
@ -450,20 +443,16 @@ class TestChannel(ElectrumTestCase):
self.assertEqual(1, alice_channel.get_oldest_unrevoked_ctn(LOCAL)) self.assertEqual(1, alice_channel.get_oldest_unrevoked_ctn(LOCAL))
self.assertEqual(len(alice_channel.included_htlcs(LOCAL, RECEIVED, ctn=2)), 0) self.assertEqual(len(alice_channel.included_htlcs(LOCAL, RECEIVED, ctn=2)), 0)
aliceRevocation2, _ = alice_channel.revoke_current_commitment() aliceRevocation2, _ = alice_channel.revoke_current_commitment()
alice_channel.serialize()
aliceSig2, aliceHtlcSigs2 = alice_channel.sign_next_commitment() aliceSig2, aliceHtlcSigs2 = alice_channel.sign_next_commitment()
self.assertEqual(aliceHtlcSigs2, [], "alice should generate no htlc signatures") self.assertEqual(aliceHtlcSigs2, [], "alice should generate no htlc signatures")
self.assertEqual(len(bob_channel.get_latest_commitment(LOCAL).outputs()), 3) self.assertEqual(len(bob_channel.get_latest_commitment(LOCAL).outputs()), 3)
bob_channel.receive_revocation(aliceRevocation2) bob_channel.receive_revocation(aliceRevocation2)
bob_channel.serialize()
bob_channel.receive_new_commitment(aliceSig2, aliceHtlcSigs2) bob_channel.receive_new_commitment(aliceSig2, aliceHtlcSigs2)
bobRevocation2, (received, sent) = bob_channel.revoke_current_commitment() bobRevocation2, (received, sent) = bob_channel.revoke_current_commitment()
self.assertEqual(one_bitcoin_in_msat, received) self.assertEqual(one_bitcoin_in_msat, received)
bob_channel.serialize()
alice_channel.receive_revocation(bobRevocation2) alice_channel.receive_revocation(bobRevocation2)
alice_channel.serialize()
# At this point, Bob should have 6 BTC settled, with Alice still having # At this point, Bob should have 6 BTC settled, with Alice still having
# 4 BTC. Alice's channel should show 1 BTC sent and Bob's channel # 4 BTC. Alice's channel should show 1 BTC sent and Bob's channel
@ -509,8 +498,6 @@ class TestChannel(ElectrumTestCase):
self.assertEqual(bob_channel.total_msat(RECEIVED), one_bitcoin_in_msat, "bob satoshis received incorrect") self.assertEqual(bob_channel.total_msat(RECEIVED), one_bitcoin_in_msat, "bob satoshis received incorrect")
self.assertEqual(bob_channel.total_msat(SENT), 5 * one_bitcoin_in_msat, "bob satoshis sent incorrect") self.assertEqual(bob_channel.total_msat(SENT), 5 * one_bitcoin_in_msat, "bob satoshis sent incorrect")
alice_channel.serialize()
def alice_to_bob_fee_update(self, fee=111): def alice_to_bob_fee_update(self, fee=111):
aoldctx = self.alice_channel.get_next_commitment(REMOTE).outputs() aoldctx = self.alice_channel.get_next_commitment(REMOTE).outputs()

22
electrum/tests/test_lnhtlc.py

@ -4,18 +4,18 @@ from typing import NamedTuple
from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner, Direction from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner, Direction
from electrum.lnhtlc import HTLCManager from electrum.lnhtlc import HTLCManager
from electrum.json_db import StoredDict
from . import ElectrumTestCase from . import ElectrumTestCase
class H(NamedTuple): class H(NamedTuple):
owner : str owner : str
htlc_id : int htlc_id : int
class TestHTLCManager(ElectrumTestCase): class TestHTLCManager(ElectrumTestCase):
def test_adding_htlcs_race(self): def test_adding_htlcs_race(self):
A = HTLCManager() A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager() B = HTLCManager(StoredDict({}, None, []))
A.channel_open_finished() A.channel_open_finished()
B.channel_open_finished() B.channel_open_finished()
ah0, bh0 = H('A', 0), H('B', 0) ah0, bh0 = H('A', 0), H('B', 0)
@ -61,8 +61,8 @@ class TestHTLCManager(ElectrumTestCase):
def test_single_htlc_full_lifecycle(self): def test_single_htlc_full_lifecycle(self):
def htlc_lifecycle(htlc_success: bool): def htlc_lifecycle(htlc_success: bool):
A = HTLCManager() A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager() B = HTLCManager(StoredDict({}, None, []))
A.channel_open_finished() A.channel_open_finished()
B.channel_open_finished() B.channel_open_finished()
B.recv_htlc(A.send_htlc(H('A', 0))) B.recv_htlc(A.send_htlc(H('A', 0)))
@ -134,8 +134,8 @@ class TestHTLCManager(ElectrumTestCase):
def test_remove_htlc_while_owing_commitment(self): def test_remove_htlc_while_owing_commitment(self):
def htlc_lifecycle(htlc_success: bool): def htlc_lifecycle(htlc_success: bool):
A = HTLCManager() A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager() B = HTLCManager(StoredDict({}, None, []))
A.channel_open_finished() A.channel_open_finished()
B.channel_open_finished() B.channel_open_finished()
ah0 = H('A', 0) ah0 = H('A', 0)
@ -171,8 +171,8 @@ class TestHTLCManager(ElectrumTestCase):
htlc_lifecycle(htlc_success=False) htlc_lifecycle(htlc_success=False)
def test_adding_htlc_between_send_ctx_and_recv_rev(self): def test_adding_htlc_between_send_ctx_and_recv_rev(self):
A = HTLCManager() A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager() B = HTLCManager(StoredDict({}, None, []))
A.channel_open_finished() A.channel_open_finished()
B.channel_open_finished() B.channel_open_finished()
A.send_ctx() A.send_ctx()
@ -217,8 +217,8 @@ class TestHTLCManager(ElectrumTestCase):
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
def test_unacked_local_updates(self): def test_unacked_local_updates(self):
A = HTLCManager() A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager() B = HTLCManager(StoredDict({}, None, []))
A.channel_open_finished() A.channel_open_finished()
B.channel_open_finished() B.channel_open_finished()
self.assertEqual({}, A.get_unacked_local_updates()) self.assertEqual({}, A.get_unacked_local_updates())

14
electrum/tests/test_lnutil.py

@ -2,13 +2,14 @@ import unittest
import json import json
from electrum import bitcoin from electrum import bitcoin
from electrum.json_db import StoredDict
from electrum.lnutil import (RevocationStore, get_per_commitment_secret_from_seed, make_offered_htlc, from electrum.lnutil import (RevocationStore, get_per_commitment_secret_from_seed, make_offered_htlc,
make_received_htlc, make_commitment, make_htlc_tx_witness, make_htlc_tx_output, make_received_htlc, make_commitment, make_htlc_tx_witness, make_htlc_tx_output,
make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey, make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey,
derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret, derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret,
get_compressed_pubkey_from_bech32, split_host_port, ConnStringFormatError, get_compressed_pubkey_from_bech32, split_host_port, ConnStringFormatError,
ScriptHtlc, extract_nodeid, calc_onchain_fees, UpdateAddHtlc) ScriptHtlc, extract_nodeid, calc_onchain_fees, UpdateAddHtlc)
from electrum.util import bh2u, bfh from electrum.util import bh2u, bfh, MyEncoder
from electrum.transaction import Transaction, PartialTransaction from electrum.transaction import Transaction, PartialTransaction
from . import ElectrumTestCase from . import ElectrumTestCase
@ -422,7 +423,7 @@ class TestLNUtil(ElectrumTestCase):
] ]
for test in tests: for test in tests:
receiver = RevocationStore({}) receiver = RevocationStore(StoredDict({}, None, []))
for insert in test["inserts"]: for insert in test["inserts"]:
secret = bytes.fromhex(insert["secret"]) secret = bytes.fromhex(insert["secret"])
@ -445,14 +446,19 @@ class TestLNUtil(ElectrumTestCase):
def test_shachain_produce_consume(self): def test_shachain_produce_consume(self):
seed = bitcoin.sha256(b"shachaintest") seed = bitcoin.sha256(b"shachaintest")
consumer = RevocationStore({}) consumer = RevocationStore(StoredDict({}, None, []))
for i in range(10000): for i in range(10000):
secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i) secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i)
try: try:
consumer.add_next_entry(secret) consumer.add_next_entry(secret)
except Exception as e: except Exception as e:
raise Exception("iteration " + str(i) + ": " + str(e)) raise Exception("iteration " + str(i) + ": " + str(e))
if i % 1000 == 0: self.assertEqual(consumer.serialize(), RevocationStore(json.loads(json.dumps(consumer.serialize()))).serialize()) if i % 1000 == 0:
c1 = consumer
s1 = json.dumps(c1.storage, cls=MyEncoder)
c2 = RevocationStore(StoredDict(json.loads(s1), None, []))
s2 = json.dumps(c2.storage, cls=MyEncoder)
self.assertEqual(s1, s2)
def test_commitment_tx_with_all_five_HTLCs_untrimmed_minimum_feerate(self): def test_commitment_tx_with_all_five_HTLCs_untrimmed_minimum_feerate(self):
to_local_msat = 6988000000 to_local_msat = 6988000000

2
electrum/util.py

@ -280,6 +280,8 @@ class MyEncoder(json.JSONEncoder):
return obj.isoformat(' ')[:-3] return obj.isoformat(' ')[:-3]
if isinstance(obj, set): if isinstance(obj, set):
return list(obj) return list(obj)
if isinstance(obj, bytes): # for nametuples in lnchannel
return obj.hex()
if hasattr(obj, 'to_json') and callable(obj.to_json): if hasattr(obj, 'to_json') and callable(obj.to_json):
return obj.to_json() return obj.to_json()
return super(MyEncoder, self).default(obj) return super(MyEncoder, self).default(obj)

32
electrum/wallet.py

@ -240,21 +240,13 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
# saved fields # saved fields
self.use_change = storage.get('use_change', True) self.use_change = storage.get('use_change', True)
self.multiple_change = storage.get('multiple_change', False) self.multiple_change = storage.get('multiple_change', False)
self.labels = storage.get('labels', {}) self.labels = storage.db.get_dict('labels')
self.frozen_addresses = set(storage.get('frozen_addresses', [])) self.frozen_addresses = set(storage.get('frozen_addresses', []))
self.frozen_coins = set(storage.get('frozen_coins', [])) # set of txid:vout strings self.frozen_coins = set(storage.get('frozen_coins', [])) # set of txid:vout strings
self.fiat_value = storage.get('fiat_value', {}) self.fiat_value = storage.db.get_dict('fiat_value')
self.receive_requests = storage.get('payment_requests', {}) self.receive_requests = storage.db.get_dict('payment_requests')
self.invoices = storage.get('invoices', {}) self.invoices = storage.db.get_dict('invoices')
# convert invoices
# TODO invoices being these contextual dicts even internally,
# where certain keys are only present depending on values of other keys...
# it's horrible. we need to change this, at least for the internal representation,
# to something that can be typed.
for invoice_key, invoice in self.invoices.items():
if invoice.get('type') == PR_TYPE_ONCHAIN:
outputs = [PartialTxOutput.from_legacy_tuple(*output) for output in invoice.get('outputs')]
invoice['outputs'] = outputs
self._prepare_onchain_invoice_paid_detection() self._prepare_onchain_invoice_paid_detection()
self.calc_unused_change_addresses() self.calc_unused_change_addresses()
# save wallet type the first time # save wallet type the first time
@ -372,7 +364,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
changed = True changed = True
if changed: if changed:
run_hook('set_label', self, name, text) run_hook('set_label', self, name, text)
self.storage.put('labels', self.labels)
return changed return changed
def set_fiat_value(self, txid, ccy, text, fx, value_sat): def set_fiat_value(self, txid, ccy, text, fx, value_sat):
@ -404,7 +395,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
if ccy not in self.fiat_value: if ccy not in self.fiat_value:
self.fiat_value[ccy] = {} self.fiat_value[ccy] = {}
self.fiat_value[ccy][txid] = text self.fiat_value[ccy][txid] = text
self.storage.put('fiat_value', self.fiat_value)
return reset return reset
def get_fiat_value(self, txid, ccy): def get_fiat_value(self, txid, ccy):
@ -625,12 +615,10 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
else: else:
raise Exception('Unsupported invoice type') raise Exception('Unsupported invoice type')
self.invoices[key] = invoice self.invoices[key] = invoice
self.storage.put('invoices', self.invoices)
self.storage.write() self.storage.write()
def clear_invoices(self): def clear_invoices(self):
self.invoices = {} self.invoices = {}
self.storage.put('invoices', self.invoices)
self.storage.write() self.storage.write()
def get_invoices(self): def get_invoices(self):
@ -642,7 +630,8 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
def get_invoice(self, key): def get_invoice(self, key):
if key not in self.invoices: if key not in self.invoices:
return return
item = copy.copy(self.invoices[key]) # convert StoredDict to dict
item = dict(self.invoices[key])
request_type = item.get('type') request_type = item.get('type')
if request_type == PR_TYPE_ONCHAIN: if request_type == PR_TYPE_ONCHAIN:
item['status'] = PR_PAID if self.is_onchain_invoice_paid(item) else PR_UNPAID item['status'] = PR_PAID if self.is_onchain_invoice_paid(item) else PR_UNPAID
@ -1553,7 +1542,8 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
req = self.receive_requests.get(key) req = self.receive_requests.get(key)
if not req: if not req:
return return
req = copy.copy(req) # convert StoredDict to dict
req = dict(req)
_type = req.get('type') _type = req.get('type')
if _type == PR_TYPE_ONCHAIN: if _type == PR_TYPE_ONCHAIN:
addr = req['address'] addr = req['address']
@ -1610,7 +1600,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
req['name'] = pr.pki_data req['name'] = pr.pki_data
req['sig'] = bh2u(pr.signature) req['sig'] = bh2u(pr.signature)
self.receive_requests[key] = req self.receive_requests[key] = req
self.storage.put('payment_requests', self.receive_requests)
def add_payment_request(self, req): def add_payment_request(self, req):
if req['type'] == PR_TYPE_ONCHAIN: if req['type'] == PR_TYPE_ONCHAIN:
@ -1628,7 +1617,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
raise Exception('Unknown request type') raise Exception('Unknown request type')
amount = req.get('amount') amount = req.get('amount')
self.receive_requests[key] = req self.receive_requests[key] = req
self.storage.put('payment_requests', self.receive_requests)
self.set_label(key, message) # should be a default label self.set_label(key, message) # should be a default label
return req return req
@ -1643,7 +1631,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
""" lightning or on-chain """ """ lightning or on-chain """
if key in self.invoices: if key in self.invoices:
self.invoices.pop(key) self.invoices.pop(key)
self.storage.put('invoices', self.invoices)
elif self.lnworker: elif self.lnworker:
self.lnworker.delete_payment(key) self.lnworker.delete_payment(key)
@ -1651,7 +1638,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
if addr not in self.receive_requests: if addr not in self.receive_requests:
return False return False
self.receive_requests.pop(addr) self.receive_requests.pop(addr)
self.storage.put('payment_requests', self.receive_requests)
return True return True
def get_sorted_requests(self): def get_sorted_requests(self):

97
electrum/wallet_db.py

@ -29,12 +29,16 @@ import copy
import threading import threading
from collections import defaultdict from collections import defaultdict
from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence
import binascii
from . import util, bitcoin from . import util, bitcoin
from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh, PR_TYPE_ONCHAIN
from .keystore import bip44_derivation from .keystore import bip44_derivation
from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput
from .json_db import JsonDB, locked, modifier from .logging import Logger
from .lnutil import LOCAL, REMOTE, FeeUpdate, UpdateAddHtlc, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, RevocationStore
from .lnutil import ChannelConstraints, Outpoint, ShachainElement
from .json_db import StoredDict, JsonDB, locked, modifier
# seed_version is now used for the version of the wallet file # seed_version is now used for the version of the wallet file
@ -44,17 +48,12 @@ FINAL_SEED_VERSION = 24 # electrum >= 2.7 will set this to prevent
# old versions from overwriting new format # old versions from overwriting new format
class TxFeesValue(NamedTuple): class TxFeesValue(NamedTuple):
fee: Optional[int] = None fee: Optional[int] = None
is_calculated_by_us: bool = False is_calculated_by_us: bool = False
num_inputs: Optional[int] = None num_inputs: Optional[int] = None
class WalletDB(JsonDB): class WalletDB(JsonDB):
def __init__(self, raw, *, manual_upgrades: bool): def __init__(self, raw, *, manual_upgrades: bool):
@ -67,7 +66,6 @@ class WalletDB(JsonDB):
self.put('seed_version', FINAL_SEED_VERSION) self.put('seed_version', FINAL_SEED_VERSION)
self._after_upgrade_tasks() self._after_upgrade_tasks()
def load_data(self, s): def load_data(self, s):
try: try:
self.data = json.loads(s) self.data = json.loads(s)
@ -833,7 +831,7 @@ class WalletDB(JsonDB):
self.tx_fees.pop(txid, None) self.tx_fees.pop(txid, None)
@locked @locked
def get_data_ref(self, name): def get_dict(self, name):
# Warning: interacts un-intuitively with 'put': certain parts # Warning: interacts un-intuitively with 'put': certain parts
# of 'data' will have pointers saved as separate variables. # of 'data' will have pointers saved as separate variables.
if name not in self.data: if name not in self.data:
@ -895,9 +893,9 @@ class WalletDB(JsonDB):
def load_addresses(self, wallet_type): def load_addresses(self, wallet_type):
""" called from Abstract_Wallet.__init__ """ """ called from Abstract_Wallet.__init__ """
if wallet_type == 'imported': if wallet_type == 'imported':
self.imported_addresses = self.get_data_ref('addresses') # type: Dict[str, dict] self.imported_addresses = self.get_dict('addresses') # type: Dict[str, dict]
else: else:
self.get_data_ref('addresses') self.get_dict('addresses')
for name in ['receiving', 'change']: for name in ['receiving', 'change']:
if name not in self.data['addresses']: if name not in self.data['addresses']:
self.data['addresses'][name] = [] self.data['addresses'][name] = []
@ -911,26 +909,20 @@ class WalletDB(JsonDB):
@profiler @profiler
def _load_transactions(self): def _load_transactions(self):
self.data = StoredDict(self.data, self, [])
# references in self.data # references in self.data
# TODO make all these private # TODO make all these private
# txid -> address -> set of (prev_outpoint, value) # txid -> address -> set of (prev_outpoint, value)
self.txi = self.get_data_ref('txi') # type: Dict[str, Dict[str, Set[Tuple[str, int]]]] self.txi = self.get_dict('txi') # type: Dict[str, Dict[str, Set[Tuple[str, int]]]]
# txid -> address -> set of (output_index, value, is_coinbase) # txid -> address -> set of (output_index, value, is_coinbase)
self.txo = self.get_data_ref('txo') # type: Dict[str, Dict[str, Set[Tuple[int, int, bool]]]] self.txo = self.get_dict('txo') # type: Dict[str, Dict[str, Set[Tuple[int, int, bool]]]]
self.transactions = self.get_data_ref('transactions') # type: Dict[str, Transaction] self.transactions = self.get_dict('transactions') # type: Dict[str, Transaction]
self.spent_outpoints = self.get_data_ref('spent_outpoints') # txid -> output_index -> next_txid self.spent_outpoints = self.get_dict('spent_outpoints') # txid -> output_index -> next_txid
self.history = self.get_data_ref('addr_history') # address -> list of (txid, height) self.history = self.get_dict('addr_history') # address -> list of (txid, height)
self.verified_tx = self.get_data_ref('verified_tx3') # txid -> (height, timestamp, txpos, header_hash) self.verified_tx = self.get_dict('verified_tx3') # txid -> (height, timestamp, txpos, header_hash)
self.tx_fees = self.get_data_ref('tx_fees') # type: Dict[str, TxFeesValue] self.tx_fees = self.get_dict('tx_fees') # type: Dict[str, TxFeesValue]
# scripthash -> set of (outpoint, value) # scripthash -> set of (outpoint, value)
self._prevouts_by_scripthash = self.get_data_ref('prevouts_by_scripthash') # type: Dict[str, Set[Tuple[str, int]]] self._prevouts_by_scripthash = self.get_dict('prevouts_by_scripthash') # type: Dict[str, Set[Tuple[str, int]]]
# convert raw transactions to Transaction objects
for tx_hash, raw_tx in self.transactions.items():
# note: for performance, "deserialize=False" so that we will deserialize these on-demand
self.transactions[tx_hash] = tx_from_any(raw_tx, deserialize=False)
# convert prevouts_by_scripthash: list to set, list to tuple
for scripthash, lst in self._prevouts_by_scripthash.items():
self._prevouts_by_scripthash[scripthash] = {(prevout, value) for prevout, value in lst}
# remove unreferenced tx # remove unreferenced tx
for tx_hash in list(self.transactions.keys()): for tx_hash in list(self.transactions.keys()):
if not self.get_txi_addresses(tx_hash) and not self.get_txo_addresses(tx_hash): if not self.get_txi_addresses(tx_hash) and not self.get_txo_addresses(tx_hash):
@ -943,9 +935,15 @@ class WalletDB(JsonDB):
if spending_txid not in self.transactions: if spending_txid not in self.transactions:
self.logger.info("removing unreferenced spent outpoint") self.logger.info("removing unreferenced spent outpoint")
d.pop(prevout_n) d.pop(prevout_n)
# convert tx_fees tuples to NamedTuples # convert invoices
for tx_hash, tuple_ in self.tx_fees.items(): # TODO invoices being these contextual dicts even internally,
self.tx_fees[tx_hash] = TxFeesValue(*tuple_) # where certain keys are only present depending on values of other keys...
# it's horrible. we need to change this, at least for the internal representation,
# to something that can be typed.
self.invoices = self.get_dict('invoices')
for invoice_key, invoice in self.invoices.items():
if invoice.get('type') == PR_TYPE_ONCHAIN:
invoice['outputs'] = [PartialTxOutput.from_legacy_tuple(*output) for output in invoice.get('outputs')]
@modifier @modifier
def clear_history(self): def clear_history(self):
@ -956,3 +954,42 @@ class WalletDB(JsonDB):
self.history.clear() self.history.clear()
self.verified_tx.clear() self.verified_tx.clear()
self.tx_fees.clear() self.tx_fees.clear()
def _convert_dict(self, path, key, v):
if key == 'transactions':
# note: for performance, "deserialize=False" so that we will deserialize these on-demand
v = dict((k, tx_from_any(x, deserialize=False)) for k, x in v.items())
elif key == 'adds':
v = dict((k, UpdateAddHtlc(*x)) for k, x in v.items())
elif key == 'fee_updates':
v = dict((k, FeeUpdate(**x)) for k, x in v.items())
elif key == 'tx_fees':
v = dict((k, TxFeesValue(*x)) for k, x in v.items())
elif key == 'prevouts_by_scripthash':
v = dict((k, {(prevout, value) for (prevout, value) in x}) for k, x in v.items())
elif key == 'buckets':
v = dict((k, ShachainElement(bfh(x[0]), int(x[1]))) for k, x in v.items())
return v
def _convert_value(self, path, key, v):
if key == 'local_config':
v = LocalConfig(**v)
elif key == 'remote_config':
v = RemoteConfig(**v)
elif key == 'constraints':
v = ChannelConstraints(**v)
elif key == 'funding_outpoint':
v = Outpoint(**v)
elif key.endswith("_basepoint") or key.endswith("_key"):
v = Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v)
elif key in [
"short_channel_id",
"current_per_commitment_point",
"next_per_commitment_point",
"per_commitment_secret_seed",
"current_commitment_signature",
"current_htlc_signatures"]:
v = binascii.unhexlify(v) if v is not None else None
elif len(path) > 2 and path[-2] in ['local_config', 'remote_config'] and key in ["pubkey", "privkey"]:
v = binascii.unhexlify(v) if v is not None else None
return v

Loading…
Cancel
Save