Browse Source

Merge pull request #6003 from spesmilo/htlc_switch

Htlc switch
hard-fail-on-bad-server-string
ThomasV 5 years ago
committed by GitHub
parent
commit
367d30d6c0
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      electrum/lnchannel.py
  2. 1
      electrum/lnhtlc.py
  3. 232
      electrum/lnpeer.py
  4. 2
      electrum/lnworker.py
  5. 12
      electrum/tests/test_lnpeer.py

7
electrum/lnchannel.py

@ -408,7 +408,7 @@ class Channel(Logger):
self.logger.info("add_htlc") self.logger.info("add_htlc")
return htlc return htlc
def receive_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc: def receive_htlc(self, htlc: UpdateAddHtlc, onion_packet:bytes = None) -> UpdateAddHtlc:
""" """
ReceiveHTLC adds an HTLC to the state machine's remote update log. This ReceiveHTLC adds an HTLC to the state machine's remote update log. This
method should be called in response to receiving a new HTLC from the remote method should be called in response to receiving a new HTLC from the remote
@ -427,6 +427,11 @@ class Channel(Logger):
f' HTLC amount: {htlc.amount_msat}') f' HTLC amount: {htlc.amount_msat}')
with self.db_lock: with self.db_lock:
self.hm.recv_htlc(htlc) self.hm.recv_htlc(htlc)
local_ctn = self.get_latest_ctn(LOCAL)
remote_ctn = self.get_latest_ctn(REMOTE)
if onion_packet:
self.hm.log['unfulfilled_htlcs'][htlc.htlc_id] = local_ctn, remote_ctn, onion_packet.hex(), False
self.logger.info("receive_htlc") self.logger.info("receive_htlc")
return htlc return htlc

1
electrum/lnhtlc.py

@ -25,6 +25,7 @@ class HTLCManager:
log[LOCAL] = deepcopy(initial) log[LOCAL] = deepcopy(initial)
log[REMOTE] = deepcopy(initial) log[REMOTE] = deepcopy(initial)
log['unacked_local_updates2'] = {} log['unacked_local_updates2'] = {}
log['unfulfilled_htlcs'] = {} # htlc_id -> onion_packet
# maybe bootstrap fee_updates if initial_feerate was provided # maybe bootstrap fee_updates if initial_feerate was provided
if initial_feerate is not None: if initial_feerate is not None:

232
electrum/lnpeer.py

