Browse Source

move force_close_channel to lnbase, test it, add FORCE_CLOSING state

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 6 years ago
committed by ThomasV
parent
commit
0ea87278fb
  1. 21
      electrum/lnbase.py
  2. 56
      electrum/lnworker.py
  3. 72
      electrum/tests/test_lnbase.py
  4. 12
      electrum/tests/test_lnchan.py

21
electrum/lnbase.py

@ -19,7 +19,7 @@ import aiorpcx
from .crypto import sha256, sha256d
from . import bitcoin
from . import ecc
from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string
from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string, der_sig_from_sig_string
from . import constants
from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabled_bits, ignore_exceptions
from .transaction import Transaction, TxOutput
@ -1158,6 +1158,25 @@ class Peer(PrintError):
self.print_error('Channel closed', txid)
return txid
async def force_close_channel(self, chan_id):
chan = self.channels[chan_id]
# local_commitment always gives back the next expected local_commitment,
# but in this case, we want the current one. So substract one ctn number
old_local_state = chan.config[LOCAL]
chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1)
tx = chan.pending_local_commitment
chan.config[LOCAL] = old_local_state
tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)})
remote_sig = chan.config[LOCAL].current_commitment_signature
remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
none_idx = tx._inputs[0]["signatures"].index(None)
tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
assert tx.is_complete()
# TODO persist FORCE_CLOSING state to disk
chan.set_state('FORCE_CLOSING')
self.lnworker.save_channel(chan)
return await self.network.broadcast_transaction(tx)
@log_exceptions
async def on_shutdown(self, payload):
# length of scripts allowed in BOLT-02

56
electrum/lnworker.py

@ -11,6 +11,7 @@ from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
import threading
import socket
import json
from decimal import Decimal
import dns.resolver
import dns.exception
@ -267,18 +268,13 @@ class LNWorker(PrintError):
return addr, peer, fut
def _pay(self, invoice, amount_sat=None):
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
payment_hash = addr.paymenthash
amount_sat = (addr.amount * COIN) if addr.amount else amount_sat
if amount_sat is None:
raise InvoiceError(_("Missing amount"))
amount_msat = int(amount_sat * 1000)
if addr.get_min_final_cltv_expiry() > 60 * 144:
raise InvoiceError("{}\n{}".format(
_("Invoice wants us to risk locking funds for unreasonably long."),
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat)
node_id, short_channel_id = route[0].node_id, route[0].short_channel_id
addr = self._check_invoice(invoice, amount_sat)
route = self._create_route_from_invoice(decoded_invoice=addr)
peer = self.peers[route[0].node_id]
return addr, peer, self._pay_to_route(route, addr)
async def _pay_to_route(self, route, addr):
short_channel_id = route[0].short_channel_id
with self.lock:
channels = list(self.channels.values())
for chan in channels:
@ -286,11 +282,24 @@ class LNWorker(PrintError):
break
else:
raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
peer = self.peers[node_id]
coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry())
return addr, peer, coro
peer = self.peers[route[0].node_id]
return await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry())
def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]:
@staticmethod
def _check_invoice(invoice, amount_sat=None):
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
if amount_sat:
addr.amount = Decimal(amount_sat) / COIN
if addr.amount is None:
raise InvoiceError(_("Missing amount"))
if addr.get_min_final_cltv_expiry() > 60 * 144:
raise InvoiceError("{}\n{}".format(
_("Invoice wants us to risk locking funds for unreasonably long."),
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
return addr
def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]:
amount_msat = int(decoded_invoice.amount * COIN * 1000)
invoice_pubkey = decoded_invoice.pubkey.serialize()
# use 'r' field from invoice
route = None # type: List[RouteEdge]
@ -441,19 +450,8 @@ class LNWorker(PrintError):
async def force_close_channel(self, chan_id):
chan = self.channels[chan_id]
# local_commitment always gives back the next expected local_commitment,
# but in this case, we want the current one. So substract one ctn number
old_local_state = chan.config[LOCAL]
chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1)
tx = chan.pending_local_commitment
chan.config[LOCAL] = old_local_state
tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)})
remote_sig = chan.config[LOCAL].current_commitment_signature
remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
none_idx = tx._inputs[0]["signatures"].index(None)
tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
assert tx.is_complete()
return await self.network.broadcast_transaction(tx)
peer = self.peers[chan.node_id]
return await peer.force_close_channel(chan_id)
def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
now = time.time()

72
electrum/tests/test_lnbase.py

