Browse Source

replace await_local/remote

hard-fail-on-bad-server-string
ThomasV 5 years ago
parent
commit
b9eaba3e85
  1. 10
      electrum/lnchannel.py
  2. 5
      electrum/lnhtlc.py
  3. 63
      electrum/lnpeer.py
  4. 13
      electrum/lnworker.py
  5. 2
      electrum/tests/test_lnchannel.py
  6. 7
      electrum/tests/test_lnpeer.py

10
electrum/lnchannel.py

@ -152,6 +152,7 @@ class Channel(Logger):
self._chan_ann_without_sigs = None # type: Optional[bytes] self._chan_ann_without_sigs = None # type: Optional[bytes]
self.revocation_store = RevocationStore(state["revocation_store"]) self.revocation_store = RevocationStore(state["revocation_store"])
self._can_send_ctx_updates = True # type: bool self._can_send_ctx_updates = True # type: bool
self._receive_fail_reasons = {}
def get_id_for_log(self) -> str: def get_id_for_log(self) -> str:
scid = self.short_channel_id scid = self.short_channel_id
@ -562,11 +563,15 @@ class Channel(Logger):
self.hm.send_rev() self.hm.send_rev()
received = self.hm.received_in_ctn(new_ctn) received = self.hm.received_in_ctn(new_ctn)
sent = self.hm.sent_in_ctn(new_ctn) sent = self.hm.sent_in_ctn(new_ctn)
failed = self.hm.failed_in_ctn(new_ctn)
if self.lnworker: if self.lnworker:
for htlc in received: for htlc in received:
self.lnworker.payment_completed(self, RECEIVED, htlc) self.lnworker.payment_completed(self, RECEIVED, htlc)
for htlc in sent: for htlc in sent:
self.lnworker.payment_completed(self, SENT, htlc) self.lnworker.payment_completed(self, SENT, htlc)
for htlc in failed:
reason = self._receive_fail_reasons.get(htlc.htlc_id)
self.lnworker.payment_failed(htlc.payment_hash, reason)
received_this_batch = htlcsum(received) received_this_batch = htlcsum(received)
sent_this_batch = htlcsum(sent) sent_this_batch = htlcsum(sent)
last_secret, last_point = self.get_secret_and_point(LOCAL, new_ctn - 1) last_secret, last_point = self.get_secret_and_point(LOCAL, new_ctn - 1)
@ -575,12 +580,10 @@ class Channel(Logger):
def receive_revocation(self, revocation: RevokeAndAck): def receive_revocation(self, revocation: RevokeAndAck):
self.logger.info("receive_revocation") self.logger.info("receive_revocation")
cur_point = self.config[REMOTE].current_per_commitment_point cur_point = self.config[REMOTE].current_per_commitment_point
derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True) derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True)
if cur_point != derived_point: if cur_point != derived_point:
raise Exception('revoked secret not for current point') raise Exception('revoked secret not for current point')
with self.db_lock: with self.db_lock:
self.revocation_store.add_next_entry(revocation.per_commitment_secret) self.revocation_store.add_next_entry(revocation.per_commitment_secret)
##### start applying fee/htlc changes ##### start applying fee/htlc changes
@ -763,10 +766,11 @@ class Channel(Logger):
with self.db_lock: with self.db_lock:
self.hm.send_fail(htlc_id) self.hm.send_fail(htlc_id)
def receive_fail_htlc(self, htlc_id): def receive_fail_htlc(self, htlc_id, reason):
self.logger.info("receive_fail_htlc") self.logger.info("receive_fail_htlc")
with self.db_lock: with self.db_lock:
self.hm.recv_fail(htlc_id) self.hm.recv_fail(htlc_id)
self._receive_fail_reasons[htlc_id] = reason
def pending_local_fee(self): def pending_local_fee(self):
return self.constraints.capacity - sum(x.value for x in self.get_next_commitment(LOCAL).outputs()) return self.constraints.capacity - sum(x.value for x in self.get_next_commitment(LOCAL).outputs())

5
electrum/lnhtlc.py

