Browse Source

move force_close_channel to lnbase, test it, add FORCE_CLOSING state

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 6 years ago
committed by ThomasV
parent
commit
0ea87278fb
  1. 21
      electrum/lnbase.py
  2. 56
      electrum/lnworker.py
  3. 72
      electrum/tests/test_lnbase.py
  4. 12
      electrum/tests/test_lnchan.py

21
electrum/lnbase.py

@ -19,7 +19,7 @@ import aiorpcx
from .crypto import sha256, sha256d from .crypto import sha256, sha256d
from . import bitcoin from . import bitcoin
from . import ecc from . import ecc
from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string, der_sig_from_sig_string
from . import constants from . import constants
from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabled_bits, ignore_exceptions from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabled_bits, ignore_exceptions
from .transaction import Transaction, TxOutput from .transaction import Transaction, TxOutput
@ -1158,6 +1158,25 @@ class Peer(PrintError):
self.print_error('Channel closed', txid) self.print_error('Channel closed', txid)
return txid return txid
async def force_close_channel(self, chan_id):
chan = self.channels[chan_id]
# local_commitment always gives back the next expected local_commitment,
# but in this case, we want the current one. So substract one ctn number
old_local_state = chan.config[LOCAL]
chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1)
tx = chan.pending_local_commitment
chan.config[LOCAL] = old_local_state
tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)})
remote_sig = chan.config[LOCAL].current_commitment_signature
remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
none_idx = tx._inputs[0]["signatures"].index(None)
tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
assert tx.is_complete()
# TODO persist FORCE_CLOSING state to disk
chan.set_state('FORCE_CLOSING')
self.lnworker.save_channel(chan)
return await self.network.broadcast_transaction(tx)
@log_exceptions @log_exceptions
async def on_shutdown(self, payload): async def on_shutdown(self, payload):
# length of scripts allowed in BOLT-02 # length of scripts allowed in BOLT-02

56
electrum/lnworker.py

@ -11,6 +11,7 @@ from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
import threading import threading
import socket import socket
import json import json
from decimal import Decimal
import dns.resolver import dns.resolver
import dns.exception import dns.exception
@ -267,18 +268,13 @@ class LNWorker(PrintError):
return addr, peer, fut return addr, peer, fut
def _pay(self, invoice, amount_sat=None): def _pay(self, invoice, amount_sat=None):
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) addr = self._check_invoice(invoice, amount_sat)
payment_hash = addr.paymenthash route = self._create_route_from_invoice(decoded_invoice=addr)
amount_sat = (addr.amount * COIN) if addr.amount else amount_sat peer = self.peers[route[0].node_id]
if amount_sat is None: return addr, peer, self._pay_to_route(route, addr)
raise InvoiceError(_("Missing amount"))
amount_msat = int(amount_sat * 1000) async def _pay_to_route(self, route, addr):
if addr.get_min_final_cltv_expiry() > 60 * 144: short_channel_id = route[0].short_channel_id
raise InvoiceError("{}\n{}".format(
_("Invoice wants us to risk locking funds for unreasonably long."),
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat)
node_id, short_channel_id = route[0].node_id, route[0].short_channel_id
with self.lock: with self.lock:
channels = list(self.channels.values()) channels = list(self.channels.values())
for chan in channels: for chan in channels:
@ -286,11 +282,24 @@ class LNWorker(PrintError):
break break
else: else:
raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id))) raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
peer = self.peers[node_id] peer = self.peers[route[0].node_id]
coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry()) return await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry())
return addr, peer, coro
def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]: @staticmethod
def _check_invoice(invoice, amount_sat=None):
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
if amount_sat:
addr.amount = Decimal(amount_sat) / COIN
if addr.amount is None:
raise InvoiceError(_("Missing amount"))
if addr.get_min_final_cltv_expiry() > 60 * 144:
raise InvoiceError("{}\n{}".format(
_("Invoice wants us to risk locking funds for unreasonably long."),
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
return addr
def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]:
amount_msat = int(decoded_invoice.amount * COIN * 1000)
invoice_pubkey = decoded_invoice.pubkey.serialize() invoice_pubkey = decoded_invoice.pubkey.serialize()
# use 'r' field from invoice # use 'r' field from invoice
route = None # type: List[RouteEdge] route = None # type: List[RouteEdge]
@ -441,19 +450,8 @@ class LNWorker(PrintError):
async def force_close_channel(self, chan_id): async def force_close_channel(self, chan_id):
chan = self.channels[chan_id] chan = self.channels[chan_id]
# local_commitment always gives back the next expected local_commitment, peer = self.peers[chan.node_id]
# but in this case, we want the current one. So substract one ctn number return await peer.force_close_channel(chan_id)
old_local_state = chan.config[LOCAL]
chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1)
tx = chan.pending_local_commitment
chan.config[LOCAL] = old_local_state
tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)})
remote_sig = chan.config[LOCAL].current_commitment_signature
remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
none_idx = tx._inputs[0]["signatures"].index(None)
tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
assert tx.is_complete()
return await self.network.broadcast_transaction(tx)
def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]: def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
now = time.time() now = time.time()

