Browse Source

test_lnbase: add test that pays to another local electrum

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 6 years ago
committed by ThomasV
parent
commit
7e76e82152
  1. 5
      electrum/lnbase.py
  2. 12
      electrum/lnworker.py
  3. 155
      electrum/tests/test_lnbase.py

5
electrum/lnbase.py

@ -350,6 +350,11 @@ class Peer(PrintError):
@log_exceptions @log_exceptions
@handle_disconnect @handle_disconnect
async def main_loop(self): async def main_loop(self):
"""
This is used in LNWorker and is necessary so that we don't kill the main
task group. It is not merged with _main_loop, so that we can test if the
correct exceptions are getting thrown using _main_loop.
"""
await self._main_loop() await self._main_loop()
async def _main_loop(self): async def _main_loop(self):

12
electrum/lnworker.py

@ -32,7 +32,6 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
generate_keypair, LnKeyFamily, LOCAL, REMOTE, generate_keypair, LnKeyFamily, LOCAL, REMOTE,
UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE,
NUM_MAX_EDGES_IN_PAYMENT_PATH) NUM_MAX_EDGES_IN_PAYMENT_PATH)
from .lnaddr import lndecode
from .i18n import _ from .i18n import _
from .lnrouter import RouteEdge, is_route_sane_to_use from .lnrouter import RouteEdge, is_route_sane_to_use
@ -258,6 +257,15 @@ class LNWorker(PrintError):
return bh2u(chan.node_id) return bh2u(chan.node_id)
def pay(self, invoice, amount_sat=None): def pay(self, invoice, amount_sat=None):
"""
This is not merged with _pay so that we can run the test with
one thread only.
"""
addr, peer, coro = self._pay(invoice, amount_sat)
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
return addr, peer, fut
def _pay(self, invoice, amount_sat=None):
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
payment_hash = addr.paymenthash payment_hash = addr.paymenthash
amount_sat = (addr.amount * COIN) if addr.amount else amount_sat amount_sat = (addr.amount * COIN) if addr.amount else amount_sat
@ -279,7 +287,7 @@ class LNWorker(PrintError):
raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id))) raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
peer = self.peers[node_id] peer = self.peers[node_id]
coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry()) coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry())
return addr, peer, asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) return addr, peer, coro
def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]: def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]:
invoice_pubkey = decoded_invoice.pubkey.serialize() invoice_pubkey = decoded_invoice.pubkey.serialize()

155
electrum/tests/test_lnbase.py