@ -298,6 +298,11 @@ class HTLCManager:
for htlc_id, ctns in self.log[LOCAL]['settles'].items() for htlc_id, ctns in self.log[LOCAL]['settles'].items()
if ctns[LOCAL] == ctn] if ctns[LOCAL] == ctn]
def failed_in_ctn(self, ctn: int) -> Sequence[UpdateAddHtlc]:
return [self.log[LOCAL]['adds'][htlc_id]
for htlc_id, ctns in self.log[LOCAL]['fails'].items()
if ctns[LOCAL] == ctn]
##### Queries re Fees: ##### Queries re Fees:
def get_feerate(self, subject: HTLCOwner, ctn: int) -> int: def get_feerate(self, subject: HTLCOwner, ctn: int) -> int:

63
electrum/lnpeer.py

@ -93,8 +93,6 @@ class Peer(Logger):
self.shutdown_received = {} self.shutdown_received = {}
self.announcement_signatures = defaultdict(asyncio.Queue) self.announcement_signatures = defaultdict(asyncio.Queue)
self.orphan_channel_updates = OrderedDict() self.orphan_channel_updates = OrderedDict()
self._local_changed_events = defaultdict(asyncio.Event)
self._remote_changed_events = defaultdict(asyncio.Event)
Logger.__init__(self) Logger.__init__(self)
self.taskgroup = SilentTaskGroup() self.taskgroup = SilentTaskGroup()
@ -1006,16 +1004,8 @@ class Peer(Logger):
reason = payload["reason"] reason = payload["reason"]
chan = self.channels[channel_id] chan = self.channels[channel_id]
self.logger.info(f"on_update_fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"on_update_fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
chan.receive_fail_htlc(htlc_id) chan.receive_fail_htlc(htlc_id, reason)
local_ctn = chan.get_latest_ctn(LOCAL) self.maybe_send_commitment(chan)
asyncio.ensure_future(self._on_update_fail_htlc(channel_id, htlc_id, local_ctn, reason))
@log_exceptions
async def _on_update_fail_htlc(self, channel_id, htlc_id, local_ctn, reason):
chan = self.channels[channel_id]
await self.await_local(chan, local_ctn)
payment_hash = chan.get_payment_hash(htlc_id)
self.lnworker.payment_failed(payment_hash, reason)
def maybe_send_commitment(self, chan: Channel): def maybe_send_commitment(self, chan: Channel):
# REMOTE should revoke first before we can sign a new ctx # REMOTE should revoke first before we can sign a new ctx
@ -1028,27 +1018,9 @@ class Peer(Logger):
sig_64, htlc_sigs = chan.sign_next_commitment() sig_64, htlc_sigs = chan.sign_next_commitment()
self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs)) self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs))
async def await_remote(self, chan: Channel, ctn: int): def pay(self, route: 'LNPaymentRoute', chan: Channel, amount_msat: int,
"""Wait until remote 'ctn' gets revoked."""
# if 'ctn' is too high, we risk waiting "forever", hence assert:
assert chan.get_latest_ctn(REMOTE) >= ctn, (chan.get_latest_ctn(REMOTE), ctn)
self.maybe_send_commitment(chan)
while chan.get_oldest_unrevoked_ctn(REMOTE) <= ctn:
await self._remote_changed_events[chan.channel_id].wait()
async def await_local(self, chan: Channel, ctn: int):
"""Wait until local 'ctn' gets revoked."""
# if 'ctn' is too high, we risk waiting "forever", hence assert:
assert chan.get_latest_ctn(LOCAL) >= ctn, (chan.get_latest_ctn(LOCAL), ctn)
self.maybe_send_commitment(chan)
while chan.get_oldest_unrevoked_ctn(LOCAL) <= ctn:
await self._local_changed_events[chan.channel_id].wait()
async def pay(self, route: 'LNPaymentRoute', chan: Channel, amount_msat: int,
payment_hash: bytes, min_final_cltv_expiry: int) -> UpdateAddHtlc: payment_hash: bytes, min_final_cltv_expiry: int) -> UpdateAddHtlc:
assert amount_msat > 0, "amount_msat is not greater zero" assert amount_msat > 0, "amount_msat is not greater zero"
await asyncio.wait_for(self.initialized, LN_P2P_NETWORK_TIMEOUT)
# TODO also wait for channel reestablish to finish. (combine timeout with waiting for init?)
if not chan.can_send_update_add_htlc(): if not chan.can_send_update_add_htlc():
raise PaymentFailure("Channel cannot send update_add_htlc") raise PaymentFailure("Channel cannot send update_add_htlc")
# create onion packet # create onion packet
@ -1060,25 +1032,23 @@ class Peer(Logger):
# create htlc # create htlc
htlc = UpdateAddHtlc(amount_msat=amount_msat, payment_hash=payment_hash, cltv_expiry=cltv, timestamp=int(time.time())) htlc = UpdateAddHtlc(amount_msat=amount_msat, payment_hash=payment_hash, cltv_expiry=cltv, timestamp=int(time.time()))
htlc = chan.add_htlc(htlc) htlc = chan.add_htlc(htlc)
remote_ctn = chan.get_latest_ctn(REMOTE)
chan.set_onion_key(htlc.htlc_id, secret_key) chan.set_onion_key(htlc.htlc_id, secret_key)
self.logger.info(f"starting payment. len(route)={len(route)}. route: {route}. htlc: {htlc}") self.logger.info(f"starting payment. len(route)={len(route)}. route: {route}. htlc: {htlc}")
self.send_message("update_add_htlc", self.send_message(
"update_add_htlc",
channel_id=chan.channel_id, channel_id=chan.channel_id,
id=htlc.htlc_id, id=htlc.htlc_id,
cltv_expiry=htlc.cltv_expiry, cltv_expiry=htlc.cltv_expiry,
amount_msat=htlc.amount_msat, amount_msat=htlc.amount_msat,
payment_hash=htlc.payment_hash, payment_hash=htlc.payment_hash,
onion_routing_packet=onion.to_bytes()) onion_routing_packet=onion.to_bytes())
await self.await_remote(chan, remote_ctn) self.maybe_send_commitment(chan)
return htlc return htlc
def send_revoke_and_ack(self, chan: Channel): def send_revoke_and_ack(self, chan: Channel):
self.logger.info(f'send_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(LOCAL)}') self.logger.info(f'send_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(LOCAL)}')
rev, _ = chan.revoke_current_commitment() rev, _ = chan.revoke_current_commitment()
self.lnworker.save_channel(chan) self.lnworker.save_channel(chan)
self._local_changed_events[chan.channel_id].set()
self._local_changed_events[chan.channel_id].clear()
self.send_message("revoke_and_ack", self.send_message("revoke_and_ack",
channel_id=chan.channel_id, channel_id=chan.channel_id,
per_commitment_secret=rev.per_commitment_secret, per_commitment_secret=rev.per_commitment_secret,
@ -1113,13 +1083,7 @@ class Peer(Logger):
self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
chan.receive_htlc_settle(preimage, htlc_id) chan.receive_htlc_settle(preimage, htlc_id)
self.lnworker.save_preimage(payment_hash, preimage) self.lnworker.save_preimage(payment_hash, preimage)
local_ctn = chan.get_latest_ctn(LOCAL) self.maybe_send_commitment(chan)
asyncio.ensure_future(self._on_update_fulfill_htlc(chan, local_ctn, payment_hash))
@log_exceptions
async def _on_update_fulfill_htlc(self, chan, local_ctn, payment_hash):
await self.await_local(chan, local_ctn)
self.lnworker.payment_sent(payment_hash)
def on_update_fail_malformed_htlc(self, payload): def on_update_fail_malformed_htlc(self, payload):
self.logger.info(f"on_update_fail_malformed_htlc. error {payload['data'].decode('ascii')}") self.logger.info(f"on_update_fail_malformed_htlc. error {payload['data'].decode('ascii')}")
@ -1272,8 +1236,6 @@ class Peer(Logger):
self.logger.info(f'on_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(REMOTE)}') self.logger.info(f'on_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(REMOTE)}')
rev = RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"]) rev = RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"])
chan.receive_revocation(rev) chan.receive_revocation(rev)
self._remote_changed_events[chan.channel_id].set()
self._remote_changed_events[chan.channel_id].clear()
self.lnworker.save_channel(chan) self.lnworker.save_channel(chan)
self.maybe_send_commitment(chan) self.maybe_send_commitment(chan)
@ -1303,11 +1265,11 @@ class Peer(Logger):
self.logger.info(f"(chan: {chan.get_id_for_log()}) current pending feerate {chan_fee}. " self.logger.info(f"(chan: {chan.get_id_for_log()}) current pending feerate {chan_fee}. "
f"new feerate {feerate_per_kw}") f"new feerate {feerate_per_kw}")
chan.update_fee(feerate_per_kw, True) chan.update_fee(feerate_per_kw, True)
remote_ctn = chan.get_latest_ctn(REMOTE) self.send_message(
self.send_message("update_fee", "update_fee",
channel_id=chan.channel_id, channel_id=chan.channel_id,
feerate_per_kw=feerate_per_kw) feerate_per_kw=feerate_per_kw)
await self.await_remote(chan, remote_ctn) self.maybe_send_commitment(chan)
@log_exceptions @log_exceptions
async def close_channel(self, chan_id: bytes): async def close_channel(self, chan_id: bytes):
@ -1351,9 +1313,8 @@ class Peer(Logger):
scriptpubkey = bfh(bitcoin.address_to_script(chan.sweep_address)) scriptpubkey = bfh(bitcoin.address_to_script(chan.sweep_address))
# wait until no more pending updates (bolt2) # wait until no more pending updates (bolt2)
chan.set_can_send_ctx_updates(False) chan.set_can_send_ctx_updates(False)
ctn = chan.get_latest_ctn(REMOTE) while chan.has_pending_changes(REMOTE):
if chan.has_pending_changes(REMOTE): await asyncio.sleep(0.1)
await self.await_remote(chan, ctn)
self.send_message('shutdown', channel_id=chan.channel_id, len=len(scriptpubkey), scriptpubkey=scriptpubkey) self.send_message('shutdown', channel_id=chan.channel_id, len=len(scriptpubkey), scriptpubkey=scriptpubkey)
chan.set_state(channel_states.CLOSING) chan.set_state(channel_states.CLOSING)
# can fullfill or fail htlcs. cannot add htlcs, because of CLOSING state # can fullfill or fail htlcs. cannot add htlcs, because of CLOSING state

