Browse Source

Encapsulate lightning payment events:

- make LNWorker.pending_payments private
 - public methods: payment_sent, payment_received, await_payment
dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
ThomasV 5 years ago
parent
commit
c4ab1e6fad
  1. 7
      electrum/lnchannel.py
  2. 25
      electrum/lnpeer.py
  3. 26
      electrum/lnworker.py
  4. 3
      electrum/tests/test_lnpeer.py

7
electrum/lnchannel.py

@ -31,7 +31,7 @@ from typing import Optional, Dict, List, Tuple, NamedTuple, Set, Callable, Itera
import time import time
from . import ecc from . import ecc
from .util import bfh, bh2u, PR_PAID, PR_FAILED from .util import bfh, bh2u
from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS
from .bitcoin import redeem_script_to_address from .bitcoin import redeem_script_to_address
from .crypto import sha256, sha256d from .crypto import sha256, sha256d
@ -573,6 +573,11 @@ class Channel(Logger):
assert htlc_id not in log['settles'] assert htlc_id not in log['settles']
self.hm.send_settle(htlc_id) self.hm.send_settle(htlc_id)
def get_payment_hash(self, htlc_id):
log = self.hm.log[REMOTE]
htlc = log['adds'][htlc_id]
return htlc.payment_hash
def receive_htlc_settle(self, preimage, htlc_id): def receive_htlc_settle(self, preimage, htlc_id):
self.logger.info("receive_htlc_settle") self.logger.info("receive_htlc_settle")
log = self.hm.log[LOCAL] log = self.hm.log[LOCAL]

25
electrum/lnpeer.py