72
electrum/tests/test_lnbase.py

@ -16,6 +16,7 @@ from electrum.util import bh2u
from electrum.lnbase import Peer, decode_msg, gen_msg from electrum.lnbase import Peer, decode_msg, gen_msg
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure
from electrum.lnrouter import ChannelDB, LNPathFinder from electrum.lnrouter import ChannelDB, LNPathFinder
from electrum.lnworker import LNWorker from electrum.lnworker import LNWorker
@ -33,7 +34,7 @@ def noop_lock():
yield yield
class MockNetwork: class MockNetwork:
def __init__(self): def __init__(self, tx_queue):
self.callbacks = defaultdict(list) self.callbacks = defaultdict(list)
self.lnwatcher = None self.lnwatcher = None
user_config = {} user_config = {}
@ -43,6 +44,7 @@ class MockNetwork:
self.channel_db = ChannelDB(self) self.channel_db = ChannelDB(self)
self.interface = None self.interface = None
self.path_finder = LNPathFinder(self.channel_db) self.path_finder = LNPathFinder(self.channel_db)
self.tx_queue = tx_queue
@property @property
def callback_lock(self): def callback_lock(self):
@ -55,12 +57,16 @@ class MockNetwork:
def get_local_height(self): def get_local_height(self):
return 0 return 0
async def broadcast_transaction(self, tx):
if self.tx_queue:
await self.tx_queue.put(tx)
class MockLNWorker: class MockLNWorker:
def __init__(self, remote_keypair, local_keypair, chan): def __init__(self, remote_keypair, local_keypair, chan, tx_queue):
self.chan = chan self.chan = chan
self.remote_keypair = remote_keypair self.remote_keypair = remote_keypair
self.node_keypair = local_keypair self.node_keypair = local_keypair
self.network = MockNetwork() self.network = MockNetwork(tx_queue)
self.channels = {self.chan.channel_id: self.chan} self.channels = {self.chan.channel_id: self.chan}
self.invoices = {} self.invoices = {}
@ -76,10 +82,12 @@ class MockLNWorker:
return self.channels return self.channels
def save_channel(self, chan): def save_channel(self, chan):
pass print("Ignoring channel save")
get_invoice = LNWorker.get_invoice get_invoice = LNWorker.get_invoice
_create_route_from_invoice = LNWorker._create_route_from_invoice _create_route_from_invoice = LNWorker._create_route_from_invoice
_check_invoice = staticmethod(LNWorker._check_invoice)
_pay_to_route = LNWorker._pay_to_route
class MockTransport: class MockTransport:
def __init__(self): def __init__(self):
@ -120,18 +128,19 @@ class TestPeer(unittest.TestCase):
self.alice_channel, self.bob_channel = create_test_channels() self.alice_channel, self.bob_channel = create_test_channels()
def test_require_data_loss_protect(self): def test_require_data_loss_protect(self):
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel) mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
mock_transport = NoFeaturesTransport() mock_transport = NoFeaturesTransport()
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport) p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
mock_lnworker.peer = p1 mock_lnworker.peer = p1
with self.assertRaises(LightningPeerConnectionClosed): with self.assertRaises(LightningPeerConnectionClosed):
asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1)) run(asyncio.wait_for(p1._main_loop(), 1))
def test_payment(self): def prepare_peers(self):
k1, k2 = keypair(), keypair() k1, k2 = keypair(), keypair()
t1, t2 = transport_pair() t1, t2 = transport_pair()
w1 = MockLNWorker(k1, k2, self.alice_channel) q1, q2 = asyncio.Queue(), asyncio.Queue()
w2 = MockLNWorker(k2, k1, self.bob_channel) w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1)
w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2)
p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey), p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
request_initial_sync=False, transport=t1) request_initial_sync=False, transport=t1)
p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey), p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
@ -145,6 +154,11 @@ class TestPeer(unittest.TestCase):
# this populates the channel graph: # this populates the channel graph:
p1.mark_open(self.alice_channel) p1.mark_open(self.alice_channel)
p2.mark_open(self.bob_channel) p2.mark_open(self.bob_channel)
return p1, p2, w1, w2, q1, q2
@staticmethod
def prepare_invoice(w2 # receiver
):
amount_btc = 100000/Decimal(COIN) amount_btc = 100000/Decimal(COIN)
payment_preimage = os.urandom(32) payment_preimage = os.urandom(32)
RHASH = sha256(payment_preimage) RHASH = sha256(payment_preimage)
@ -156,13 +170,23 @@ class TestPeer(unittest.TestCase):
]) ])
pay_req = lnencode(addr, w2.node_keypair.privkey) pay_req = lnencode(addr, w2.node_keypair.privkey)
w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req) w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req)
l = asyncio.get_event_loop() return pay_req
async def pay():
@staticmethod
def prepare_ln_message_future(w2 # receiver
):
fut = asyncio.Future() fut = asyncio.Future()
def evt_set(event, _lnworker, msg): def evt_set(event, _lnworker, msg):
fut.set_result(msg) fut.set_result(msg)
w2.network.register_callback(evt_set, ['ln_message']) w2.network.register_callback(evt_set, ['ln_message'])
return fut
def test_payment(self):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers()
pay_req = self.prepare_invoice(w2)
fut = self.prepare_ln_message_future(w2)
async def pay():
addr, peer, coro = LNWorker._pay(w1, pay_req) addr, peer, coro = LNWorker._pay(w1, pay_req)
await coro await coro
print("HTLC ADDED") print("HTLC ADDED")
@ -170,4 +194,28 @@ class TestPeer(unittest.TestCase):
gath.cancel() gath.cancel()
gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop()) gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
with self.assertRaises(asyncio.CancelledError): with self.assertRaises(asyncio.CancelledError):
l.run_until_complete(gath) run(gath)
def test_channel_usage_after_closing(self):
p1, p2, w1, w2, q1, q2 = self.prepare_peers()
pay_req = self.prepare_invoice(w2)
addr = w1._check_invoice(pay_req)
route = w1._create_route_from_invoice(decoded_invoice=addr)
run(p1.force_close_channel(self.alice_channel.channel_id))
# check if a tx (commitment transaction) was broadcasted:
assert q1.qsize() == 1
with self.assertRaises(PaymentFailure) as e:
w1._create_route_from_invoice(decoded_invoice=addr)
self.assertEqual(str(e.exception), 'No path found')
peer = w1.peers[route[0].node_id]
# AssertionError is ok since we shouldn't use old routes, and the
# route finding should fail when channel is closed
with self.assertRaises(AssertionError):
run(asyncio.gather(w1._pay_to_route(route, addr), p1._main_loop(), p2._main_loop()))
def run(coro):
asyncio.get_event_loop().run_until_complete(coro)