@ -249,6 +249,7 @@ class Peer(Logger):
async def main_loop(self): async def main_loop(self):
async with self.taskgroup as group: async with self.taskgroup as group:
await group.spawn(self._message_loop()) await group.spawn(self._message_loop())
await group.spawn(self.htlc_switch())
await group.spawn(self.query_gossip()) await group.spawn(self.query_gossip())
await group.spawn(self.process_gossip()) await group.spawn(self.process_gossip())
@ -1131,195 +1132,137 @@ class Peer(Logger):
self.logger.info(f"on_update_add_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"on_update_add_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
cltv_expiry = int.from_bytes(payload["cltv_expiry"], 'big') cltv_expiry = int.from_bytes(payload["cltv_expiry"], 'big')
amount_msat_htlc = int.from_bytes(payload["amount_msat"], 'big') amount_msat_htlc = int.from_bytes(payload["amount_msat"], 'big')
onion_packet = OnionPacket.from_bytes(payload["onion_routing_packet"]) onion_packet = payload["onion_routing_packet"]
processed_onion = process_onion_packet(onion_packet, associated_data=payment_hash, our_onion_private_key=self.privkey)
if chan.get_state() != channel_states.OPEN: if chan.get_state() != channel_states.OPEN:
raise RemoteMisbehaving(f"received update_add_htlc while chan.get_state() != OPEN. state was {chan.get_state()}") raise RemoteMisbehaving(f"received update_add_htlc while chan.get_state() != OPEN. state was {chan.get_state()}")
if cltv_expiry > bitcoin.NLOCKTIME_BLOCKHEIGHT_MAX: if cltv_expiry > bitcoin.NLOCKTIME_BLOCKHEIGHT_MAX:
asyncio.ensure_future(self.lnworker.force_close_channel(channel_id)) asyncio.ensure_future(self.lnworker.force_close_channel(channel_id))
raise RemoteMisbehaving(f"received update_add_htlc with cltv_expiry > BLOCKHEIGHT_MAX. value was {cltv_expiry}") raise RemoteMisbehaving(f"received update_add_htlc with cltv_expiry > BLOCKHEIGHT_MAX. value was {cltv_expiry}")
# add htlc # add htlc
htlc = UpdateAddHtlc(amount_msat=amount_msat_htlc, htlc = UpdateAddHtlc(
payment_hash=payment_hash, amount_msat=amount_msat_htlc,
cltv_expiry=cltv_expiry, payment_hash=payment_hash,
timestamp=int(time.time()), cltv_expiry=cltv_expiry,
htlc_id=htlc_id) timestamp=int(time.time()),
htlc = chan.receive_htlc(htlc) htlc_id=htlc_id)
# TODO: fulfilling/failing/forwarding of htlcs should be robust to going offline. chan.receive_htlc(htlc, onion_packet)
# instead of storing state implicitly in coroutines, we could decouple it from receiving the htlc.
# maybe persist the required details, and have a long-running task that makes these decisions. def maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *,
local_ctn = chan.get_latest_ctn(LOCAL) onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket):
remote_ctn = chan.get_latest_ctn(REMOTE)
if processed_onion.are_we_final:
asyncio.ensure_future(self._maybe_fulfill_htlc(chan=chan,
htlc=htlc,
local_ctn=local_ctn,
remote_ctn=remote_ctn,
onion_packet=onion_packet,
processed_onion=processed_onion))
else:
asyncio.ensure_future(self._maybe_forward_htlc(chan=chan,
htlc=htlc,
local_ctn=local_ctn,
remote_ctn=remote_ctn,
onion_packet=onion_packet,
processed_onion=processed_onion))
@log_exceptions
async def _maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int,
onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket):
await self.await_local(chan, local_ctn)
await self.await_remote(chan, remote_ctn)
# Forward HTLC # Forward HTLC
# FIXME: this is not robust to us going offline before payment is fulfilled
# FIXME: there are critical safety checks MISSING here # FIXME: there are critical safety checks MISSING here
forwarding_enabled = self.network.config.get('lightning_forward_payments', False) forwarding_enabled = self.network.config.get('lightning_forward_payments', False)
if not forwarding_enabled: if not forwarding_enabled:
self.logger.info(f"forwarding is disabled. failing htlc.") self.logger.info(f"forwarding is disabled. failing htlc.")
reason = OnionRoutingFailureMessage(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'') return OnionRoutingFailureMessage(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'')
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
dph = processed_onion.hop_data.per_hop dph = processed_onion.hop_data.per_hop
next_chan = self.lnworker.get_channel_by_short_id(dph.short_channel_id) next_chan = self.lnworker.get_channel_by_short_id(dph.short_channel_id)
next_chan_scid = dph.short_channel_id next_chan_scid = dph.short_channel_id
next_peer = self.lnworker.peers[next_chan.node_id]
local_height = self.network.get_local_height() local_height = self.network.get_local_height()
if next_chan is None: if next_chan is None:
self.logger.info(f"cannot forward htlc. cannot find next_chan {next_chan_scid}") self.logger.info(f"cannot forward htlc. cannot find next_chan {next_chan_scid}")
reason = OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') return OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update()[2:] outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update()[2:]
outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big") outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big")
if not next_chan.can_send_update_add_htlc(): if not next_chan.can_send_update_add_htlc():
self.logger.info(f"cannot forward htlc. next_chan {next_chan_scid} cannot send ctx updates. " self.logger.info(f"cannot forward htlc. next_chan {next_chan_scid} cannot send ctx updates. "
f"chan state {next_chan.get_state()}, peer state: {next_chan.peer_state}") f"chan state {next_chan.get_state()}, peer state: {next_chan.peer_state}")
reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data = outgoing_chan_upd_len + outgoing_chan_upd
data=outgoing_chan_upd_len+outgoing_chan_upd) return OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=data)
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
next_cltv_expiry = int.from_bytes(dph.outgoing_cltv_value, 'big') next_cltv_expiry = int.from_bytes(dph.outgoing_cltv_value, 'big')
if htlc.cltv_expiry - next_cltv_expiry < NBLOCK_OUR_CLTV_EXPIRY_DELTA: if htlc.cltv_expiry - next_cltv_expiry < NBLOCK_OUR_CLTV_EXPIRY_DELTA:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_CLTV_EXPIRY, data = htlc.cltv_expiry.to_bytes(4, byteorder="big") + outgoing_chan_upd_len + outgoing_chan_upd
data=(htlc.cltv_expiry.to_bytes(4, byteorder="big") return OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_CLTV_EXPIRY, data=data)
+ outgoing_chan_upd_len + outgoing_chan_upd))
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
if htlc.cltv_expiry - lnutil.NBLOCK_DEADLINE_BEFORE_EXPIRY_FOR_RECEIVED_HTLCS <= local_height \ if htlc.cltv_expiry - lnutil.NBLOCK_DEADLINE_BEFORE_EXPIRY_FOR_RECEIVED_HTLCS <= local_height \
or next_cltv_expiry <= local_height: or next_cltv_expiry <= local_height:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.EXPIRY_TOO_SOON, data = outgoing_chan_upd_len + outgoing_chan_upd
data=outgoing_chan_upd_len+outgoing_chan_upd) return OnionRoutingFailureMessage(code=OnionFailureCode.EXPIRY_TOO_SOON, data=data)
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
if max(htlc.cltv_expiry, next_cltv_expiry) > local_height + lnutil.NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE: if max(htlc.cltv_expiry, next_cltv_expiry) > local_height + lnutil.NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.EXPIRY_TOO_FAR, data=b'') return OnionRoutingFailureMessage(code=OnionFailureCode.EXPIRY_TOO_FAR, data=b'')
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
next_amount_msat_htlc = int.from_bytes(dph.amt_to_forward, 'big') next_amount_msat_htlc = int.from_bytes(dph.amt_to_forward, 'big')
forwarding_fees = fee_for_edge_msat(forwarded_amount_msat=next_amount_msat_htlc, forwarding_fees = fee_for_edge_msat(
fee_base_msat=lnutil.OUR_FEE_BASE_MSAT, forwarded_amount_msat=next_amount_msat_htlc,
fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS) fee_base_msat=lnutil.OUR_FEE_BASE_MSAT,
fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS)
if htlc.amount_msat - next_amount_msat_htlc < forwarding_fees: if htlc.amount_msat - next_amount_msat_htlc < forwarding_fees:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.FEE_INSUFFICIENT, data = next_amount_msat_htlc.to_bytes(8, byteorder="big") + outgoing_chan_upd_len + outgoing_chan_upd
data=(next_amount_msat_htlc.to_bytes(8, byteorder="big") return OnionRoutingFailureMessage(code=OnionFailureCode.FEE_INSUFFICIENT, data=data)
+ outgoing_chan_upd_len + outgoing_chan_upd))
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
self.logger.info(f'forwarding htlc to {next_chan.node_id}') self.logger.info(f'forwarding htlc to {next_chan.node_id}')
next_htlc = UpdateAddHtlc(amount_msat=next_amount_msat_htlc, payment_hash=htlc.payment_hash, cltv_expiry=next_cltv_expiry, timestamp=int(time.time())) next_htlc = UpdateAddHtlc(
amount_msat=next_amount_msat_htlc,
payment_hash=htlc.payment_hash,
cltv_expiry=next_cltv_expiry,
timestamp=int(time.time()))
next_htlc = next_chan.add_htlc(next_htlc) next_htlc = next_chan.add_htlc(next_htlc)
next_remote_ctn = next_chan.get_latest_ctn(REMOTE) next_peer = self.lnworker.peers[next_chan.node_id]
next_peer.send_message( try:
"update_add_htlc", next_peer.send_message(
channel_id=next_chan.channel_id, "update_add_htlc",
id=next_htlc.htlc_id, channel_id=next_chan.channel_id,
cltv_expiry=dph.outgoing_cltv_value, id=next_htlc.htlc_id,
amount_msat=dph.amt_to_forward, cltv_expiry=dph.outgoing_cltv_value,
payment_hash=next_htlc.payment_hash, amount_msat=dph.amt_to_forward,
onion_routing_packet=processed_onion.next_packet.to_bytes() payment_hash=next_htlc.payment_hash,
) onion_routing_packet=processed_onion.next_packet.to_bytes()
await next_peer.await_remote(next_chan, next_remote_ctn) )
success, preimage, reason = await self.lnworker.await_payment(next_htlc.payment_hash) except BaseException as e:
if success: self.logger.info(f"failed to forward htlc: error sending message. {e}")
await self._fulfill_htlc(chan, htlc.htlc_id, preimage) data = outgoing_chan_upd_len + outgoing_chan_upd
self.logger.info("htlc forwarded successfully") return OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=data)
else: return None
# TODO: test this
self.logger.info(f"forwarded htlc has failed, {reason}") def maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *,
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket):
@log_exceptions
async def _maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int,
onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket):
await self.await_local(chan, local_ctn)
await self.await_remote(chan, remote_ctn)
try: try:
info = self.lnworker.get_payment_info(htlc.payment_hash) info = self.lnworker.get_payment_info(htlc.payment_hash)
preimage = self.lnworker.get_preimage(htlc.payment_hash) preimage = self.lnworker.get_preimage(htlc.payment_hash)
except UnknownPaymentHash: except UnknownPaymentHash:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return False, reason
return
expected_received_msat = int(info.amount * 1000) if info.amount is not None else None expected_received_msat = int(info.amount * 1000) if info.amount is not None else None
if expected_received_msat is not None and \ if expected_received_msat is not None and \
not (expected_received_msat <= htlc.amount_msat <= 2 * expected_received_msat): not (expected_received_msat <= htlc.amount_msat <= 2 * expected_received_msat):
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return False, reason
return
local_height = self.network.get_local_height() local_height = self.network.get_local_height()
if local_height + MIN_FINAL_CLTV_EXPIRY_ACCEPTED > htlc.cltv_expiry: if local_height + MIN_FINAL_CLTV_EXPIRY_ACCEPTED > htlc.cltv_expiry:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_EXPIRY_TOO_SOON, data=b'') reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_EXPIRY_TOO_SOON, data=b'')
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return False, reason
return
cltv_from_onion = int.from_bytes(processed_onion.hop_data.per_hop.outgoing_cltv_value, byteorder="big") cltv_from_onion = int.from_bytes(processed_onion.hop_data.per_hop.outgoing_cltv_value, byteorder="big")
if cltv_from_onion != htlc.cltv_expiry: if cltv_from_onion != htlc.cltv_expiry:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_CLTV_EXPIRY, reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_CLTV_EXPIRY,
data=htlc.cltv_expiry.to_bytes(4, byteorder="big")) data=htlc.cltv_expiry.to_bytes(4, byteorder="big"))
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return False, reason
return
amount_from_onion = int.from_bytes(processed_onion.hop_data.per_hop.amt_to_forward, byteorder="big") amount_from_onion = int.from_bytes(processed_onion.hop_data.per_hop.amt_to_forward, byteorder="big")
if amount_from_onion > htlc.amount_msat: if amount_from_onion > htlc.amount_msat:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT, reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT,
data=htlc.amount_msat.to_bytes(8, byteorder="big")) data=htlc.amount_msat.to_bytes(8, byteorder="big"))
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return False, reason
return # all good
#self.network.trigger_callback('htlc_added', htlc, invoice, RECEIVED) return preimage, None
await self.lnworker.enable_htlc_settle.wait()
await self._fulfill_htlc(chan, htlc.htlc_id, preimage)
async def _fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes):
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
if not chan.can_send_ctx_updates(): assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
self.logger.info(f"dropping chan update (fulfill htlc {htlc_id}) for {chan.short_channel_id}. "
f"cannot send updates")
return
chan.settle_htlc(preimage, htlc_id) chan.settle_htlc(preimage, htlc_id)
payment_hash = sha256(preimage) payment_hash = sha256(preimage)
self.lnworker.payment_received(payment_hash) self.lnworker.payment_received(payment_hash)
remote_ctn = chan.get_latest_ctn(REMOTE)
self.send_message("update_fulfill_htlc", self.send_message("update_fulfill_htlc",
channel_id=chan.channel_id, channel_id=chan.channel_id,
id=htlc_id, id=htlc_id,
payment_preimage=preimage) payment_preimage=preimage)
await self.await_remote(chan, remote_ctn)
async def fail_htlc(self, chan: Channel, htlc_id: int, onion_packet: OnionPacket, def fail_htlc(self, chan: Channel, htlc_id: int, onion_packet: OnionPacket,
reason: OnionRoutingFailureMessage): reason: OnionRoutingFailureMessage):
self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}. reason: {reason}") self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}. reason: {reason}")
if not chan.can_send_ctx_updates(): assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
self.logger.info(f"dropping chan update (fail htlc {htlc_id}) for {chan.short_channel_id}. "
f"cannot send updates")
return
chan.fail_htlc(htlc_id) chan.fail_htlc(htlc_id)
remote_ctn = chan.get_latest_ctn(REMOTE)
error_packet = construct_onion_error(reason, onion_packet, our_onion_private_key=self.privkey) error_packet = construct_onion_error(reason, onion_packet, our_onion_private_key=self.privkey)
self.send_message("update_fail_htlc", self.send_message("update_fail_htlc",
channel_id=chan.channel_id, channel_id=chan.channel_id,
id=htlc_id, id=htlc_id,
len=len(error_packet), len=len(error_packet),
reason=error_packet) reason=error_packet)
await self.await_remote(chan, remote_ctn)
def on_revoke_and_ack(self, payload): def on_revoke_and_ack(self, payload):
channel_id = payload["channel_id"] channel_id = payload["channel_id"]
@ -1484,3 +1427,52 @@ class Peer(Logger):
# broadcast # broadcast
await self.network.try_broadcasting(closing_tx, 'closing') await self.network.try_broadcasting(closing_tx, 'closing')
return closing_tx.txid() return closing_tx.txid()
async def htlc_switch(self):
while True:
await asyncio.sleep(0.1)
for chan_id, chan in self.channels.items():
if not chan.can_send_ctx_updates():
continue
self.maybe_send_commitment(chan)
done = set()
unfulfilled = chan.hm.log.get('unfulfilled_htlcs', {})
for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarded) in unfulfilled.items():
if chan.get_oldest_unrevoked_ctn(LOCAL) <= local_ctn:
continue
if chan.get_oldest_unrevoked_ctn(REMOTE) <= remote_ctn:
continue
chan.logger.info(f'found unfulfilled htlc: {htlc_id}')
onion_packet = OnionPacket.from_bytes(bytes.fromhex(onion_packet_hex))
htlc = chan.hm.log[REMOTE]['adds'][htlc_id]
payment_hash = htlc.payment_hash
processed_onion = process_onion_packet(onion_packet, associated_data=payment_hash, our_onion_private_key=self.privkey)
preimage, error = None, None
if processed_onion.are_we_final:
preimage, error = self.maybe_fulfill_htlc(
chan=chan,
htlc=htlc,
onion_packet=onion_packet,
processed_onion=processed_onion)
elif not forwarded:
error = self.maybe_forward_htlc(
chan=chan,
htlc=htlc,
onion_packet=onion_packet,
processed_onion=processed_onion)
if not error:
unfulfilled[htlc_id] = local_ctn, remote_ctn, onion_packet_hex, True
else:
f = self.lnworker.pending_payments[payment_hash]
if f.done():
success, preimage, error = f.result()
if preimage:
await self.lnworker.enable_htlc_settle.wait()
self.fulfill_htlc(chan, htlc.htlc_id, preimage)
done.add(htlc_id)
if error:
self.fail_htlc(chan, htlc.htlc_id, onion_packet, error)
done.add(htlc_id)
# cleanup
for htlc_id in done:
unfulfilled.pop(htlc_id)