@ -16,6 +16,7 @@ from electrum.util import bh2u
from electrum.lnbase import Peer, decode_msg, gen_msg
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure
from electrum.lnrouter import ChannelDB, LNPathFinder
from electrum.lnworker import LNWorker
@ -33,7 +34,7 @@ def noop_lock():
yield
class MockNetwork:
def __init__(self):
def __init__(self, tx_queue):
self.callbacks = defaultdict(list)
self.lnwatcher = None
user_config = {}
@ -43,6 +44,7 @@ class MockNetwork:
self.channel_db = ChannelDB(self)
self.interface = None
self.path_finder = LNPathFinder(self.channel_db)
self.tx_queue = tx_queue
@property
def callback_lock(self):
@ -55,12 +57,16 @@ class MockNetwork:
def get_local_height(self):
return 0
async def broadcast_transaction(self, tx):
if self.tx_queue:
await self.tx_queue.put(tx)
class MockLNWorker:
def __init__(self, remote_keypair, local_keypair, chan):
def __init__(self, remote_keypair, local_keypair, chan, tx_queue):
self.chan = chan
self.remote_keypair = remote_keypair
self.node_keypair = local_keypair
self.network = MockNetwork()
self.network = MockNetwork(tx_queue)
self.channels = {self.chan.channel_id: self.chan}
self.invoices = {}
@ -76,10 +82,12 @@ class MockLNWorker:
return self.channels
def save_channel(self, chan):
pass
print("Ignoring channel save")
get_invoice = LNWorker.get_invoice
_create_route_from_invoice = LNWorker._create_route_from_invoice
_check_invoice = staticmethod(LNWorker._check_invoice)
_pay_to_route = LNWorker._pay_to_route
class MockTransport:
def __init__(self):
@ -120,18 +128,19 @@ class TestPeer(unittest.TestCase):
self.alice_channel, self.bob_channel = create_test_channels()
def test_require_data_loss_protect(self):
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel)
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
mock_transport = NoFeaturesTransport()
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
mock_lnworker.peer = p1
with self.assertRaises(LightningPeerConnectionClosed):
asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
run(asyncio.wait_for(p1._main_loop(), 1))
def test_payment(self):
def prepare_peers(self):
k1, k2 = keypair(), keypair()
t1, t2 = transport_pair()
w1 = MockLNWorker(k1, k2, self.alice_channel)
w2 = MockLNWorker(k2, k1, self.bob_channel)
q1, q2 = asyncio.Queue(), asyncio.Queue()
w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1)
w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2)
p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
request_initial_sync=False, transport=t1)
p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
@ -145,6 +154,11 @@ class TestPeer(unittest.TestCase):
# this populates the channel graph:
p1.mark_open(self.alice_channel)
p2.mark_open(self.bob_channel)
return p1, p2, w1, w2, q1, q2
@staticmethod
def prepare_invoice(w2 # receiver
):
amount_btc = 100000/Decimal(COIN)
payment_preimage = os.urandom(32)
RHASH = sha256(payment_preimage)
@ -156,13 +170,23 @@ class TestPeer(unittest.TestCase):
])
pay_req = lnencode(addr, w2.node_keypair.privkey)
w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req)
l = asyncio.get_event_loop()
async def pay():
return pay_req
@staticmethod
def prepare_ln_message_future(w2 # receiver
):
fut = asyncio.Future()
def evt_set(event, _lnworker, msg):
fut.set_result(msg)
w2.network.register_callback(evt_set, ['ln_message'])
return fut
def test_payment(self):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers()
pay_req = self.prepare_invoice(w2)
fut = self.prepare_ln_message_future(w2)
async def pay():
addr, peer, coro = LNWorker._pay(w1, pay_req)
await coro
print("HTLC ADDED")
@ -170,4 +194,28 @@ class TestPeer(unittest.TestCase):
gath.cancel()
gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
with self.assertRaises(asyncio.CancelledError):
l.run_until_complete(gath)
run(gath)
def test_channel_usage_after_closing(self):
p1, p2, w1, w2, q1, q2 = self.prepare_peers()
pay_req = self.prepare_invoice(w2)
addr = w1._check_invoice(pay_req)
route = w1._create_route_from_invoice(decoded_invoice=addr)
run(p1.force_close_channel(self.alice_channel.channel_id))
# check if a tx (commitment transaction) was broadcasted:
assert q1.qsize() == 1
with self.assertRaises(PaymentFailure) as e:
w1._create_route_from_invoice(decoded_invoice=addr)
self.assertEqual(str(e.exception), 'No path found')
peer = w1.peers[route[0].node_id]
# AssertionError is ok since we shouldn't use old routes, and the
# route finding should fail when channel is closed
with self.assertRaises(AssertionError):
run(asyncio.gather(w1._pay_to_route(route, addr), p1._main_loop(), p2._main_loop()))
def run(coro):
asyncio.get_event_loop().run_until_complete(coro)

12
electrum/tests/test_lnchan.py

@ -29,6 +29,7 @@ from electrum import lnchan
from electrum import lnutil
from electrum import bip32 as bip32_utils
from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED
from electrum.ecc import sig_string_from_der_sig
one_bitcoin_in_msat = bitcoin.COIN * 1000
@ -81,7 +82,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
per_commitment_secret_seed=seed,
funding_locked_received=True,
was_announced=False,
current_commitment_signature=None,
# just a random signature
current_commitment_signature=sig_string_from_der_sig(bytes.fromhex('3046022100c66e112e22b91b96b795a6dd5f4b004f3acccd9a2a31bf104840f256855b7aa3022100e711b868b62d87c7edd95a2370e496b9cb6a38aff13c9f64f9ff2f3b2a0052dd')),
current_htlc_signatures=None,
),
"constraints":lnbase.ChannelConstraints(
@ -185,6 +187,14 @@ class TestChannel(unittest.TestCase):
self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
def test_concurrent_reversed_payment(self):
self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02')
self.htlc_dict['amount_msat'] += 1000
bob_idx = self.bob_channel.add_htlc(self.htlc_dict)
alice_idx = self.alice_channel.receive_htlc(self.htlc_dict)
self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
self.assertEquals(len(self.alice_channel.pending_remote_commitment.outputs()), 3)
def test_SimpleAddSettleWorkflow(self):
alice_channel, bob_channel = self.alice_channel, self.bob_channel
htlc = self.htlc

Loading…
Cancel
Save