13
electrum/lnworker.py

@ -523,6 +523,8 @@ class LNWallet(LNWorker):
preimage = self.get_preimage(htlc.payment_hash) preimage = self.get_preimage(htlc.payment_hash)
timestamp = int(time.time()) timestamp = int(time.time())
self.network.trigger_callback('ln_payment_completed', timestamp, direction, htlc, preimage, chan_id) self.network.trigger_callback('ln_payment_completed', timestamp, direction, htlc, preimage, chan_id)
if direction == SENT:
self.payment_sent(htlc.payment_hash)
def get_settled_payments(self): def get_settled_payments(self):
# return one item per payment_hash # return one item per payment_hash
@ -952,7 +954,8 @@ class LNWallet(LNWorker):
peer = self.peers.get(route[0].node_id) peer = self.peers.get(route[0].node_id)
if not peer: if not peer:
raise Exception('Dropped peer') raise Exception('Dropped peer')
htlc = await peer.pay(route, chan, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry()) await peer.initialized
htlc = peer.pay(route, chan, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry())
self.network.trigger_callback('htlc_added', htlc, lnaddr, SENT) self.network.trigger_callback('htlc_added', htlc, lnaddr, SENT)
success, preimage, reason = await self.await_payment(lnaddr.paymenthash) success, preimage, reason = await self.await_payment(lnaddr.paymenthash)
if success: if success:
@ -1207,12 +1210,16 @@ class LNWallet(LNWorker):
def payment_failed(self, payment_hash: bytes, reason): def payment_failed(self, payment_hash: bytes, reason):
self.set_payment_status(payment_hash, PR_UNPAID) self.set_payment_status(payment_hash, PR_UNPAID)
self.pending_payments[payment_hash].set_result((False, None, reason)) f = self.pending_payments[payment_hash]
if not f.cancelled():
f.set_result((False, None, reason))
def payment_sent(self, payment_hash: bytes): def payment_sent(self, payment_hash: bytes):
self.set_payment_status(payment_hash, PR_PAID) self.set_payment_status(payment_hash, PR_PAID)
preimage = self.get_preimage(payment_hash) preimage = self.get_preimage(payment_hash)
self.pending_payments[payment_hash].set_result((True, preimage, None)) f = self.pending_payments[payment_hash]
if not f.cancelled():
f.set_result((True, preimage, None))
def payment_received(self, payment_hash: bytes): def payment_received(self, payment_hash: bytes):
self.set_payment_status(payment_hash, PR_PAID) self.set_payment_status(payment_hash, PR_PAID)

