Browse Source

Merge pull request #7502 from spesmilo/no_convert_key

No convert key
patch-4
ThomasV 3 years ago
committed by GitHub
parent
commit
d570002b83
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 25
      electrum/json_db.py
  2. 8
      electrum/lnchannel.py
  3. 19
      electrum/lnhtlc.py
  4. 4
      electrum/lnpeer.py
  5. 2
      electrum/tests/test_lnchannel.py
  6. 23
      electrum/wallet_db.py

25
electrum/json_db.py

@ -78,16 +78,8 @@ class StoredDict(dict):
for k, v in list(data.items()): for k, v in list(data.items()):
self.__setitem__(k, v) self.__setitem__(k, v)
def convert_key(self, key):
"""Convert int keys to str keys, as only those are allowed in json."""
# NOTE: this is evil. really hard to keep in mind and reason about. :(
# e.g.: imagine setting int keys everywhere, and then iterating over the dict:
# suddenly the keys are str...
return str(int(key)) if isinstance(key, int) else key
@locked @locked
def __setitem__(self, key, v): def __setitem__(self, key, v):
key = self.convert_key(key)
is_new = key not in self is_new = key not in self
# early return to prevent unnecessary disk writes # early return to prevent unnecessary disk writes
if not is_new and self[key] == v: if not is_new and self[key] == v:
@ -119,24 +111,12 @@ class StoredDict(dict):
@locked @locked
def __delitem__(self, key): def __delitem__(self, key):
key = self.convert_key(key)
dict.__delitem__(self, key) dict.__delitem__(self, key)
if self.db: if self.db:
self.db.set_modified(True) self.db.set_modified(True)
@locked
def __getitem__(self, key):
key = self.convert_key(key)
return dict.__getitem__(self, key)
@locked
def __contains__(self, key):
key = self.convert_key(key)
return dict.__contains__(self, key)
@locked @locked
def pop(self, key, v=_RaiseKeyError): def pop(self, key, v=_RaiseKeyError):
key = self.convert_key(key)
if v is _RaiseKeyError: if v is _RaiseKeyError:
r = dict.pop(self, key) r = dict.pop(self, key)
else: else:
@ -145,11 +125,6 @@ class StoredDict(dict):
self.db.set_modified(True) self.db.set_modified(True)
return r return r
@locked
def get(self, key, default=None):
key = self.convert_key(key)
return dict.get(self, key, default)

8
electrum/lnchannel.py

@ -563,6 +563,8 @@ class Channel(AbstractChannel):
self.onion_keys = state['onion_keys'] # type: Dict[int, bytes] self.onion_keys = state['onion_keys'] # type: Dict[int, bytes]
self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp'] self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp']
self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate) self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate)
self.fail_htlc_reasons = state["fail_htlc_reasons"]
self.unfulfilled_htlcs = state["unfulfilled_htlcs"]
self._state = ChannelState[state['state']] self._state = ChannelState[state['state']]
self.peer_state = PeerState.DISCONNECTED self.peer_state = PeerState.DISCONNECTED
self._sweep_info = {} self._sweep_info = {}
@ -912,7 +914,7 @@ class Channel(AbstractChannel):
remote_ctn = self.get_latest_ctn(REMOTE) remote_ctn = self.get_latest_ctn(REMOTE)
if onion_packet: if onion_packet:
# TODO neither local_ctn nor remote_ctn are used anymore... no point storing them. # TODO neither local_ctn nor remote_ctn are used anymore... no point storing them.
self.hm.log['unfulfilled_htlcs'][htlc.htlc_id] = local_ctn, remote_ctn, onion_packet.hex(), False self.unfulfilled_htlcs[htlc.htlc_id] = local_ctn, remote_ctn, onion_packet.hex(), False
self.logger.info("receive_htlc") self.logger.info("receive_htlc")
return htlc return htlc
@ -1071,10 +1073,10 @@ class Channel(AbstractChannel):
failure_message: Optional['OnionRoutingFailure']): failure_message: Optional['OnionRoutingFailure']):
error_hex = error_bytes.hex() if error_bytes else None error_hex = error_bytes.hex() if error_bytes else None
failure_hex = failure_message.to_bytes().hex() if failure_message else None failure_hex = failure_message.to_bytes().hex() if failure_message else None
self.hm.log['fail_htlc_reasons'][htlc_id] = (error_hex, failure_hex) self.fail_htlc_reasons[htlc_id] = (error_hex, failure_hex)
def pop_fail_htlc_reason(self, htlc_id): def pop_fail_htlc_reason(self, htlc_id):
error_hex, failure_hex = self.hm.log['fail_htlc_reasons'].pop(htlc_id, (None, None)) error_hex, failure_hex = self.fail_htlc_reasons.pop(htlc_id, (None, None))
error_bytes = bytes.fromhex(error_hex) if error_hex else None error_bytes = bytes.fromhex(error_hex) if error_hex else None
failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None
return error_bytes, failure_message return error_bytes, failure_message