2
electrum/lnworker.py

@ -55,7 +55,7 @@ from .lnutil import (Outpoint, LNPeerAddr,
ShortChannelID, PaymentAttemptLog, PaymentAttemptFailureDetails) ShortChannelID, PaymentAttemptLog, PaymentAttemptFailureDetails)
from .lnutil import ln_dummy_address, ln_compare_features from .lnutil import ln_dummy_address, ln_compare_features
from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput
from .lnonion import OnionFailureCode from .lnonion import OnionFailureCode, process_onion_packet, OnionPacket
from .lnmsg import decode_msg from .lnmsg import decode_msg
from .i18n import _ from .i18n import _
from .lnrouter import RouteEdge, LNPaymentRoute, is_route_sane_to_use from .lnrouter import RouteEdge, LNPaymentRoute, is_route_sane_to_use

12
electrum/tests/test_lnpeer.py

@ -238,7 +238,7 @@ class TestPeer(ElectrumTestCase):
self.assertEqual(alice_channel.peer_state, peer_states.GOOD) self.assertEqual(alice_channel.peer_state, peer_states.GOOD)
self.assertEqual(bob_channel.peer_state, peer_states.GOOD) self.assertEqual(bob_channel.peer_state, peer_states.GOOD)
gath.cancel() gath.cancel()
gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop()) gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p1.htlc_switch())
async def f(): async def f():
await gath await gath
with self.assertRaises(concurrent.futures.CancelledError): with self.assertRaises(concurrent.futures.CancelledError):
@ -253,7 +253,7 @@ class TestPeer(ElectrumTestCase):
result = await LNWallet._pay(w1, pay_req) result = await LNWallet._pay(w1, pay_req)
self.assertEqual(result, True) self.assertEqual(result, True)
gath.cancel() gath.cancel()
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop()) gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
async def f(): async def f():
await gath await gath
with self.assertRaises(concurrent.futures.CancelledError): with self.assertRaises(concurrent.futures.CancelledError):
@ -271,7 +271,7 @@ class TestPeer(ElectrumTestCase):
# wait so that pending messages are processed # wait so that pending messages are processed
#await asyncio.sleep(1) #await asyncio.sleep(1)
gath.cancel() gath.cancel()
gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop()) gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
async def f(): async def f():
await gath await gath
with self.assertRaises(concurrent.futures.CancelledError): with self.assertRaises(concurrent.futures.CancelledError):
@ -285,7 +285,7 @@ class TestPeer(ElectrumTestCase):
result = await LNWallet._pay(w1, pay_req) result = await LNWallet._pay(w1, pay_req)
self.assertTrue(result) self.assertTrue(result)
gath.cancel() gath.cancel()
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop()) gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
async def f(): async def f():
await gath await gath
with self.assertRaises(concurrent.futures.CancelledError): with self.assertRaises(concurrent.futures.CancelledError):
@ -313,7 +313,7 @@ class TestPeer(ElectrumTestCase):
async def set_settle(): async def set_settle():
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
w2.enable_htlc_settle.set() w2.enable_htlc_settle.set()
gath = asyncio.gather(pay(), set_settle(), p1._message_loop(), p2._message_loop()) gath = asyncio.gather(pay(), set_settle(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
async def f(): async def f():
await gath await gath
with self.assertRaises(concurrent.futures.CancelledError): with self.assertRaises(concurrent.futures.CancelledError):
@ -338,7 +338,7 @@ class TestPeer(ElectrumTestCase):
# AssertionError is ok since we shouldn't use old routes, and the # AssertionError is ok since we shouldn't use old routes, and the
# route finding should fail when channel is closed # route finding should fail when channel is closed
async def f(): async def f():
await asyncio.gather(w1._pay_to_route(route, addr), p1._message_loop(), p2._message_loop()) await asyncio.gather(w1._pay_to_route(route, addr), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
with self.assertRaises(PaymentFailure): with self.assertRaises(PaymentFailure):
run(f()) run(f())

Loading…
Cancel
Save