@ -41,7 +41,6 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
MINIMUM_MAX_HTLC_VALUE_IN_FLIGHT_ACCEPTED, MAXIMUM_HTLC_MINIMUM_MSAT_ACCEPTED, MINIMUM_MAX_HTLC_VALUE_IN_FLIGHT_ACCEPTED, MAXIMUM_HTLC_MINIMUM_MSAT_ACCEPTED,
MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY, MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY,
NBLOCK_OUR_CLTV_EXPIRY_DELTA, format_short_channel_id, ShortChannelID) NBLOCK_OUR_CLTV_EXPIRY_DELTA, format_short_channel_id, ShortChannelID)
from .util import PR_PAID
from .lnutil import FeeUpdate from .lnutil import FeeUpdate
from .lntransport import LNTransport, LNTransportBase from .lntransport import LNTransport, LNTransportBase
from .lnmsg import encode_msg, decode_msg from .lnmsg import encode_msg, decode_msg
@ -1103,7 +1102,8 @@ class Peer(Logger):
async def _on_update_fail_htlc(self, channel_id, htlc_id, local_ctn): async def _on_update_fail_htlc(self, channel_id, htlc_id, local_ctn):
chan = self.channels[channel_id] chan = self.channels[channel_id]
await self.await_local(chan, local_ctn) await self.await_local(chan, local_ctn)
self.lnworker.pending_payments[(chan.short_channel_id, htlc_id)].set_result(False) payment_hash = chan.get_payment_hash(htlc_id)
self.lnworker.payment_sent(payment_hash, False)
@log_exceptions @log_exceptions
async def _handle_error_code_from_failed_htlc(self, payload, channel_id, htlc_id): async def _handle_error_code_from_failed_htlc(self, payload, channel_id, htlc_id):
@ -1267,14 +1267,13 @@ class Peer(Logger):
self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
chan.receive_htlc_settle(preimage, htlc_id) chan.receive_htlc_settle(preimage, htlc_id)
self.lnworker.save_preimage(payment_hash, preimage) self.lnworker.save_preimage(payment_hash, preimage)
self.lnworker.set_payment_status(payment_hash, PR_PAID)
local_ctn = chan.get_latest_ctn(LOCAL) local_ctn = chan.get_latest_ctn(LOCAL)
asyncio.ensure_future(self._on_update_fulfill_htlc(chan, htlc_id, preimage, local_ctn)) asyncio.ensure_future(self._on_update_fulfill_htlc(chan, local_ctn, payment_hash))
@log_exceptions @log_exceptions
async def _on_update_fulfill_htlc(self, chan, htlc_id, preimage, local_ctn): async def _on_update_fulfill_htlc(self, chan, local_ctn, payment_hash):
await self.await_local(chan, local_ctn) await self.await_local(chan, local_ctn)
self.lnworker.pending_payments[(chan.short_channel_id, htlc_id)].set_result(True) self.lnworker.payment_sent(payment_hash, True)
def on_update_fail_malformed_htlc(self, payload): def on_update_fail_malformed_htlc(self, payload):
self.logger.info(f"on_update_fail_malformed_htlc. error {payload['data'].decode('ascii')}") self.logger.info(f"on_update_fail_malformed_htlc. error {payload['data'].decode('ascii')}")
@ -1392,13 +1391,14 @@ class Peer(Logger):
onion_routing_packet=processed_onion.next_packet.to_bytes() onion_routing_packet=processed_onion.next_packet.to_bytes()
) )
await next_peer.await_remote(next_chan, next_remote_ctn) await next_peer.await_remote(next_chan, next_remote_ctn)
# wait until we get paid success, preimage = await self.lnworker.await_payment(next_htlc.payment_hash)
success = await self.lnworker.pending_payments[(next_chan.short_channel_id, next_htlc.htlc_id)] if success:
assert success
preimage = self.lnworker.get_preimage(next_htlc.payment_hash)
# fulfill the original htlc
await self._fulfill_htlc(chan, htlc.htlc_id, preimage) await self._fulfill_htlc(chan, htlc.htlc_id, preimage)
self.logger.info("htlc forwarded successfully") self.logger.info("htlc forwarded successfully")
else:
self.logger.info("htlc not fulfilled")
# TODO: Read error code and forward it, as follows:
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
@log_exceptions @log_exceptions
async def _maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int, async def _maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int,
@ -1443,14 +1443,13 @@ class Peer(Logger):
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
chan.settle_htlc(preimage, htlc_id) chan.settle_htlc(preimage, htlc_id)
payment_hash = sha256(preimage) payment_hash = sha256(preimage)
self.lnworker.set_payment_status(payment_hash, PR_PAID) self.lnworker.payment_received(payment_hash)
remote_ctn = chan.get_latest_ctn(REMOTE) remote_ctn = chan.get_latest_ctn(REMOTE)
self.send_message("update_fulfill_htlc", self.send_message("update_fulfill_htlc",
channel_id=chan.channel_id, channel_id=chan.channel_id,
id=htlc_id, id=htlc_id,
payment_preimage=preimage) payment_preimage=preimage)
await self.await_remote(chan, remote_ctn) await self.await_remote(chan, remote_ctn)
#self.lnworker.payment_received(htlc_id)
async def fail_htlc(self, chan: Channel, htlc_id: int, onion_packet: OnionPacket, async def fail_htlc(self, chan: Channel, htlc_id: int, onion_packet: OnionPacket,
reason: OnionRoutingFailureMessage): reason: OnionRoutingFailureMessage):

26
electrum/lnworker.py

