Browse Source

lnbase: mark initialized later, add tests, etc

- consistent node_id sorting
- require OPTION_DATA_LOSS_PROTECT and test it
dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 6 years ago
committed by ThomasV
parent
commit
85789d8a09
  1. 50
      electrum/lnbase.py
  2. 66
      electrum/tests/test_lnbase.py

50
electrum/lnbase.py

@ -201,6 +201,7 @@ class Peer(PrintError):
self.peer_addr = peer_addr
self.lnworker = lnworker
self.privkey = lnworker.node_keypair.privkey
self.node_ids = [peer_addr.pubkey, privkey_to_pubkey(self.privkey)]
self.network = lnworker.network
self.lnwatcher = lnworker.network.lnwatcher
self.channel_db = lnworker.network.channel_db
@ -218,7 +219,7 @@ class Peer(PrintError):
self.localfeatures = LnLocalFeatures(0)
if request_initial_sync:
self.localfeatures |= LnLocalFeatures.INITIAL_ROUTING_SYNC
self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT
self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ
self.attempted_route = {}
self.orphan_channel_updates = OrderedDict()
@ -234,7 +235,6 @@ class Peer(PrintError):
await transport.handshake()
self.transport = transport
self.send_message("init", gflen=0, lflen=1, localfeatures=self.localfeatures)
self.initialized.set_result(True)
@property
def channels(self) -> Dict[bytes, Channel]:
@ -310,6 +310,7 @@ class Peer(PrintError):
raise LightningPeerConnectionClosed("remote does not have even flag {}"
.format(str(LnLocalFeatures(1 << flag))))
self.localfeatures ^= 1 << flag # disable flag
self.initialized.set_result(True)
def on_channel_update(self, payload):
try:
@ -349,6 +350,13 @@ class Peer(PrintError):
@log_exceptions
@handle_disconnect
async def main_loop(self):
"""
This is used from the GUI. It is not merged with the other function,
so that we can test if the correct exceptions are getting thrown.
"""
await self._main_loop()
async def _main_loop(self):
try:
await asyncio.wait_for(self.initialize(), 10)
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
@ -757,16 +765,17 @@ class Peer(PrintError):
if not ecc.verify_signature(self.peer_addr.pubkey, remote_node_sig, h):
raise Exception("node_sig invalid in announcement_signatures")
node_sigs = [local_node_sig, remote_node_sig]
bitcoin_sigs = [local_bitcoin_sig, remote_bitcoin_sig]
node_ids = [privkey_to_pubkey(self.privkey), self.peer_addr.pubkey]
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
node_sigs = [remote_node_sig, local_node_sig]
bitcoin_sigs = [remote_bitcoin_sig, local_bitcoin_sig]
bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey, chan.config[LOCAL].multisig_key.pubkey]
if node_ids[0] > node_ids[1]:
if self.node_ids[0] > self.node_ids[1]:
node_sigs.reverse()
bitcoin_sigs.reverse()
node_ids.reverse()
node_ids = list(reversed(self.node_ids))
bitcoin_keys.reverse()
else:
node_ids = self.node_ids
self.send_message("channel_announcement",
node_signatures_1=node_sigs[0],
@ -793,14 +802,13 @@ class Peer(PrintError):
chan.set_state("OPEN")
self.network.trigger_callback('channel', chan)
# add channel to database
pubkey_ours = self.lnworker.node_keypair.pubkey
pubkey_theirs = self.peer_addr.pubkey
node_ids = [pubkey_theirs, pubkey_ours]
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
sorted_node_ids = list(sorted(node_ids))
if sorted_node_ids != node_ids:
sorted_node_ids = list(sorted(self.node_ids))
if sorted_node_ids != self.node_ids:
node_ids = sorted_node_ids
bitcoin_keys.reverse()
else:
node_ids = self.node_ids
# note: we inject a channel announcement, and a channel update (for outgoing direction)
# This is atm needed for
# - finding routes
@ -813,7 +821,10 @@ class Peer(PrintError):
'bitcoin_key_1': bitcoin_keys[0], 'bitcoin_key_2': bitcoin_keys[1]},
trusted=True)
# only inject outgoing direction:
channel_flags = b'\x00' if node_ids[0] == pubkey_ours else b'\x01'
if node_ids[0] == privkey_to_pubkey(self.privkey):
channel_flags = b'\x00'
else:
channel_flags = b'\x01'
now = int(time.time()).to_bytes(4, byteorder="big")
self.channel_db.on_channel_update({"short_channel_id": chan.short_channel_id, 'channel_flags': channel_flags, 'cltv_expiry_delta': b'\x90',
'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01',
@ -832,16 +843,15 @@ class Peer(PrintError):
def send_announcement_signatures(self, chan):
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey,
chan.config[REMOTE].multisig_key.pubkey]
node_ids = [privkey_to_pubkey(self.privkey),
self.peer_addr.pubkey]
bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey,
chan.config[LOCAL].multisig_key.pubkey]
sorted_node_ids = list(sorted(node_ids))
sorted_node_ids = list(sorted(self.node_ids))
if sorted_node_ids != node_ids:
node_ids = sorted_node_ids
bitcoin_keys.reverse()
else:
node_ids = self.node_ids
chan_ann = gen_msg("channel_announcement",
len=0,

66
electrum/tests/test_lnbase.py

@ -0,0 +1,66 @@
from electrum.lnbase import Peer, decode_msg, gen_msg
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.ecc import ECPrivkey
from electrum.lnrouter import ChannelDB
import unittest
import asyncio
from electrum import simple_config
import tempfile
from .test_lnchan import create_test_channels
class MockNetwork:
def __init__(self):
self.lnwatcher = None
user_config = {}
user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-")
self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir)
self.asyncio_loop = asyncio.get_event_loop()
self.channel_db = ChannelDB(self)
self.interface = None
def register_callback(self, cb, trigger_names):
print("callback registered", repr(trigger_names))
def trigger_callback(self, trigger_name, obj):
print("callback triggered", repr(trigger_name))
class MockLNWorker:
def __init__(self, remote_peer_pubkey, chan):
self.chan = chan
self.remote_peer_pubkey = remote_peer_pubkey
priv = ECPrivkey.generate_random_key().get_secret_bytes()
self.node_keypair = Keypair(
pubkey=privkey_to_pubkey(priv),
privkey=priv)
self.network = MockNetwork()
@property
def peers(self):
return {self.remote_peer_pubkey: self.peer}
def channels_for_peer(self, pubkey):
return {self.chan.channel_id: self.chan}
class MockTransport:
def __init__(self):
self.queue = asyncio.Queue()
async def read_messages(self):
while True:
yield await self.queue.get()
class BadFeaturesTransport(MockTransport):
def send_bytes(self, data):
decoded = decode_msg(data)
print(decoded)
if decoded[0] == 'init':
self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
class TestPeer(unittest.TestCase):
def setUp(self):
self.alice_channel, self.bob_channel = create_test_channels()
def test_bad_feature_flags(self):
# we should require DATA_LOSS_PROTECT
mock_lnworker = MockLNWorker(b"\x00" * 32, self.alice_channel)
mock_transport = BadFeaturesTransport()
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 32), request_initial_sync=False, transport=mock_transport)
mock_lnworker.peer = p1
with self.assertRaises(LightningPeerConnectionClosed):
asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
Loading…
Cancel
Save