diff --git a/electrum/lnbase.py b/electrum/lnbase.py index e8ecf96d7..ee4c630b7 100644 --- a/electrum/lnbase.py +++ b/electrum/lnbase.py @@ -197,15 +197,14 @@ def gen_msg(msg_type: str, **kwargs) -> bytes: class Peer(PrintError): - def __init__(self, lnworker: 'LNWorker', peer_addr: LNPeerAddr, responding=False, - request_initial_sync=False, transport: LNTransportBase=None): - self.responding = responding + def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase, + request_initial_sync=False): self.initialized = asyncio.Event() self.transport = transport - self.peer_addr = peer_addr + self.pubkey = pubkey self.lnworker = lnworker self.privkey = lnworker.node_keypair.privkey - self.node_ids = [peer_addr.pubkey, privkey_to_pubkey(self.privkey)] + self.node_ids = [self.pubkey, privkey_to_pubkey(self.privkey)] self.network = lnworker.network self.lnwatcher = lnworker.network.lnwatcher self.channel_db = lnworker.network.channel_db @@ -233,19 +232,14 @@ class Peer(PrintError): self.transport.send_bytes(gen_msg(message_name, **kwargs)) async def initialize(self): - if not self.transport: - reader, writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port) - transport = LNTransport(self.privkey, self.peer_addr.pubkey, reader, writer) - await transport.handshake() - self.transport = transport self.send_message("init", gflen=0, lflen=1, localfeatures=self.localfeatures) @property def channels(self) -> Dict[bytes, Channel]: - return self.lnworker.channels_for_peer(self.peer_addr.pubkey) + return self.lnworker.channels_for_peer(self.pubkey) def diagnostic_name(self): - return str(self.peer_addr.host) + ':' + str(self.peer_addr.port) + return self.transport.name() def ping_if_required(self): if time.time() - self.ping_time > 120: @@ -352,7 +346,7 @@ class Peer(PrintError): self.print_error("disconnecting gracefully. {}".format(e)) finally: self.close_and_cleanup() - self.lnworker.peers.pop(self.peer_addr.pubkey) + self.lnworker.peers.pop(self.pubkey) return wrapper_func @ignore_exceptions # do not kill main_taskgroup @@ -373,8 +367,6 @@ class Peer(PrintError): except (OSError, asyncio.TimeoutError, HandshakeFailed) as e: self.print_error('initialize failed, disconnecting: {}'.format(repr(e))) return - if not self.responding: - self.channel_db.add_recent_peer(self.peer_addr) # loop async for msg in self.transport.read_messages(): self.process_message(msg) @@ -513,7 +505,7 @@ class Peer(PrintError): # remote commitment transaction channel_id, funding_txid_bytes = channel_id_from_funding_tx(funding_txid, funding_index) chan_dict = { - "node_id": self.peer_addr.pubkey, + "node_id": self.pubkey, "channel_id": channel_id, "short_channel_id": None, "funding_outpoint": Outpoint(funding_txid, funding_index), @@ -587,7 +579,7 @@ class Peer(PrintError): remote_dust_limit_sat = int.from_bytes(payload['dust_limit_satoshis'], byteorder='big') # TODO validate remote_reserve_sat = self.validate_remote_reserve(payload['channel_reserve_satoshis'], remote_dust_limit_sat, funding_sat) chan_dict = { - "node_id": self.peer_addr.pubkey, + "node_id": self.pubkey, "channel_id": channel_id, "short_channel_id": None, "funding_outpoint": Outpoint(funding_txid, funding_idx), @@ -794,7 +786,7 @@ class Peer(PrintError): remote_bitcoin_sig = announcement_signatures_msg["bitcoin_signature"] if not ecc.verify_signature(chan.config[REMOTE].multisig_key.pubkey, remote_bitcoin_sig, h): raise Exception("bitcoin_sig invalid in announcement_signatures") - if not ecc.verify_signature(self.peer_addr.pubkey, remote_node_sig, h): + if not ecc.verify_signature(self.pubkey, remote_node_sig, h): raise Exception("node_sig invalid in announcement_signatures") node_sigs = [remote_node_sig, local_node_sig] diff --git a/electrum/lntransport.py b/electrum/lntransport.py index 88c26f839..671e81564 100644 --- a/electrum/lntransport.py +++ b/electrum/lntransport.py @@ -6,6 +6,7 @@ # Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8 import hashlib +import asyncio from asyncio import StreamReader, StreamWriter from Cryptodome.Cipher import ChaCha20_Poly1305 @@ -87,10 +88,6 @@ def create_ephemeral_key() -> (bytes, bytes): class LNTransportBase: - def __init__(self, reader: StreamReader, writer: StreamWriter): - self.reader = reader - self.writer = writer - def send_bytes(self, msg): l = len(msg).to_bytes(2, 'big') lc = aead_encrypt(self.sk, self.sn(), b'', l) @@ -153,12 +150,16 @@ class LNTransportBase: class LNResponderTransport(LNTransportBase): def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter): - LNTransportBase.__init__(self, reader, writer) + LNTransportBase.__init__(self) + self.reader = reader + self.writer = writer self.privkey = privkey + def name(self): + return "responder" + async def handshake(self, **kwargs): hs = HandshakeState(privkey_to_pubkey(self.privkey)) - act1 = b'' while len(act1) < 50: act1 += await self.reader.read(50 - len(act1)) @@ -205,14 +206,20 @@ class LNResponderTransport(LNTransportBase): return rs class LNTransport(LNTransportBase): - def __init__(self, privkey: bytes, remote_pubkey: bytes, - reader: StreamReader, writer: StreamWriter): - LNTransportBase.__init__(self, reader, writer) + + def __init__(self, privkey: bytes, peer_addr): + LNTransportBase.__init__(self) assert type(privkey) is bytes and len(privkey) == 32 self.privkey = privkey - self.remote_pubkey = remote_pubkey + self.remote_pubkey = peer_addr.pubkey + self.host = peer_addr.host + self.port = peer_addr.port + + def name(self): + return str(self.host) + ':' + str(self.port) async def handshake(self): + self.reader, self.writer = await asyncio.open_connection(self.host, self.port) hs = HandshakeState(self.remote_pubkey) # Get a new ephemeral key epriv, epub = create_ephemeral_key() diff --git a/electrum/lnworker.py b/electrum/lnworker.py index d3f96c109..1453e070b 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -28,7 +28,7 @@ from .crypto import sha256 from .bip32 import bip32_root from .util import bh2u, bfh, PrintError, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions from .util import timestamp_to_datetime -from .lntransport import LNResponderTransport +from .lntransport import LNTransport, LNResponderTransport from .lnbase import Peer from .lnaddr import lnencode, LnAddr, lndecode from .ecc import der_sig_from_sig_string @@ -244,13 +244,16 @@ class LNWorker(PrintError): return {x: y for (x, y) in self.channels.items() if y.node_id == node_id} async def add_peer(self, host, port, node_id): - port = int(port) - peer_addr = LNPeerAddr(host, port, node_id) if node_id in self.peers: return self.peers[node_id] + port = int(port) + peer_addr = LNPeerAddr(host, port, node_id) + transport = LNTransport(self.node_keypair.privkey, peer_addr) + await transport.handshake() + self.channel_db.add_recent_peer(peer_addr) self._last_tried_peer[peer_addr] = time.time() self.print_error("adding peer", peer_addr) - peer = Peer(self, peer_addr, request_initial_sync=self.config.get("request_initial_sync", True)) + peer = Peer(self, node_id, transport, request_initial_sync=self.config.get("request_initial_sync", True)) await self.network.main_taskgroup.spawn(peer.main_loop()) self.peers[node_id] = peer self.network.trigger_callback('ln_status') @@ -797,16 +800,13 @@ class LNWorker(PrintError): # ipv6 addr = addr[1:-1] async def cb(reader, writer): - t = LNResponderTransport(self.node_keypair.privkey, reader, writer) + transport = LNResponderTransport(self.node_keypair.privkey, reader, writer) try: - node_id = await t.handshake() + node_id = await transport.handshake() except: self.print_error('handshake failure from incoming connection') return - # FIXME extract host and port from transport - peer = Peer(self, LNPeerAddr("bogus", 1337, node_id), responding=True, - request_initial_sync=self.config.get("request_initial_sync", True), - transport=t) + peer = Peer(self, node_id, transport, request_initial_sync=self.config.get("request_initial_sync", True)) self.peers[node_id] = peer await self.network.main_taskgroup.spawn(peer.main_loop()) self.network.trigger_callback('ln_status') diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py index be3447d79..8c7400faa 100644 --- a/electrum/tests/test_lnbase.py +++ b/electrum/tests/test_lnbase.py @@ -113,6 +113,9 @@ class MockTransport: def __init__(self): self.queue = asyncio.Queue() + def name(self): + return "" + async def read_messages(self): while True: yield await self.queue.get() @@ -150,7 +153,7 @@ class TestPeer(unittest.TestCase): def test_require_data_loss_protect(self): mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None) mock_transport = NoFeaturesTransport() - p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport) + p1 = Peer(mock_lnworker, b"\x00" * 33, mock_transport, request_initial_sync=False) mock_lnworker.peer = p1 with self.assertRaises(LightningPeerConnectionClosed): run(asyncio.wait_for(p1._main_loop(), 1)) @@ -161,10 +164,8 @@ class TestPeer(unittest.TestCase): q1, q2 = asyncio.Queue(), asyncio.Queue() w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1) w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2) - p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey), - request_initial_sync=False, transport=t1) - p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey), - request_initial_sync=False, transport=t2) + p1 = Peer(w1, k1.pubkey, t1, request_initial_sync=False) + p2 = Peer(w2, k2.pubkey, t2, request_initial_sync=False) w1.peer = p1 w2.peer = p2 # mark_open won't work if state is already OPEN.