Browse Source

create transport and perform handshake before creating Peer

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 6 years ago
parent
commit
b5482e4470
  1. 28
      electrum/lnbase.py
  2. 27
      electrum/lntransport.py
  3. 20
      electrum/lnworker.py
  4. 11
      electrum/tests/test_lnbase.py

28
electrum/lnbase.py

@ -197,15 +197,14 @@ def gen_msg(msg_type: str, **kwargs) -> bytes:
class Peer(PrintError):
def __init__(self, lnworker: 'LNWorker', peer_addr: LNPeerAddr, responding=False,
request_initial_sync=False, transport: LNTransportBase=None):
self.responding = responding
def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase,
request_initial_sync=False):
self.initialized = asyncio.Event()
self.transport = transport
self.peer_addr = peer_addr
self.pubkey = pubkey
self.lnworker = lnworker
self.privkey = lnworker.node_keypair.privkey
self.node_ids = [peer_addr.pubkey, privkey_to_pubkey(self.privkey)]
self.node_ids = [self.pubkey, privkey_to_pubkey(self.privkey)]
self.network = lnworker.network
self.lnwatcher = lnworker.network.lnwatcher
self.channel_db = lnworker.network.channel_db
@ -233,19 +232,14 @@ class Peer(PrintError):
self.transport.send_bytes(gen_msg(message_name, **kwargs))
async def initialize(self):
if not self.transport:
reader, writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port)
transport = LNTransport(self.privkey, self.peer_addr.pubkey, reader, writer)
await transport.handshake()
self.transport = transport
self.send_message("init", gflen=0, lflen=1, localfeatures=self.localfeatures)
@property
def channels(self) -> Dict[bytes, Channel]:
return self.lnworker.channels_for_peer(self.peer_addr.pubkey)
return self.lnworker.channels_for_peer(self.pubkey)
def diagnostic_name(self):
return str(self.peer_addr.host) + ':' + str(self.peer_addr.port)
return self.transport.name()
def ping_if_required(self):
if time.time() - self.ping_time > 120:
@ -352,7 +346,7 @@ class Peer(PrintError):
self.print_error("disconnecting gracefully. {}".format(e))
finally:
self.close_and_cleanup()
self.lnworker.peers.pop(self.peer_addr.pubkey)
self.lnworker.peers.pop(self.pubkey)
return wrapper_func
@ignore_exceptions # do not kill main_taskgroup
@ -373,8 +367,6 @@ class Peer(PrintError):
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
self.print_error('initialize failed, disconnecting: {}'.format(repr(e)))
return
if not self.responding:
self.channel_db.add_recent_peer(self.peer_addr)
# loop
async for msg in self.transport.read_messages():
self.process_message(msg)
@ -513,7 +505,7 @@ class Peer(PrintError):
# remote commitment transaction
channel_id, funding_txid_bytes = channel_id_from_funding_tx(funding_txid, funding_index)
chan_dict = {
"node_id": self.peer_addr.pubkey,
"node_id": self.pubkey,
"channel_id": channel_id,
"short_channel_id": None,
"funding_outpoint": Outpoint(funding_txid, funding_index),
@ -587,7 +579,7 @@ class Peer(PrintError):
remote_dust_limit_sat = int.from_bytes(payload['dust_limit_satoshis'], byteorder='big') # TODO validate
remote_reserve_sat = self.validate_remote_reserve(payload['channel_reserve_satoshis'], remote_dust_limit_sat, funding_sat)
chan_dict = {
"node_id": self.peer_addr.pubkey,
"node_id": self.pubkey,
"channel_id": channel_id,
"short_channel_id": None,
"funding_outpoint": Outpoint(funding_txid, funding_idx),
@ -794,7 +786,7 @@ class Peer(PrintError):
remote_bitcoin_sig = announcement_signatures_msg["bitcoin_signature"]
if not ecc.verify_signature(chan.config[REMOTE].multisig_key.pubkey, remote_bitcoin_sig, h):
raise Exception("bitcoin_sig invalid in announcement_signatures")
if not ecc.verify_signature(self.peer_addr.pubkey, remote_node_sig, h):
if not ecc.verify_signature(self.pubkey, remote_node_sig, h):
raise Exception("node_sig invalid in announcement_signatures")
node_sigs = [remote_node_sig, local_node_sig]

27
electrum/lntransport.py

@ -6,6 +6,7 @@
# Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8
import hashlib
import asyncio
from asyncio import StreamReader, StreamWriter
from Cryptodome.Cipher import ChaCha20_Poly1305
@ -87,10 +88,6 @@ def create_ephemeral_key() -> (bytes, bytes):
class LNTransportBase:
def __init__(self, reader: StreamReader, writer: StreamWriter):
self.reader = reader
self.writer = writer
def send_bytes(self, msg):
l = len(msg).to_bytes(2, 'big')
lc = aead_encrypt(self.sk, self.sn(), b'', l)
@ -153,12 +150,16 @@ class LNTransportBase:
class LNResponderTransport(LNTransportBase):
def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
LNTransportBase.__init__(self, reader, writer)
LNTransportBase.__init__(self)
self.reader = reader
self.writer = writer
self.privkey = privkey
def name(self):
return "responder"
async def handshake(self, **kwargs):
hs = HandshakeState(privkey_to_pubkey(self.privkey))
act1 = b''
while len(act1) < 50:
act1 += await self.reader.read(50 - len(act1))
@ -205,14 +206,20 @@ class LNResponderTransport(LNTransportBase):
return rs
class LNTransport(LNTransportBase):
def __init__(self, privkey: bytes, remote_pubkey: bytes,
reader: StreamReader, writer: StreamWriter):
LNTransportBase.__init__(self, reader, writer)
def __init__(self, privkey: bytes, peer_addr):
LNTransportBase.__init__(self)
assert type(privkey) is bytes and len(privkey) == 32
self.privkey = privkey
self.remote_pubkey = remote_pubkey
self.remote_pubkey = peer_addr.pubkey
self.host = peer_addr.host
self.port = peer_addr.port
def name(self):
return str(self.host) + ':' + str(self.port)
async def handshake(self):
self.reader, self.writer = await asyncio.open_connection(self.host, self.port)
hs = HandshakeState(self.remote_pubkey)
# Get a new ephemeral key
epriv, epub = create_ephemeral_key()

20
electrum/lnworker.py

@ -28,7 +28,7 @@ from .crypto import sha256
from .bip32 import bip32_root
from .util import bh2u, bfh, PrintError, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions
from .util import timestamp_to_datetime
from .lntransport import LNResponderTransport
from .lntransport import LNTransport, LNResponderTransport
from .lnbase import Peer
from .lnaddr import lnencode, LnAddr, lndecode
from .ecc import der_sig_from_sig_string
@ -244,13 +244,16 @@ class LNWorker(PrintError):
return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
async def add_peer(self, host, port, node_id):
port = int(port)
peer_addr = LNPeerAddr(host, port, node_id)
if node_id in self.peers:
return self.peers[node_id]
port = int(port)
peer_addr = LNPeerAddr(host, port, node_id)
transport = LNTransport(self.node_keypair.privkey, peer_addr)
await transport.handshake()
self.channel_db.add_recent_peer(peer_addr)
self._last_tried_peer[peer_addr] = time.time()
self.print_error("adding peer", peer_addr)
peer = Peer(self, peer_addr, request_initial_sync=self.config.get("request_initial_sync", True))
peer = Peer(self, node_id, transport, request_initial_sync=self.config.get("request_initial_sync", True))
await self.network.main_taskgroup.spawn(peer.main_loop())
self.peers[node_id] = peer
self.network.trigger_callback('ln_status')
@ -797,16 +800,13 @@ class LNWorker(PrintError):
# ipv6
addr = addr[1:-1]
async def cb(reader, writer):
t = LNResponderTransport(self.node_keypair.privkey, reader, writer)
transport = LNResponderTransport(self.node_keypair.privkey, reader, writer)
try:
node_id = await t.handshake()
node_id = await transport.handshake()
except:
self.print_error('handshake failure from incoming connection')
return
# FIXME extract host and port from transport
peer = Peer(self, LNPeerAddr("bogus", 1337, node_id), responding=True,
request_initial_sync=self.config.get("request_initial_sync", True),
transport=t)
peer = Peer(self, node_id, transport, request_initial_sync=self.config.get("request_initial_sync", True))
self.peers[node_id] = peer
await self.network.main_taskgroup.spawn(peer.main_loop())
self.network.trigger_callback('ln_status')

11
electrum/tests/test_lnbase.py

@ -113,6 +113,9 @@ class MockTransport:
def __init__(self):
self.queue = asyncio.Queue()
def name(self):
return ""
async def read_messages(self):
while True:
yield await self.queue.get()
@ -150,7 +153,7 @@ class TestPeer(unittest.TestCase):
def test_require_data_loss_protect(self):
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
mock_transport = NoFeaturesTransport()
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
p1 = Peer(mock_lnworker, b"\x00" * 33, mock_transport, request_initial_sync=False)
mock_lnworker.peer = p1
with self.assertRaises(LightningPeerConnectionClosed):
run(asyncio.wait_for(p1._main_loop(), 1))
@ -161,10 +164,8 @@ class TestPeer(unittest.TestCase):
q1, q2 = asyncio.Queue(), asyncio.Queue()
w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1)
w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2)
p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
request_initial_sync=False, transport=t1)
p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
request_initial_sync=False, transport=t2)
p1 = Peer(w1, k1.pubkey, t1, request_initial_sync=False)
p2 = Peer(w2, k2.pubkey, t2, request_initial_sync=False)
w1.peer = p1
w2.peer = p2
# mark_open won't work if state is already OPEN.

Loading…
Cancel
Save