19
electrum/lnhtlc.py

@ -27,12 +27,7 @@ class HTLCManager:
# note: "htlc_id" keys in dict are str! but due to json_db magic they can *almost* be treated as int... # note: "htlc_id" keys in dict are str! but due to json_db magic they can *almost* be treated as int...
log[LOCAL] = deepcopy(initial) log[LOCAL] = deepcopy(initial)
log[REMOTE] = deepcopy(initial) log[REMOTE] = deepcopy(initial)
log['unacked_local_updates2'] = {} log[LOCAL]['unacked_updates'] = {}
if 'unfulfilled_htlcs' not in log:
log['unfulfilled_htlcs'] = {} # htlc_id -> onion_packet
if 'fail_htlc_reasons' not in log:
log['fail_htlc_reasons'] = {} # htlc_id -> error_bytes, failure_message
# maybe bootstrap fee_updates if initial_feerate was provided # maybe bootstrap fee_updates if initial_feerate was provided
if initial_feerate is not None: if initial_feerate is not None:
@ -209,7 +204,7 @@ class HTLCManager:
fee_update.ctn_local = self.ctn_latest(LOCAL) + 1 fee_update.ctn_local = self.ctn_latest(LOCAL) + 1
# no need to keep local update raw msgs anymore, they have just been ACKed. # no need to keep local update raw msgs anymore, they have just been ACKed.
self.log['unacked_local_updates2'].pop(self.log[REMOTE]['ctn'], None) self.log[LOCAL]['unacked_updates'].pop(self.log[REMOTE]['ctn'], None)
@with_lock @with_lock
def _update_maybe_active_htlc_ids(self) -> None: def _update_maybe_active_htlc_ids(self) -> None:
@ -276,21 +271,21 @@ class HTLCManager:
"""We need to be able to replay unacknowledged updates we sent to the remote """We need to be able to replay unacknowledged updates we sent to the remote
in case of disconnections. Hence, raw update and commitment_signed messages in case of disconnections. Hence, raw update and commitment_signed messages
are stored temporarily (until they are acked).""" are stored temporarily (until they are acked)."""
# self.log['unacked_local_updates2'][ctn_idx] is a list of raw messages # self.log[LOCAL]['unacked_updates'][ctn_idx] is a list of raw messages
# containing some number of updates and then a single commitment_signed # containing some number of updates and then a single commitment_signed
if is_commitment_signed: if is_commitment_signed:
ctn_idx = self.ctn_latest(REMOTE) ctn_idx = self.ctn_latest(REMOTE)
else: else:
ctn_idx = self.ctn_latest(REMOTE) + 1 ctn_idx = self.ctn_latest(REMOTE) + 1
l = self.log['unacked_local_updates2'].get(ctn_idx, []) l = self.log[LOCAL]['unacked_updates'].get(ctn_idx, [])
l.append(raw_update_msg.hex()) l.append(raw_update_msg.hex())
self.log['unacked_local_updates2'][ctn_idx] = l self.log[LOCAL]['unacked_updates'][ctn_idx] = l
@with_lock @with_lock
def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]: def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]:
#return self.log['unacked_local_updates2'] #return self.log[LOCAL]['unacked_updates']
return {int(ctn): [bfh(msg) for msg in messages] return {int(ctn): [bfh(msg) for msg in messages]
for ctn, messages in self.log['unacked_local_updates2'].items()} for ctn, messages in self.log[LOCAL]['unacked_updates'].items()}
##### Queries re HTLCs: ##### Queries re HTLCs:

4
electrum/lnpeer.py