2
electrum/tests/test_lnchannel.py

@ -618,7 +618,7 @@ class TestAvailableToSpend(ElectrumTestCase):
bob_idx = bob_channel.receive_htlc(htlc_dict).htlc_id bob_idx = bob_channel.receive_htlc(htlc_dict).htlc_id
force_state_transition(alice_channel, bob_channel) force_state_transition(alice_channel, bob_channel)
bob_channel.fail_htlc(bob_idx) bob_channel.fail_htlc(bob_idx)
alice_channel.receive_fail_htlc(alice_idx) alice_channel.receive_fail_htlc(alice_idx, None)
# Alice now has gotten all her original balance (5 BTC) back, however, # Alice now has gotten all her original balance (5 BTC) back, however,
# adding a new HTLC at this point SHOULD fail, since if she adds the # adding a new HTLC at this point SHOULD fail, since if she adds the
# HTLC and signs the next state, Bob cannot assume she received the # HTLC and signs the next state, Bob cannot assume she received the

7
electrum/tests/test_lnpeer.py

@ -131,6 +131,7 @@ class MockLNWallet:
_create_route_from_invoice = LNWallet._create_route_from_invoice _create_route_from_invoice = LNWallet._create_route_from_invoice
_check_invoice = staticmethod(LNWallet._check_invoice) _check_invoice = staticmethod(LNWallet._check_invoice)
_pay_to_route = LNWallet._pay_to_route _pay_to_route = LNWallet._pay_to_route
_pay = LNWallet._pay
force_close_channel = LNWallet.force_close_channel force_close_channel = LNWallet.force_close_channel
get_first_timestamp = lambda self: 0 get_first_timestamp = lambda self: 0
payment_completed = LNWallet.payment_completed payment_completed = LNWallet.payment_completed
@ -250,7 +251,7 @@ class TestPeer(ElectrumTestCase):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
pay_req = self.prepare_invoice(w2) pay_req = self.prepare_invoice(w2)
async def pay(): async def pay():
result = await LNWallet._pay(w1, pay_req) result = await w1._pay(pay_req)
self.assertEqual(result, True) self.assertEqual(result, True)
gath.cancel() gath.cancel()
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
@ -282,7 +283,7 @@ class TestPeer(ElectrumTestCase):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
pay_req = self.prepare_invoice(w2) pay_req = self.prepare_invoice(w2)
async def pay(): async def pay():
result = await LNWallet._pay(w1, pay_req) result = await w1._pay(pay_req)
self.assertTrue(result) self.assertTrue(result)
gath.cancel() gath.cancel()
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
@ -306,7 +307,7 @@ class TestPeer(ElectrumTestCase):
await asyncio.wait_for(p2.initialized, 1) await asyncio.wait_for(p2.initialized, 1)
# alice sends htlc # alice sends htlc
route = await w1._create_route_from_invoice(decoded_invoice=lnaddr) route = await w1._create_route_from_invoice(decoded_invoice=lnaddr)
htlc = await p1.pay(route, alice_channel, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry()) htlc = p1.pay(route, alice_channel, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry())
# alice closes # alice closes
await p1.close_channel(alice_channel.channel_id) await p1.close_channel(alice_channel.channel_id)
gath.cancel() gath.cancel()

Loading…
Cancel
Save