diff --git a/electrum/lnbase.py b/electrum/lnbase.py index 8ea9b4e9e..4bfb36bfa 100644 --- a/electrum/lnbase.py +++ b/electrum/lnbase.py @@ -10,7 +10,7 @@ import asyncio import os import time from functools import partial -from typing import List, Tuple, Dict, TYPE_CHECKING +from typing import List, Tuple, Dict, TYPE_CHECKING, Optional, Callable import traceback import sys @@ -40,7 +40,7 @@ if TYPE_CHECKING: from .lnworker import LNWorker -def channel_id_from_funding_tx(funding_txid, funding_index): +def channel_id_from_funding_tx(funding_txid: str, funding_index: int) -> Tuple[bytes, bytes]: funding_txid_bytes = bytes.fromhex(funding_txid)[::-1] i = int.from_bytes(funding_txid_bytes, 'big') ^ funding_index return i.to_bytes(32, 'big'), funding_txid_bytes @@ -48,7 +48,7 @@ def channel_id_from_funding_tx(funding_txid, funding_index): message_types = {} -def handlesingle(x, ma): +def handlesingle(x, ma: dict) -> int: """ Evaluate a term of the simple language used to specify lightning message field lengths. @@ -57,7 +57,7 @@ def handlesingle(x, ma): otherwise it is treated as a variable and looked up in `ma`. - It the value in `ma` was no integer, it is + If the value in `ma` was no integer, it is assumed big-endian bytes and decoded. Returns int @@ -72,7 +72,7 @@ def handlesingle(x, ma): x = int.from_bytes(x, byteorder='big') return x -def calcexp(exp, ma): +def calcexp(exp, ma: dict) -> int: """ Evaluate simple mathematical expression given in `exp` with variables assigned in the dict `ma` @@ -88,7 +88,7 @@ def calcexp(exp, ma): return result return sum(handlesingle(x, ma) for x in exp.split("+")) -def make_handler(k, v): +def make_handler(k: str, v: dict) -> Callable[[bytes], Tuple[str, dict]]: """ Generate a message handler function (taking bytes) for message type `k` with specification `v` @@ -100,7 +100,7 @@ def make_handler(k, v): Returns function taking bytes """ - def handler(data): + def handler(data: bytes) -> Tuple[str, dict]: nonlocal k, v ma = {} pos = 0 @@ -223,7 +223,7 @@ class Peer(PrintError): self.attempted_route = {} self.orphan_channel_updates = OrderedDict() - def send_message(self, message_name, **kwargs): + def send_message(self, message_name: str, **kwargs): assert type(message_name) is str self.print_error("Sending '%s'"%message_name.upper()) self.transport.send_bytes(gen_msg(message_name, **kwargs)) @@ -350,13 +350,10 @@ class Peer(PrintError): @log_exceptions @handle_disconnect async def main_loop(self): - """ - This is used from the GUI. It is not merged with the other function, - so that we can test if the correct exceptions are getting thrown. - """ await self._main_loop() async def _main_loop(self): + """This is separate from main_loop for the tests.""" try: await asyncio.wait_for(self.initialize(), 10) except (OSError, asyncio.TimeoutError, HandshakeFailed) as e: @@ -378,7 +375,7 @@ class Peer(PrintError): chan.set_state('DISCONNECTED') self.network.trigger_callback('channel', chan) - def make_local_config(self, funding_sat, push_msat, initiator: HTLCOwner): + def make_local_config(self, funding_sat: int, push_msat: int, initiator: HTLCOwner) -> Tuple[ChannelConfig, bytes]: # key derivation channel_counter = self.lnworker.get_and_inc_counter_for_channel_keys() keypair_generator = lambda family: generate_keypair(self.lnworker.ln_keystore, family, channel_counter) @@ -406,7 +403,8 @@ class Peer(PrintError): return local_config, per_commitment_secret_seed @log_exceptions - async def channel_establishment_flow(self, password, funding_sat, push_msat, temp_channel_id): + async def channel_establishment_flow(self, password: Optional[str], funding_sat: int, + push_msat: int, temp_channel_id: bytes) -> Channel: wallet = self.lnworker.wallet # dry run creating funding tx to see if we even have enough funds funding_tx_test = wallet.mktx([TxOutput(bitcoin.TYPE_ADDRESS, wallet.dummy_address(), funding_sat)], @@ -478,7 +476,7 @@ class Peer(PrintError): funding_index = funding_tx.outputs().index(funding_output) # remote commitment transaction channel_id, funding_txid_bytes = channel_id_from_funding_tx(funding_txid, funding_index) - chan = { + chan_dict = { "node_id": self.peer_addr.pubkey, "channel_id": channel_id, "short_channel_id": None, @@ -495,10 +493,10 @@ class Peer(PrintError): "constraints": ChannelConstraints(capacity=funding_sat, is_initiator=True, funding_txn_minimum_depth=funding_txn_minimum_depth, feerate=feerate), "remote_commitment_to_be_revoked": None, } - m = Channel(chan) - m.lnwatcher = self.lnwatcher - m.sweep_address = self.lnworker.sweep_address - sig_64, _ = m.sign_next_commitment() + chan = Channel(chan_dict) + chan.lnwatcher = self.lnwatcher + chan.sweep_address = self.lnworker.sweep_address + sig_64, _ = chan.sign_next_commitment() self.send_message("funding_created", temporary_channel_id=temp_channel_id, funding_txid=funding_txid_bytes, @@ -507,14 +505,14 @@ class Peer(PrintError): payload = await self.funding_signed[channel_id].get() self.print_error('received funding_signed') remote_sig = payload['signature'] - m.receive_new_commitment(remote_sig, []) + chan.receive_new_commitment(remote_sig, []) # broadcast funding tx await self.network.broadcast_transaction(funding_tx) - m.remote_commitment_to_be_revoked = m.pending_remote_commitment - m.config[REMOTE] = m.config[REMOTE]._replace(ctn=0) - m.config[LOCAL] = m.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig) - m.set_state('OPENING') - return m + chan.remote_commitment_to_be_revoked = chan.pending_remote_commitment + chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0) + chan.config[LOCAL] = chan.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig) + chan.set_state('OPENING') + return chan async def on_open_channel(self, payload): # payload['channel_flags'] @@ -555,8 +553,9 @@ class Peer(PrintError): channel_id, funding_txid_bytes = channel_id_from_funding_tx(funding_txid, funding_idx) their_revocation_store = RevocationStore() remote_balance_sat = funding_sat * 1000 - push_msat + remote_dust_limit_sat = int.from_bytes(payload['dust_limit_satoshis'], byteorder='big') remote_reserve_sat = self.validate_remote_reserve(payload['channel_reserve_satoshis'], remote_dust_limit_sat, funding_sat) - chan = { + chan_dict = { "node_id": self.peer_addr.pubkey, "channel_id": channel_id, "short_channel_id": None, @@ -568,7 +567,7 @@ class Peer(PrintError): delayed_basepoint=OnlyPubkeyKeypair(payload['delayed_payment_basepoint']), revocation_basepoint=OnlyPubkeyKeypair(payload['revocation_basepoint']), to_self_delay=int.from_bytes(payload['to_self_delay'], 'big'), - dust_limit_sat=int.from_bytes(payload['dust_limit_satoshis'], 'big'), + dust_limit_sat=remote_dust_limit_sat, max_htlc_value_in_flight_msat=int.from_bytes(payload['max_htlc_value_in_flight_msat'], 'big'), max_accepted_htlcs=int.from_bytes(payload['max_accepted_htlcs'], 'big'), initial_msat=remote_balance_sat, @@ -592,22 +591,22 @@ class Peer(PrintError): "constraints": ChannelConstraints(capacity=funding_sat, is_initiator=False, funding_txn_minimum_depth=min_depth, feerate=feerate), "remote_commitment_to_be_revoked": None, } - m = Channel(chan) - m.lnwatcher = self.lnwatcher - m.sweep_address = self.lnworker.sweep_address + chan = Channel(chan_dict) + chan.lnwatcher = self.lnwatcher + chan.sweep_address = self.lnworker.sweep_address remote_sig = funding_created['signature'] - m.receive_new_commitment(remote_sig, []) - sig_64, _ = m.sign_next_commitment() + chan.receive_new_commitment(remote_sig, []) + sig_64, _ = chan.sign_next_commitment() self.send_message('funding_signed', channel_id=channel_id, signature=sig_64, ) - m.set_state('OPENING') - m.remote_commitment_to_be_revoked = m.pending_remote_commitment - m.config[REMOTE] = m.config[REMOTE]._replace(ctn=0) - m.config[LOCAL] = m.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig) - self.lnworker.save_channel(m) - self.lnwatcher.watch_channel(m.get_funding_address(), m.funding_outpoint.to_str()) + chan.set_state('OPENING') + chan.remote_commitment_to_be_revoked = chan.pending_remote_commitment + chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0) + chan.config[LOCAL] = chan.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig) + self.lnworker.save_channel(chan) + self.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str()) self.lnworker.on_channels_updated() while True: try: @@ -618,13 +617,13 @@ class Peer(PrintError): else: break outp = funding_tx.outputs()[funding_idx] - redeem_script = funding_output_script(m.config[REMOTE], m.config[LOCAL]) + redeem_script = funding_output_script(chan.config[REMOTE], chan.config[LOCAL]) funding_address = bitcoin.redeem_script_to_address('p2wsh', redeem_script) if outp != TxOutput(bitcoin.TYPE_ADDRESS, funding_address, funding_sat): - m.set_state('DISCONNECTED') + chan.set_state('DISCONNECTED') raise Exception('funding outpoint mismatch') - def validate_remote_reserve(self, payload_field, dust_limit, funding_sat): + def validate_remote_reserve(self, payload_field: bytes, dust_limit: int, funding_sat: int) -> int: remote_reserve_sat = int.from_bytes(payload_field, 'big') if remote_reserve_sat < dust_limit: raise Exception('protocol violation: reserve < dust_limit') @@ -633,7 +632,7 @@ class Peer(PrintError): return remote_reserve_sat @log_exceptions - async def reestablish_channel(self, chan): + async def reestablish_channel(self, chan: Channel): await self.initialized chan_id = chan.channel_id if chan.get_state() != 'DISCONNECTED': @@ -712,7 +711,7 @@ class Peer(PrintError): # checks done self.channel_reestablished[chan_id].set_result(True) - def funding_locked(self, chan): + def funding_locked(self, chan: Channel): channel_id = chan.channel_id per_commitment_secret_index = RevocationStore.START_INDEX - 1 per_commitment_point_second = secret_to_pubkey(int.from_bytes( @@ -739,7 +738,7 @@ class Peer(PrintError): if chan.short_channel_id: self.mark_open(chan) - def on_network_update(self, chan, funding_tx_depth): + def on_network_update(self, chan: Channel, funding_tx_depth: int): """ Only called when the channel is OPEN. @@ -794,7 +793,7 @@ class Peer(PrintError): print("SENT CHANNEL ANNOUNCEMENT") - def mark_open(self, chan): + def mark_open(self, chan: Channel): if chan.get_state() == "OPEN": return # NOTE: even closed channels will be temporarily marked "OPEN" @@ -841,13 +840,13 @@ class Peer(PrintError): self.print_error("CHANNEL OPENING COMPLETED") - def send_announcement_signatures(self, chan): + def send_announcement_signatures(self, chan: Channel): bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey, chan.config[LOCAL].multisig_key.pubkey] sorted_node_ids = list(sorted(self.node_ids)) - if sorted_node_ids != node_ids: + if sorted_node_ids != self.node_ids: node_ids = sorted_node_ids bitcoin_keys.reverse() else: @@ -944,12 +943,12 @@ class Peer(PrintError): else: self.network.path_finder.blacklist.add(short_chan_id) - def send_commitment(self, chan): + def send_commitment(self, chan: Channel): sig_64, htlc_sigs = chan.sign_next_commitment() self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs)) return len(htlc_sigs) - async def update_channel(self, chan, message_name, **kwargs): + async def update_channel(self, chan: Channel, message_name: str, **kwargs): """ generic channel update flow """ self.send_message(message_name, **kwargs) self.send_commitment(chan) @@ -957,7 +956,8 @@ class Peer(PrintError): await self.receive_commitment(chan) self.revoke(chan) - async def pay(self, route: List[RouteEdge], chan, amount_msat, payment_hash, min_final_cltv_expiry): + async def pay(self, route: List[RouteEdge], chan: Channel, amount_msat: int, + payment_hash: bytes, min_final_cltv_expiry: int): assert chan.get_state() == "OPEN", chan.get_state() assert amount_msat > 0, "amount_msat is not greater zero" # create onion packet @@ -974,25 +974,25 @@ class Peer(PrintError): self.print_error(f"starting payment. route: {route}") await self.update_channel(chan, "update_add_htlc", channel_id=chan.channel_id, id=htlc_id, cltv_expiry=cltv, amount_msat=amount_msat, payment_hash=payment_hash, onion_routing_packet=onion.to_bytes()) - async def receive_revoke(self, m): - revoke_and_ack_msg = await self.revoke_and_ack[m.channel_id].get() - m.receive_revocation(RevokeAndAck(revoke_and_ack_msg["per_commitment_secret"], revoke_and_ack_msg["next_per_commitment_point"])) - self.lnworker.save_channel(m) + async def receive_revoke(self, chan: Channel): + revoke_and_ack_msg = await self.revoke_and_ack[chan.channel_id].get() + chan.receive_revocation(RevokeAndAck(revoke_and_ack_msg["per_commitment_secret"], revoke_and_ack_msg["next_per_commitment_point"])) + self.lnworker.save_channel(chan) - def revoke(self, m): - rev, _ = m.revoke_current_commitment() - self.lnworker.save_channel(m) + def revoke(self, chan: Channel): + rev, _ = chan.revoke_current_commitment() + self.lnworker.save_channel(chan) self.send_message("revoke_and_ack", - channel_id=m.channel_id, + channel_id=chan.channel_id, per_commitment_secret=rev.per_commitment_secret, next_per_commitment_point=rev.next_per_commitment_point) - async def receive_commitment(self, m, commitment_signed_msg=None): + async def receive_commitment(self, chan: Channel, commitment_signed_msg=None): if commitment_signed_msg is None: - commitment_signed_msg = await self.commitment_signed[m.channel_id].get() + commitment_signed_msg = await self.commitment_signed[chan.channel_id].get() data = commitment_signed_msg["htlc_signature"] htlc_sigs = [data[i:i+64] for i in range(0, len(data), 64)] - m.receive_new_commitment(commitment_signed_msg["signature"], htlc_sigs) + chan.receive_new_commitment(commitment_signed_msg["signature"], htlc_sigs) return len(htlc_sigs) def on_commitment_signed(self, payload): @@ -1109,7 +1109,7 @@ class Peer(PrintError): channel_id = payload["channel_id"] self.channels[channel_id].receive_update_fee(int.from_bytes(payload["feerate_per_kw"], "big")) - async def bitcoin_fee_update(self, chan): + async def bitcoin_fee_update(self, chan: Channel): """ called when our fee estimates change """ @@ -1144,7 +1144,7 @@ class Peer(PrintError): self.closing_signed[chan_id].put_nowait(payload) @log_exceptions - async def close_channel(self, chan_id): + async def close_channel(self, chan_id: bytes): chan = self.channels[chan_id] self.shutdown_received[chan_id] = asyncio.Future() self.send_shutdown(chan) @@ -1167,12 +1167,12 @@ class Peer(PrintError): txid = await self._shutdown(chan, payload) self.print_error('Channel closed by remote peer', txid) - def send_shutdown(self, chan): + def send_shutdown(self, chan: Channel): scriptpubkey = bfh(bitcoin.address_to_script(chan.sweep_address)) self.send_message('shutdown', channel_id=chan.channel_id, len=len(scriptpubkey), scriptpubkey=scriptpubkey) @log_exceptions - async def _shutdown(self, chan, payload): + async def _shutdown(self, chan: Channel, payload): scriptpubkey = bfh(bitcoin.address_to_script(chan.sweep_address)) signature, fee, txid = chan.make_closing_tx(scriptpubkey, payload['scriptpubkey']) self.send_message('closing_signed', channel_id=chan.channel_id, fee_satoshis=fee, signature=signature) diff --git a/electrum/lnchan.py b/electrum/lnchan.py index 3783098a4..c38184cbf 100644 --- a/electrum/lnchan.py +++ b/electrum/lnchan.py @@ -781,7 +781,8 @@ class Channel(PrintError): ), htlcs=htlcs) - def make_closing_tx(self, local_script: bytes, remote_script: bytes, fee_sat: Optional[int] = None) -> (bytes, int): + def make_closing_tx(self, local_script: bytes, remote_script: bytes, + fee_sat: Optional[int]=None) -> Tuple[bytes, int, str]: if fee_sat is None: fee_sat = self.pending_local_fee