@ -770,6 +770,8 @@ class Peer(Logger):
'onion_keys': {}, 'onion_keys': {},
'data_loss_protect_remote_pcp': {}, 'data_loss_protect_remote_pcp': {},
"log": {}, "log": {},
"fail_htlc_reasons": {}, # htlc_id -> onion_packet
"unfulfilled_htlcs": {}, # htlc_id -> error_bytes, failure_message
"revocation_store": {}, "revocation_store": {},
"static_remotekey_enabled": self.is_static_remotekey(), # stored because it cannot be "downgraded", per BOLT2 "static_remotekey_enabled": self.is_static_remotekey(), # stored because it cannot be "downgraded", per BOLT2
} }
@ -1862,7 +1864,7 @@ class Peer(Logger):
continue continue
self.maybe_send_commitment(chan) self.maybe_send_commitment(chan)
done = set() done = set()
unfulfilled = chan.hm.log.get('unfulfilled_htlcs', {}) unfulfilled = chan.unfulfilled_htlcs
for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items(): for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items():
if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
continue continue

2
electrum/tests/test_lnchannel.py

@ -104,6 +104,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
'data_loss_protect_remote_pcp': {}, 'data_loss_protect_remote_pcp': {},
'state': 'PREOPENING', 'state': 'PREOPENING',
'log': {}, 'log': {},
'fail_htlc_reasons': {},
'unfulfilled_htlcs': {},
'revocation_store': {}, 'revocation_store': {},
} }
return StoredDict(state, None, []) return StoredDict(state, None, [])

23
electrum/wallet_db.py

@ -53,7 +53,7 @@ if TYPE_CHECKING:
OLD_SEED_VERSION = 4 # electrum versions < 2.0 OLD_SEED_VERSION = 4 # electrum versions < 2.0
NEW_SEED_VERSION = 11 # electrum versions >= 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0
FINAL_SEED_VERSION = 42 # electrum >= 2.7 will set this to prevent FINAL_SEED_VERSION = 43 # electrum >= 2.7 will set this to prevent
# old versions from overwriting new format # old versions from overwriting new format
@ -191,6 +191,7 @@ class WalletDB(JsonDB):
self._convert_version_40() self._convert_version_40()
self._convert_version_41() self._convert_version_41()
self._convert_version_42() self._convert_version_42()
self._convert_version_43()
self.put('seed_version', FINAL_SEED_VERSION) # just to be sure self.put('seed_version', FINAL_SEED_VERSION) # just to be sure
self._after_upgrade_tasks() self._after_upgrade_tasks()
@ -837,6 +838,18 @@ class WalletDB(JsonDB):
for _type, addr, val in item['outputs']] for _type, addr, val in item['outputs']]
self.data['seed_version'] = 42 self.data['seed_version'] = 42
def _convert_version_43(self):
if not self._is_upgrade_method_needed(42, 42):
return
channels = self.data.pop('channels', {})
for k, c in channels.items():
log = c['log']
c['fail_htlc_reasons'] = log.pop('fail_htlc_reasons', {})
c['unfulfilled_htlcs'] = log.pop('unfulfilled_htlcs', {})
log["1"]['unacked_updates'] = log.pop('unacked_local_updates2', {})
self.data['channels'] = channels
self.data['seed_version'] = 43
def _convert_imported(self): def _convert_imported(self):
if not self._is_upgrade_method_needed(0, 13): if not self._is_upgrade_method_needed(0, 13):
return return
@ -1344,6 +1357,14 @@ class WalletDB(JsonDB):
v = dict((k, ShachainElement(bfh(x[0]), int(x[1]))) for k, x in v.items()) v = dict((k, ShachainElement(bfh(x[0]), int(x[1]))) for k, x in v.items())
elif key == 'data_loss_protect_remote_pcp': elif key == 'data_loss_protect_remote_pcp':
v = dict((k, bfh(x)) for k, x in v.items()) v = dict((k, bfh(x)) for k, x in v.items())
# convert htlc_id keys to int
if key in ['adds', 'locked_in', 'settles', 'fails', 'fee_updates', 'buckets']:
v = dict((int(k), x) for k, x in v.items())
# convert keys to HTLCOwner
if key == 'log' or (path and path[-1] in ['locked_in', 'fails', 'settles']):
if "1" in v:
v[LOCAL] = v.pop("1")
v[REMOTE] = v.pop("-1")
return v return v
def _convert_value(self, path, key, v): def _convert_value(self, path, key, v):

Loading…
Cancel
Save