@ -853,6 +853,8 @@ class LNWallet(LNWorker):
status = self.get_payment_status(lnaddr.paymenthash) status = self.get_payment_status(lnaddr.paymenthash)
if status == PR_PAID: if status == PR_PAID:
raise PaymentFailure(_("This invoice has been paid already")) raise PaymentFailure(_("This invoice has been paid already"))
if status == PR_INFLIGHT:
raise PaymentFailure(_("A payment was already initiated for this invoice"))
info = PaymentInfo(lnaddr.paymenthash, amount, SENT, PR_UNPAID) info = PaymentInfo(lnaddr.paymenthash, amount, SENT, PR_UNPAID)
self.save_payment_info(info) self.save_payment_info(info)
self._check_invoice(invoice, amount_sat) self._check_invoice(invoice, amount_sat)
@ -874,8 +876,7 @@ class LNWallet(LNWorker):
peer = self.peers[route[0].node_id] peer = self.peers[route[0].node_id]
htlc = await peer.pay(route, chan, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry()) htlc = await peer.pay(route, chan, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry())
self.network.trigger_callback('htlc_added', htlc, lnaddr, SENT) self.network.trigger_callback('htlc_added', htlc, lnaddr, SENT)
success = await self.pending_payments[(short_channel_id, htlc.htlc_id)] success, preimage = await self.await_payment(lnaddr.paymenthash)
self.set_payment_status(lnaddr.paymenthash, (PR_PAID if success else PR_UNPAID))
return success return success
@staticmethod @staticmethod
@ -1012,6 +1013,7 @@ class LNWallet(LNWorker):
def save_payment_info(self, info): def save_payment_info(self, info):
key = info.payment_hash.hex() key = info.payment_hash.hex()
assert info.status in [PR_PAID, PR_UNPAID, PR_INFLIGHT]
with self.lock: with self.lock:
self.payments[key] = info.amount, info.direction, info.status self.payments[key] = info.amount, info.direction, info.status
self.storage.put('lightning_payments', self.payments) self.storage.put('lightning_payments', self.payments)
@ -1020,9 +1022,15 @@ class LNWallet(LNWorker):
def get_payment_status(self, payment_hash): def get_payment_status(self, payment_hash):
try: try:
info = self.get_payment_info(payment_hash) info = self.get_payment_info(payment_hash)
return info.status status = info.status
except UnknownPaymentHash: except UnknownPaymentHash:
return PR_UNKNOWN status = PR_UNPAID
return status
async def await_payment(self, payment_hash):
success = await self.pending_payments[payment_hash]
preimage = self.get_preimage(payment_hash)
return success, preimage
def set_payment_status(self, payment_hash: bytes, status): def set_payment_status(self, payment_hash: bytes, status):
try: try:
@ -1032,8 +1040,14 @@ class LNWallet(LNWorker):
return return
info = info._replace(status=status) info = info._replace(status=status)
self.save_payment_info(info) self.save_payment_info(info)
if info.direction == RECEIVED and info.status == PR_PAID:
self.network.trigger_callback('payment_received', self.wallet, bh2u(payment_hash), PR_PAID) def payment_sent(self, payment_hash: bytes, success):
status = PR_PAID if success else PR_UNPAID
self.set_payment_status(payment_hash, status)
self.pending_payments[payment_hash].set_result(success)
def payment_received(self, payment_hash: bytes):
self.set_payment_status(payment_hash, PR_PAID)
async def _calc_routing_hints_for_invoice(self, amount_sat): async def _calc_routing_hints_for_invoice(self, amount_sat):
"""calculate routing hints (BOLT-11 'r' field)""" """calculate routing hints (BOLT-11 'r' field)"""

3
electrum/tests/test_lnpeer.py

@ -124,6 +124,9 @@ class MockLNWallet:
save_payment_info = LNWallet.save_payment_info save_payment_info = LNWallet.save_payment_info
set_payment_status = LNWallet.set_payment_status set_payment_status = LNWallet.set_payment_status
get_payment_status = LNWallet.get_payment_status get_payment_status = LNWallet.get_payment_status
await_payment = LNWallet.await_payment
payment_received = LNWallet.payment_received
payment_sent = LNWallet.payment_sent
save_preimage = LNWallet.save_preimage save_preimage = LNWallet.save_preimage
get_preimage = LNWallet.get_preimage get_preimage = LNWallet.get_preimage
_create_route_from_invoice = LNWallet._create_route_from_invoice _create_route_from_invoice = LNWallet._create_route_from_invoice

Loading…
Cancel
Save