From c4ab1e6fad559d07b02cb8115302302b26abbff6 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Fri, 11 Oct 2019 10:11:41 +0200 Subject: [PATCH] Encapsulate lightning payment events: - make LNWorker.pending_payments private - public methods: payment_sent, payment_received, await_payment --- electrum/lnchannel.py | 7 ++++++- electrum/lnpeer.py | 29 ++++++++++++++--------------- electrum/lnworker.py | 26 ++++++++++++++++++++------ electrum/tests/test_lnpeer.py | 3 +++ 4 files changed, 43 insertions(+), 22 deletions(-) diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 9596ce684..ea7cbac4d 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -31,7 +31,7 @@ from typing import Optional, Dict, List, Tuple, NamedTuple, Set, Callable, Itera import time from . import ecc -from .util import bfh, bh2u, PR_PAID, PR_FAILED +from .util import bfh, bh2u from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS from .bitcoin import redeem_script_to_address from .crypto import sha256, sha256d @@ -573,6 +573,11 @@ class Channel(Logger): assert htlc_id not in log['settles'] self.hm.send_settle(htlc_id) + def get_payment_hash(self, htlc_id): + log = self.hm.log[REMOTE] + htlc = log['adds'][htlc_id] + return htlc.payment_hash + def receive_htlc_settle(self, preimage, htlc_id): self.logger.info("receive_htlc_settle") log = self.hm.log[LOCAL] diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index ada425b23..3a7314bad 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -41,7 +41,6 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, MINIMUM_MAX_HTLC_VALUE_IN_FLIGHT_ACCEPTED, MAXIMUM_HTLC_MINIMUM_MSAT_ACCEPTED, MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY, NBLOCK_OUR_CLTV_EXPIRY_DELTA, format_short_channel_id, ShortChannelID) -from .util import PR_PAID from .lnutil import FeeUpdate from .lntransport import LNTransport, LNTransportBase from .lnmsg import encode_msg, decode_msg @@ -1103,7 +1102,8 @@ class Peer(Logger): async def _on_update_fail_htlc(self, channel_id, htlc_id, local_ctn): chan = self.channels[channel_id] await self.await_local(chan, local_ctn) - self.lnworker.pending_payments[(chan.short_channel_id, htlc_id)].set_result(False) + payment_hash = chan.get_payment_hash(htlc_id) + self.lnworker.payment_sent(payment_hash, False) @log_exceptions async def _handle_error_code_from_failed_htlc(self, payload, channel_id, htlc_id): @@ -1267,14 +1267,13 @@ class Peer(Logger): self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") chan.receive_htlc_settle(preimage, htlc_id) self.lnworker.save_preimage(payment_hash, preimage) - self.lnworker.set_payment_status(payment_hash, PR_PAID) local_ctn = chan.get_latest_ctn(LOCAL) - asyncio.ensure_future(self._on_update_fulfill_htlc(chan, htlc_id, preimage, local_ctn)) + asyncio.ensure_future(self._on_update_fulfill_htlc(chan, local_ctn, payment_hash)) @log_exceptions - async def _on_update_fulfill_htlc(self, chan, htlc_id, preimage, local_ctn): + async def _on_update_fulfill_htlc(self, chan, local_ctn, payment_hash): await self.await_local(chan, local_ctn) - self.lnworker.pending_payments[(chan.short_channel_id, htlc_id)].set_result(True) + self.lnworker.payment_sent(payment_hash, True) def on_update_fail_malformed_htlc(self, payload): self.logger.info(f"on_update_fail_malformed_htlc. error {payload['data'].decode('ascii')}") @@ -1392,13 +1391,14 @@ class Peer(Logger): onion_routing_packet=processed_onion.next_packet.to_bytes() ) await next_peer.await_remote(next_chan, next_remote_ctn) - # wait until we get paid - success = await self.lnworker.pending_payments[(next_chan.short_channel_id, next_htlc.htlc_id)] - assert success - preimage = self.lnworker.get_preimage(next_htlc.payment_hash) - # fulfill the original htlc - await self._fulfill_htlc(chan, htlc.htlc_id, preimage) - self.logger.info("htlc forwarded successfully") + success, preimage = await self.lnworker.await_payment(next_htlc.payment_hash) + if success: + await self._fulfill_htlc(chan, htlc.htlc_id, preimage) + self.logger.info("htlc forwarded successfully") + else: + self.logger.info("htlc not fulfilled") + # TODO: Read error code and forward it, as follows: + await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) @log_exceptions async def _maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int, @@ -1443,14 +1443,13 @@ class Peer(Logger): self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") chan.settle_htlc(preimage, htlc_id) payment_hash = sha256(preimage) - self.lnworker.set_payment_status(payment_hash, PR_PAID) + self.lnworker.payment_received(payment_hash) remote_ctn = chan.get_latest_ctn(REMOTE) self.send_message("update_fulfill_htlc", channel_id=chan.channel_id, id=htlc_id, payment_preimage=preimage) await self.await_remote(chan, remote_ctn) - #self.lnworker.payment_received(htlc_id) async def fail_htlc(self, chan: Channel, htlc_id: int, onion_packet: OnionPacket, reason: OnionRoutingFailureMessage): diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 40dbea101..e218f55ba 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -853,6 +853,8 @@ class LNWallet(LNWorker): status = self.get_payment_status(lnaddr.paymenthash) if status == PR_PAID: raise PaymentFailure(_("This invoice has been paid already")) + if status == PR_INFLIGHT: + raise PaymentFailure(_("A payment was already initiated for this invoice")) info = PaymentInfo(lnaddr.paymenthash, amount, SENT, PR_UNPAID) self.save_payment_info(info) self._check_invoice(invoice, amount_sat) @@ -874,8 +876,7 @@ class LNWallet(LNWorker): peer = self.peers[route[0].node_id] htlc = await peer.pay(route, chan, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry()) self.network.trigger_callback('htlc_added', htlc, lnaddr, SENT) - success = await self.pending_payments[(short_channel_id, htlc.htlc_id)] - self.set_payment_status(lnaddr.paymenthash, (PR_PAID if success else PR_UNPAID)) + success, preimage = await self.await_payment(lnaddr.paymenthash) return success @staticmethod @@ -1012,6 +1013,7 @@ class LNWallet(LNWorker): def save_payment_info(self, info): key = info.payment_hash.hex() + assert info.status in [PR_PAID, PR_UNPAID, PR_INFLIGHT] with self.lock: self.payments[key] = info.amount, info.direction, info.status self.storage.put('lightning_payments', self.payments) @@ -1020,9 +1022,15 @@ class LNWallet(LNWorker): def get_payment_status(self, payment_hash): try: info = self.get_payment_info(payment_hash) - return info.status + status = info.status except UnknownPaymentHash: - return PR_UNKNOWN + status = PR_UNPAID + return status + + async def await_payment(self, payment_hash): + success = await self.pending_payments[payment_hash] + preimage = self.get_preimage(payment_hash) + return success, preimage def set_payment_status(self, payment_hash: bytes, status): try: @@ -1032,8 +1040,14 @@ class LNWallet(LNWorker): return info = info._replace(status=status) self.save_payment_info(info) - if info.direction == RECEIVED and info.status == PR_PAID: - self.network.trigger_callback('payment_received', self.wallet, bh2u(payment_hash), PR_PAID) + + def payment_sent(self, payment_hash: bytes, success): + status = PR_PAID if success else PR_UNPAID + self.set_payment_status(payment_hash, status) + self.pending_payments[payment_hash].set_result(success) + + def payment_received(self, payment_hash: bytes): + self.set_payment_status(payment_hash, PR_PAID) async def _calc_routing_hints_for_invoice(self, amount_sat): """calculate routing hints (BOLT-11 'r' field)""" diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index f73d12a07..a7565d7a3 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -124,6 +124,9 @@ class MockLNWallet: save_payment_info = LNWallet.save_payment_info set_payment_status = LNWallet.set_payment_status get_payment_status = LNWallet.get_payment_status + await_payment = LNWallet.await_payment + payment_received = LNWallet.payment_received + payment_sent = LNWallet.payment_sent save_preimage = LNWallet.save_preimage get_preimage = LNWallet.get_preimage _create_route_from_invoice = LNWallet._create_route_from_invoice