12
electrum/tests/test_lnchan.py

@ -29,6 +29,7 @@ from electrum import lnchan
from electrum import lnutil from electrum import lnutil
from electrum import bip32 as bip32_utils from electrum import bip32 as bip32_utils
from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED
from electrum.ecc import sig_string_from_der_sig
one_bitcoin_in_msat = bitcoin.COIN * 1000 one_bitcoin_in_msat = bitcoin.COIN * 1000
@ -81,7 +82,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
per_commitment_secret_seed=seed, per_commitment_secret_seed=seed,
funding_locked_received=True, funding_locked_received=True,
was_announced=False, was_announced=False,
current_commitment_signature=None, # just a random signature
current_commitment_signature=sig_string_from_der_sig(bytes.fromhex('3046022100c66e112e22b91b96b795a6dd5f4b004f3acccd9a2a31bf104840f256855b7aa3022100e711b868b62d87c7edd95a2370e496b9cb6a38aff13c9f64f9ff2f3b2a0052dd')),
current_htlc_signatures=None, current_htlc_signatures=None,
), ),
"constraints":lnbase.ChannelConstraints( "constraints":lnbase.ChannelConstraints(
@ -185,6 +187,14 @@ class TestChannel(unittest.TestCase):
self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0] self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
def test_concurrent_reversed_payment(self):
self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02')
self.htlc_dict['amount_msat'] += 1000
bob_idx = self.bob_channel.add_htlc(self.htlc_dict)
alice_idx = self.alice_channel.receive_htlc(self.htlc_dict)
self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
self.assertEquals(len(self.alice_channel.pending_remote_commitment.outputs()), 3)
def test_SimpleAddSettleWorkflow(self): def test_SimpleAddSettleWorkflow(self):
alice_channel, bob_channel = self.alice_channel, self.bob_channel alice_channel, bob_channel = self.alice_channel, self.bob_channel
htlc = self.htlc htlc = self.htlc

Loading…
Cancel
Save