@ -1,16 +1,40 @@
from electrum.lnbase import Peer, decode_msg, gen_msg
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.ecc import ECPrivkey
from electrum.lnrouter import ChannelDB
import unittest import unittest
import asyncio import asyncio
from electrum import simple_config
import tempfile import tempfile
from decimal import Decimal
import os
from contextlib import contextmanager
from collections import defaultdict
from electrum.network import Network
from electrum.ecc import ECPrivkey
from electrum import simple_config, lnutil
from electrum.lnaddr import lnencode, LnAddr, lndecode
from electrum.bitcoin import COIN, sha256
from electrum.util import bh2u
from electrum.lnbase import Peer, decode_msg, gen_msg
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnrouter import ChannelDB, LNPathFinder
from electrum.lnworker import LNWorker
from .test_lnchan import create_test_channels from .test_lnchan import create_test_channels
def keypair():
priv = ECPrivkey.generate_random_key().get_secret_bytes()
k1 = Keypair(
pubkey=privkey_to_pubkey(priv),
privkey=priv)
return k1
@contextmanager
def noop_lock():
yield
class MockNetwork: class MockNetwork:
def __init__(self): def __init__(self):
self.callbacks = defaultdict(list)
self.lnwatcher = None self.lnwatcher = None
user_config = {} user_config = {}
user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-") user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-")
@ -18,49 +42,132 @@ class MockNetwork:
self.asyncio_loop = asyncio.get_event_loop() self.asyncio_loop = asyncio.get_event_loop()
self.channel_db = ChannelDB(self) self.channel_db = ChannelDB(self)
self.interface = None self.interface = None
def register_callback(self, cb, trigger_names): self.path_finder = LNPathFinder(self.channel_db)
print("callback registered", repr(trigger_names))
def trigger_callback(self, trigger_name, obj): @property
print("callback triggered", repr(trigger_name)) def callback_lock(self):
return noop_lock()
register_callback = Network.register_callback
unregister_callback = Network.unregister_callback
trigger_callback = Network.trigger_callback
def get_local_height(self):
return 0
class MockLNWorker: class MockLNWorker:
def __init__(self, remote_peer_pubkey, chan): def __init__(self, remote_keypair, local_keypair, chan):
self.chan = chan self.chan = chan
self.remote_peer_pubkey = remote_peer_pubkey self.remote_keypair = remote_keypair
priv = ECPrivkey.generate_random_key().get_secret_bytes() self.node_keypair = local_keypair
self.node_keypair = Keypair(
pubkey=privkey_to_pubkey(priv),
privkey=priv)
self.network = MockNetwork() self.network = MockNetwork()
self.channels = {self.chan.channel_id: self.chan}
self.invoices = {}
@property
def lock(self):
return noop_lock()
@property @property
def peers(self): def peers(self):
return {self.remote_peer_pubkey: self.peer} return {self.remote_keypair.pubkey: self.peer}
def channels_for_peer(self, pubkey): def channels_for_peer(self, pubkey):
return {self.chan.channel_id: self.chan} return self.channels
def save_channel(self, chan):
pass
get_invoice = LNWorker.get_invoice
_create_route_from_invoice = LNWorker._create_route_from_invoice
class MockTransport: class MockTransport:
def __init__(self): def __init__(self):
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
async def read_messages(self): async def read_messages(self):
while True: while True:
yield await self.queue.get() yield await self.queue.get()
class BadFeaturesTransport(MockTransport): class NoFeaturesTransport(MockTransport):
"""
This answers the init message with a init that doesn't signal any features.
Used for testing that we require DATA_LOSS_PROTECT.
"""
def send_bytes(self, data): def send_bytes(self, data):
decoded = decode_msg(data) decoded = decode_msg(data)
print(decoded) print(decoded)
if decoded[0] == 'init': if decoded[0] == 'init':
self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00")) self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
class PutIntoOthersQueueTransport(MockTransport):
def __init__(self):
super().__init__()
self.other_mock_transport = None
def send_bytes(self, data):
self.other_mock_transport.queue.put_nowait(data)
def transport_pair():
t1 = PutIntoOthersQueueTransport()
t2 = PutIntoOthersQueueTransport()
t1.other_mock_transport = t2
t2.other_mock_transport = t1
return t1, t2
class TestPeer(unittest.TestCase): class TestPeer(unittest.TestCase):
def setUp(self): def setUp(self):
self.alice_channel, self.bob_channel = create_test_channels() self.alice_channel, self.bob_channel = create_test_channels()
def test_bad_feature_flags(self):
# we should require DATA_LOSS_PROTECT def test_require_data_loss_protect(self):
mock_lnworker = MockLNWorker(b"\x00" * 32, self.alice_channel) mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel)
mock_transport = BadFeaturesTransport() mock_transport = NoFeaturesTransport()
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 32), request_initial_sync=False, transport=mock_transport) p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
mock_lnworker.peer = p1 mock_lnworker.peer = p1
with self.assertRaises(LightningPeerConnectionClosed): with self.assertRaises(LightningPeerConnectionClosed):
asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1)) asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
def test_payment(self):
k1, k2 = keypair(), keypair()
t1, t2 = transport_pair()
w1 = MockLNWorker(k1, k2, self.alice_channel)
w2 = MockLNWorker(k2, k1, self.bob_channel)
p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
request_initial_sync=False, transport=t1)
p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
request_initial_sync=False, transport=t2)
w1.peer = p1
w2.peer = p2
# mark_open won't work if state is already OPEN.
# so set it to OPENING
self.alice_channel.set_state("OPENING")
self.bob_channel.set_state("OPENING")
# this populates the channel graph:
p1.mark_open(self.alice_channel)
p2.mark_open(self.bob_channel)
amount_btc = 100000/Decimal(COIN)
payment_preimage = os.urandom(32)
RHASH = sha256(payment_preimage)
addr = LnAddr(
RHASH,
amount_btc,
tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE),
('d', 'coffee')
])
pay_req = lnencode(addr, w2.node_keypair.privkey)
w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req)
l = asyncio.get_event_loop()
async def pay():
fut = asyncio.Future()
def evt_set(event, _lnworker, msg):
fut.set_result(msg)
w2.network.register_callback(evt_set, ['ln_message'])
addr, peer, coro = LNWorker._pay(w1, pay_req)
await coro
print("HTLC ADDED")
self.assertEqual(await fut, 'Payment received')
gath.cancel()
gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
with self.assertRaises(asyncio.CancelledError):
l.run_until_complete(gath)

Loading…
Cancel
Save