diff --git a/electrum/lnbase.py b/electrum/lnbase.py index f8c5c6c06..86dfe19e7 100644 --- a/electrum/lnbase.py +++ b/electrum/lnbase.py @@ -19,7 +19,7 @@ import aiorpcx from .crypto import sha256, sha256d from . import bitcoin from . import ecc -from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string +from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string, der_sig_from_sig_string from . import constants from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabled_bits, ignore_exceptions from .transaction import Transaction, TxOutput @@ -1158,6 +1158,25 @@ class Peer(PrintError): self.print_error('Channel closed', txid) return txid + async def force_close_channel(self, chan_id): + chan = self.channels[chan_id] + # local_commitment always gives back the next expected local_commitment, + # but in this case, we want the current one. So substract one ctn number + old_local_state = chan.config[LOCAL] + chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1) + tx = chan.pending_local_commitment + chan.config[LOCAL] = old_local_state + tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)}) + remote_sig = chan.config[LOCAL].current_commitment_signature + remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01" + none_idx = tx._inputs[0]["signatures"].index(None) + tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig)) + assert tx.is_complete() + # TODO persist FORCE_CLOSING state to disk + chan.set_state('FORCE_CLOSING') + self.lnworker.save_channel(chan) + return await self.network.broadcast_transaction(tx) + @log_exceptions async def on_shutdown(self, payload): # length of scripts allowed in BOLT-02 diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 6ce269118..9df2f8ed9 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -11,6 +11,7 @@ from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING import threading import socket import json +from decimal import Decimal import dns.resolver import dns.exception @@ -267,18 +268,13 @@ class LNWorker(PrintError): return addr, peer, fut def _pay(self, invoice, amount_sat=None): - addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) - payment_hash = addr.paymenthash - amount_sat = (addr.amount * COIN) if addr.amount else amount_sat - if amount_sat is None: - raise InvoiceError(_("Missing amount")) - amount_msat = int(amount_sat * 1000) - if addr.get_min_final_cltv_expiry() > 60 * 144: - raise InvoiceError("{}\n{}".format( - _("Invoice wants us to risk locking funds for unreasonably long."), - f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}")) - route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat) - node_id, short_channel_id = route[0].node_id, route[0].short_channel_id + addr = self._check_invoice(invoice, amount_sat) + route = self._create_route_from_invoice(decoded_invoice=addr) + peer = self.peers[route[0].node_id] + return addr, peer, self._pay_to_route(route, addr) + + async def _pay_to_route(self, route, addr): + short_channel_id = route[0].short_channel_id with self.lock: channels = list(self.channels.values()) for chan in channels: @@ -286,11 +282,24 @@ class LNWorker(PrintError): break else: raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id))) - peer = self.peers[node_id] - coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry()) - return addr, peer, coro + peer = self.peers[route[0].node_id] + return await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry()) - def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]: + @staticmethod + def _check_invoice(invoice, amount_sat=None): + addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) + if amount_sat: + addr.amount = Decimal(amount_sat) / COIN + if addr.amount is None: + raise InvoiceError(_("Missing amount")) + if addr.get_min_final_cltv_expiry() > 60 * 144: + raise InvoiceError("{}\n{}".format( + _("Invoice wants us to risk locking funds for unreasonably long."), + f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}")) + return addr + + def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]: + amount_msat = int(decoded_invoice.amount * COIN * 1000) invoice_pubkey = decoded_invoice.pubkey.serialize() # use 'r' field from invoice route = None # type: List[RouteEdge] @@ -441,19 +450,8 @@ class LNWorker(PrintError): async def force_close_channel(self, chan_id): chan = self.channels[chan_id] - # local_commitment always gives back the next expected local_commitment, - # but in this case, we want the current one. So substract one ctn number - old_local_state = chan.config[LOCAL] - chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1) - tx = chan.pending_local_commitment - chan.config[LOCAL] = old_local_state - tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)}) - remote_sig = chan.config[LOCAL].current_commitment_signature - remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01" - none_idx = tx._inputs[0]["signatures"].index(None) - tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig)) - assert tx.is_complete() - return await self.network.broadcast_transaction(tx) + peer = self.peers[chan.node_id] + return await peer.force_close_channel(chan_id) def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]: now = time.time() diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py index c64e1ced0..cf57f6a91 100644 --- a/electrum/tests/test_lnbase.py +++ b/electrum/tests/test_lnbase.py @@ -16,6 +16,7 @@ from electrum.util import bh2u from electrum.lnbase import Peer, decode_msg, gen_msg from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving +from electrum.lnutil import PaymentFailure from electrum.lnrouter import ChannelDB, LNPathFinder from electrum.lnworker import LNWorker @@ -33,7 +34,7 @@ def noop_lock(): yield class MockNetwork: - def __init__(self): + def __init__(self, tx_queue): self.callbacks = defaultdict(list) self.lnwatcher = None user_config = {} @@ -43,6 +44,7 @@ class MockNetwork: self.channel_db = ChannelDB(self) self.interface = None self.path_finder = LNPathFinder(self.channel_db) + self.tx_queue = tx_queue @property def callback_lock(self): @@ -55,12 +57,16 @@ class MockNetwork: def get_local_height(self): return 0 + async def broadcast_transaction(self, tx): + if self.tx_queue: + await self.tx_queue.put(tx) + class MockLNWorker: - def __init__(self, remote_keypair, local_keypair, chan): + def __init__(self, remote_keypair, local_keypair, chan, tx_queue): self.chan = chan self.remote_keypair = remote_keypair self.node_keypair = local_keypair - self.network = MockNetwork() + self.network = MockNetwork(tx_queue) self.channels = {self.chan.channel_id: self.chan} self.invoices = {} @@ -76,10 +82,12 @@ class MockLNWorker: return self.channels def save_channel(self, chan): - pass + print("Ignoring channel save") get_invoice = LNWorker.get_invoice _create_route_from_invoice = LNWorker._create_route_from_invoice + _check_invoice = staticmethod(LNWorker._check_invoice) + _pay_to_route = LNWorker._pay_to_route class MockTransport: def __init__(self): @@ -120,18 +128,19 @@ class TestPeer(unittest.TestCase): self.alice_channel, self.bob_channel = create_test_channels() def test_require_data_loss_protect(self): - mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel) + mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None) mock_transport = NoFeaturesTransport() p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport) mock_lnworker.peer = p1 with self.assertRaises(LightningPeerConnectionClosed): - asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1)) + run(asyncio.wait_for(p1._main_loop(), 1)) - def test_payment(self): + def prepare_peers(self): k1, k2 = keypair(), keypair() t1, t2 = transport_pair() - w1 = MockLNWorker(k1, k2, self.alice_channel) - w2 = MockLNWorker(k2, k1, self.bob_channel) + q1, q2 = asyncio.Queue(), asyncio.Queue() + w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1) + w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2) p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey), request_initial_sync=False, transport=t1) p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey), @@ -145,6 +154,11 @@ class TestPeer(unittest.TestCase): # this populates the channel graph: p1.mark_open(self.alice_channel) p2.mark_open(self.bob_channel) + return p1, p2, w1, w2, q1, q2 + + @staticmethod + def prepare_invoice(w2 # receiver + ): amount_btc = 100000/Decimal(COIN) payment_preimage = os.urandom(32) RHASH = sha256(payment_preimage) @@ -156,13 +170,23 @@ class TestPeer(unittest.TestCase): ]) pay_req = lnencode(addr, w2.node_keypair.privkey) w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req) - l = asyncio.get_event_loop() - async def pay(): - fut = asyncio.Future() - def evt_set(event, _lnworker, msg): - fut.set_result(msg) - w2.network.register_callback(evt_set, ['ln_message']) + return pay_req + + @staticmethod + def prepare_ln_message_future(w2 # receiver + ): + fut = asyncio.Future() + def evt_set(event, _lnworker, msg): + fut.set_result(msg) + w2.network.register_callback(evt_set, ['ln_message']) + return fut + + def test_payment(self): + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers() + pay_req = self.prepare_invoice(w2) + fut = self.prepare_ln_message_future(w2) + async def pay(): addr, peer, coro = LNWorker._pay(w1, pay_req) await coro print("HTLC ADDED") @@ -170,4 +194,28 @@ class TestPeer(unittest.TestCase): gath.cancel() gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop()) with self.assertRaises(asyncio.CancelledError): - l.run_until_complete(gath) + run(gath) + + def test_channel_usage_after_closing(self): + p1, p2, w1, w2, q1, q2 = self.prepare_peers() + pay_req = self.prepare_invoice(w2) + + addr = w1._check_invoice(pay_req) + route = w1._create_route_from_invoice(decoded_invoice=addr) + + run(p1.force_close_channel(self.alice_channel.channel_id)) + # check if a tx (commitment transaction) was broadcasted: + assert q1.qsize() == 1 + + with self.assertRaises(PaymentFailure) as e: + w1._create_route_from_invoice(decoded_invoice=addr) + self.assertEqual(str(e.exception), 'No path found') + + peer = w1.peers[route[0].node_id] + # AssertionError is ok since we shouldn't use old routes, and the + # route finding should fail when channel is closed + with self.assertRaises(AssertionError): + run(asyncio.gather(w1._pay_to_route(route, addr), p1._main_loop(), p2._main_loop())) + +def run(coro): + asyncio.get_event_loop().run_until_complete(coro) diff --git a/electrum/tests/test_lnchan.py b/electrum/tests/test_lnchan.py index 8b99b4f6f..4fa6d2257 100644 --- a/electrum/tests/test_lnchan.py +++ b/electrum/tests/test_lnchan.py @@ -29,6 +29,7 @@ from electrum import lnchan from electrum import lnutil from electrum import bip32 as bip32_utils from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED +from electrum.ecc import sig_string_from_der_sig one_bitcoin_in_msat = bitcoin.COIN * 1000 @@ -81,7 +82,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate per_commitment_secret_seed=seed, funding_locked_received=True, was_announced=False, - current_commitment_signature=None, + # just a random signature + current_commitment_signature=sig_string_from_der_sig(bytes.fromhex('3046022100c66e112e22b91b96b795a6dd5f4b004f3acccd9a2a31bf104840f256855b7aa3022100e711b868b62d87c7edd95a2370e496b9cb6a38aff13c9f64f9ff2f3b2a0052dd')), current_htlc_signatures=None, ), "constraints":lnbase.ChannelConstraints( @@ -185,6 +187,14 @@ class TestChannel(unittest.TestCase): self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0] + def test_concurrent_reversed_payment(self): + self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02') + self.htlc_dict['amount_msat'] += 1000 + bob_idx = self.bob_channel.add_htlc(self.htlc_dict) + alice_idx = self.alice_channel.receive_htlc(self.htlc_dict) + self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment()) + self.assertEquals(len(self.alice_channel.pending_remote_commitment.outputs()), 3) + def test_SimpleAddSettleWorkflow(self): alice_channel, bob_channel = self.alice_channel, self.bob_channel htlc = self.htlc