Browse Source

create transport and perform handshake before creating Peer

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
b5482e4470
  1. 28
      electrum/lnbase.py
  2. 27
      electrum/lntransport.py
  3. 20
      electrum/lnworker.py
  4. 11
      electrum/tests/test_lnbase.py

28
electrum/lnbase.py

@ -197,15 +197,14 @@ def gen_msg(msg_type: str, **kwargs) -> bytes:
class Peer(PrintError): class Peer(PrintError):
def __init__(self, lnworker: 'LNWorker', peer_addr: LNPeerAddr, responding=False, def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase,
request_initial_sync=False, transport: LNTransportBase=None): request_initial_sync=False):
self.responding = responding
self.initialized = asyncio.Event() self.initialized = asyncio.Event()
self.transport = transport self.transport = transport
self.peer_addr = peer_addr self.pubkey = pubkey
self.lnworker = lnworker self.lnworker = lnworker
self.privkey = lnworker.node_keypair.privkey self.privkey = lnworker.node_keypair.privkey
self.node_ids = [peer_addr.pubkey, privkey_to_pubkey(self.privkey)] self.node_ids = [self.pubkey, privkey_to_pubkey(self.privkey)]
self.network = lnworker.network self.network = lnworker.network
self.lnwatcher = lnworker.network.lnwatcher self.lnwatcher = lnworker.network.lnwatcher
self.channel_db = lnworker.network.channel_db self.channel_db = lnworker.network.channel_db
@ -233,19 +232,14 @@ class Peer(PrintError):
self.transport.send_bytes(gen_msg(message_name, **kwargs)) self.transport.send_bytes(gen_msg(message_name, **kwargs))
async def initialize(self): async def initialize(self):
if not self.transport:
reader, writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port)
transport = LNTransport(self.privkey, self.peer_addr.pubkey, reader, writer)
await transport.handshake()
self.transport = transport
self.send_message("init", gflen=0, lflen=1, localfeatures=self.localfeatures) self.send_message("init", gflen=0, lflen=1, localfeatures=self.localfeatures)
@property @property
def channels(self) -> Dict[bytes, Channel]: def channels(self) -> Dict[bytes, Channel]:
return self.lnworker.channels_for_peer(self.peer_addr.pubkey) return self.lnworker.channels_for_peer(self.pubkey)
def diagnostic_name(self): def diagnostic_name(self):
return str(self.peer_addr.host) + ':' + str(self.peer_addr.port) return self.transport.name()
def ping_if_required(self): def ping_if_required(self):
if time.time() - self.ping_time > 120: if time.time() - self.ping_time > 120:
@ -352,7 +346,7 @@ class Peer(PrintError):
self.print_error("disconnecting gracefully. {}".format(e)) self.print_error("disconnecting gracefully. {}".format(e))
finally: finally:
self.close_and_cleanup() self.close_and_cleanup()
self.lnworker.peers.pop(self.peer_addr.pubkey) self.lnworker.peers.pop(self.pubkey)
return wrapper_func return wrapper_func
@ignore_exceptions # do not kill main_taskgroup @ignore_exceptions # do not kill main_taskgroup
@ -373,8 +367,6 @@ class Peer(PrintError):
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e: except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
self.print_error('initialize failed, disconnecting: {}'.format(repr(e))) self.print_error('initialize failed, disconnecting: {}'.format(repr(e)))
return return
if not self.responding:
self.channel_db.add_recent_peer(self.peer_addr)
# loop # loop
async for msg in self.transport.read_messages(): async for msg in self.transport.read_messages():
self.process_message(msg) self.process_message(msg)
@ -513,7 +505,7 @@ class Peer(PrintError):
# remote commitment transaction # remote commitment transaction
channel_id, funding_txid_bytes = channel_id_from_funding_tx(funding_txid, funding_index) channel_id, funding_txid_bytes = channel_id_from_funding_tx(funding_txid, funding_index)
chan_dict = { chan_dict = {
"node_id": self.peer_addr.pubkey, "node_id": self.pubkey,
"channel_id": channel_id, "channel_id": channel_id,
"short_channel_id": None, "short_channel_id": None,
"funding_outpoint": Outpoint(funding_txid, funding_index), "funding_outpoint": Outpoint(funding_txid, funding_index),
@ -587,7 +579,7 @@ class Peer(PrintError):
remote_dust_limit_sat = int.from_bytes(payload['dust_limit_satoshis'], byteorder='big') # TODO validate remote_dust_limit_sat = int.from_bytes(payload['dust_limit_satoshis'], byteorder='big') # TODO validate
remote_reserve_sat = self.validate_remote_reserve(payload['channel_reserve_satoshis'], remote_dust_limit_sat, funding_sat) remote_reserve_sat = self.validate_remote_reserve(payload['channel_reserve_satoshis'], remote_dust_limit_sat, funding_sat)
chan_dict = { chan_dict = {
"node_id": self.peer_addr.pubkey, "node_id": self.pubkey,
"channel_id": channel_id, "channel_id": channel_id,
"short_channel_id": None, "short_channel_id": None,
"funding_outpoint": Outpoint(funding_txid, funding_idx), "funding_outpoint": Outpoint(funding_txid, funding_idx),
@ -794,7 +786,7 @@ class Peer(PrintError):
remote_bitcoin_sig = announcement_signatures_msg["bitcoin_signature"] remote_bitcoin_sig = announcement_signatures_msg["bitcoin_signature"]
if not ecc.verify_signature(chan.config[REMOTE].multisig_key.pubkey, remote_bitcoin_sig, h): if not ecc.verify_signature(chan.config[REMOTE].multisig_key.pubkey, remote_bitcoin_sig, h):
raise Exception("bitcoin_sig invalid in announcement_signatures") raise Exception("bitcoin_sig invalid in announcement_signatures")
if not ecc.verify_signature(self.peer_addr.pubkey, remote_node_sig, h): if not ecc.verify_signature(self.pubkey, remote_node_sig, h):
raise Exception("node_sig invalid in announcement_signatures") raise Exception("node_sig invalid in announcement_signatures")
node_sigs = [remote_node_sig, local_node_sig] node_sigs = [remote_node_sig, local_node_sig]

27
electrum/lntransport.py

@ -6,6 +6,7 @@
# Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8 # Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8
import hashlib import hashlib
import asyncio
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from Cryptodome.Cipher import ChaCha20_Poly1305 from Cryptodome.Cipher import ChaCha20_Poly1305
@ -87,10 +88,6 @@ def create_ephemeral_key() -> (bytes, bytes):
class LNTransportBase: class LNTransportBase:
def __init__(self, reader: StreamReader, writer: StreamWriter):
self.reader = reader
self.writer = writer
def send_bytes(self, msg): def send_bytes(self, msg):
l = len(msg).to_bytes(2, 'big') l = len(msg).to_bytes(2, 'big')
lc = aead_encrypt(self.sk, self.sn(), b'', l) lc = aead_encrypt(self.sk, self.sn(), b'', l)
@ -153,12 +150,16 @@ class LNTransportBase:
class LNResponderTransport(LNTransportBase): class LNResponderTransport(LNTransportBase):
def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter): def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
LNTransportBase.__init__(self, reader, writer) LNTransportBase.__init__(self)
self.reader = reader
self.writer = writer
self.privkey = privkey self.privkey = privkey
def name(self):
return "responder"
async def handshake(self, **kwargs): async def handshake(self, **kwargs):
hs = HandshakeState(privkey_to_pubkey(self.privkey)) hs = HandshakeState(privkey_to_pubkey(self.privkey))
act1 = b'' act1 = b''
while len(act1) < 50: while len(act1) < 50:
act1 += await self.reader.read(50 - len(act1)) act1 += await self.reader.read(50 - len(act1))
@ -205,14 +206,20 @@ class LNResponderTransport(LNTransportBase):
return rs return rs
class LNTransport(LNTransportBase): class LNTransport(LNTransportBase):
def __init__(self, privkey: bytes, remote_pubkey: bytes,
reader: StreamReader, writer: StreamWriter): def __init__(self, privkey: bytes, peer_addr):
LNTransportBase.__init__(self, reader, writer) LNTransportBase.__init__(self)
assert type(privkey) is bytes and len(privkey) == 32 assert type(privkey) is bytes and len(privkey) == 32
self.privkey = privkey self.privkey = privkey
self.remote_pubkey = remote_pubkey self.remote_pubkey = peer_addr.pubkey
self.host = peer_addr.host
self.port = peer_addr.port
def name(self):
return str(self.host) + ':' + str(self.port)
async def handshake(self): async def handshake(self):
self.reader, self.writer = await asyncio.open_connection(self.host, self.port)
hs = HandshakeState(self.remote_pubkey) hs = HandshakeState(self.remote_pubkey)
# Get a new ephemeral key # Get a new ephemeral key
epriv, epub = create_ephemeral_key() epriv, epub = create_ephemeral_key()

20
electrum/lnworker.py

@ -28,7 +28,7 @@ from .crypto import sha256
from .bip32 import bip32_root from .bip32 import bip32_root
from .util import bh2u, bfh, PrintError, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions from .util import bh2u, bfh, PrintError, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions
from .util import timestamp_to_datetime from .util import timestamp_to_datetime
from .lntransport import LNResponderTransport from .lntransport import LNTransport, LNResponderTransport
from .lnbase import Peer from .lnbase import Peer
from .lnaddr import lnencode, LnAddr, lndecode from .lnaddr import lnencode, LnAddr, lndecode
from .ecc import der_sig_from_sig_string from .ecc import der_sig_from_sig_string
@ -244,13 +244,16 @@ class LNWorker(PrintError):
return {x: y for (x, y) in self.channels.items() if y.node_id == node_id} return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
async def add_peer(self, host, port, node_id): async def add_peer(self, host, port, node_id):
port = int(port)
peer_addr = LNPeerAddr(host, port, node_id)
if node_id in self.peers: if node_id in self.peers:
return self.peers[node_id] return self.peers[node_id]
port = int(port)
peer_addr = LNPeerAddr(host, port, node_id)
transport = LNTransport(self.node_keypair.privkey, peer_addr)
await transport.handshake()
self.channel_db.add_recent_peer(peer_addr)
self._last_tried_peer[peer_addr] = time.time() self._last_tried_peer[peer_addr] = time.time()
self.print_error("adding peer", peer_addr) self.print_error("adding peer", peer_addr)
peer = Peer(self, peer_addr, request_initial_sync=self.config.get("request_initial_sync", True)) peer = Peer(self, node_id, transport, request_initial_sync=self.config.get("request_initial_sync", True))
await self.network.main_taskgroup.spawn(peer.main_loop()) await self.network.main_taskgroup.spawn(peer.main_loop())
self.peers[node_id] = peer self.peers[node_id] = peer
self.network.trigger_callback('ln_status') self.network.trigger_callback('ln_status')
@ -797,16 +800,13 @@ class LNWorker(PrintError):
# ipv6 # ipv6
addr = addr[1:-1] addr = addr[1:-1]
async def cb(reader, writer): async def cb(reader, writer):
t = LNResponderTransport(self.node_keypair.privkey, reader, writer) transport = LNResponderTransport(self.node_keypair.privkey, reader, writer)
try: try:
node_id = await t.handshake() node_id = await transport.handshake()
except: except:
self.print_error('handshake failure from incoming connection') self.print_error('handshake failure from incoming connection')
return return
# FIXME extract host and port from transport peer = Peer(self, node_id, transport, request_initial_sync=self.config.get("request_initial_sync", True))
peer = Peer(self, LNPeerAddr("bogus", 1337, node_id), responding=True,
request_initial_sync=self.config.get("request_initial_sync", True),
transport=t)
self.peers[node_id] = peer self.peers[node_id] = peer
await self.network.main_taskgroup.spawn(peer.main_loop()) await self.network.main_taskgroup.spawn(peer.main_loop())
self.network.trigger_callback('ln_status') self.network.trigger_callback('ln_status')

11
electrum/tests/test_lnbase.py

@ -113,6 +113,9 @@ class MockTransport:
def __init__(self): def __init__(self):
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
def name(self):
return ""
async def read_messages(self): async def read_messages(self):
while True: while True:
yield await self.queue.get() yield await self.queue.get()
@ -150,7 +153,7 @@ class TestPeer(unittest.TestCase):
def test_require_data_loss_protect(self): def test_require_data_loss_protect(self):
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None) mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
mock_transport = NoFeaturesTransport() mock_transport = NoFeaturesTransport()
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport) p1 = Peer(mock_lnworker, b"\x00" * 33, mock_transport, request_initial_sync=False)
mock_lnworker.peer = p1 mock_lnworker.peer = p1
with self.assertRaises(LightningPeerConnectionClosed): with self.assertRaises(LightningPeerConnectionClosed):
run(asyncio.wait_for(p1._main_loop(), 1)) run(asyncio.wait_for(p1._main_loop(), 1))
@ -161,10 +164,8 @@ class TestPeer(unittest.TestCase):
q1, q2 = asyncio.Queue(), asyncio.Queue() q1, q2 = asyncio.Queue(), asyncio.Queue()
w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1) w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1)
w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2) w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2)
p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey), p1 = Peer(w1, k1.pubkey, t1, request_initial_sync=False)
request_initial_sync=False, transport=t1) p2 = Peer(w2, k2.pubkey, t2, request_initial_sync=False)
p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
request_initial_sync=False, transport=t2)
w1.peer = p1 w1.peer = p1
w2.peer = p2 w2.peer = p2
# mark_open won't work if state is already OPEN. # mark_open won't work if state is already OPEN.

Loading…
Cancel
Save