From 399fe08047afaa15f33e95d424f1647ffc9fbd45 Mon Sep 17 00:00:00 2001 From: Janus Date: Fri, 15 Jun 2018 16:35:29 +0200 Subject: [PATCH] ln: avoid code duplication --- lib/lnbase.py | 149 +++++++++++---------------------------- lib/lnhtlc.py | 24 +++++-- lib/tests/test_lnhtlc.py | 2 +- 3 files changed, 63 insertions(+), 112 deletions(-) diff --git a/lib/lnbase.py b/lib/lnbase.py index 63074bfba..e72347803 100644 --- a/lib/lnbase.py +++ b/lib/lnbase.py @@ -984,28 +984,11 @@ class Peer(PrintError): def on_update_fail_htlc(self, payload): print("UPDATE_FAIL_HTLC", decode_onion_error(payload["reason"], self.node_keys, self.secret_key)) - def derive_and_incr(self, chan): - last_small_num = chan.local_state.ctn - next_small_num = last_small_num + 2 - this_small_num = last_small_num + 1 - last_secret = get_per_commitment_secret_from_seed(chan.local_state.per_commitment_secret_seed, 2**48-last_small_num-1) - this_secret = get_per_commitment_secret_from_seed(chan.local_state.per_commitment_secret_seed, 2**48-this_small_num-1) - this_point = secret_to_pubkey(int.from_bytes(this_secret, 'big')) - next_secret = get_per_commitment_secret_from_seed(chan.local_state.per_commitment_secret_seed, 2**48-next_small_num-1) - next_point = secret_to_pubkey(int.from_bytes(next_secret, 'big')) - chan = chan._replace( - local_state=chan.local_state._replace( - ctn=chan.local_state.ctn + 1 - ) - ) - return chan, last_secret, this_point, next_point - @aiosafe async def pay(self, path, chan, amount_msat, payment_hash, pubkey_in_invoice, min_final_cltv_expiry): assert self.channel_state[chan.channel_id] == "OPEN" assert amount_msat > 0, "amount_msat is not greater zero" height = self.network.get_local_height() - their_revstore = chan.remote_state.revocation_store route = self.lnworker.path_finder.create_route_from_path(path, self.lnworker.pubkey) hops_data = [] sum_of_deltas = sum(route_edge.channel_policy.cltv_expiry_delta for route_edge in route[1:]) @@ -1035,69 +1018,53 @@ class Peer(PrintError): self.send_message(gen_msg("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=1, htlc_signature=htlc_sig)) - revoke_and_ack_msg = await self.revoke_and_ack[chan.channel_id].get() - m.receive_revocation(RevokeAndAck(revoke_and_ack_msg["per_commitment_secret"], revoke_and_ack_msg["next_per_commitment_point"])) + await self.receive_revoke(m) - rev, _ = m.revoke_current_commitment() - self.send_message(gen_msg("revoke_and_ack", - channel_id=chan.channel_id, - per_commitment_secret=rev.per_commitment_secret, - next_per_commitment_point=rev.next_per_commitment_point)) - - chan = m.state - - print("waiting for update_fulfill") + self.revoke(m) update_fulfill_htlc_msg = await self.update_fulfill_htlc[chan.channel_id].get() + m.receive_htlc_settle(update_fulfill_htlc_msg["payment_preimage"], int.from_bytes(update_fulfill_htlc_msg["id"], "big")) - print("waiting for commitment_signed") - commitment_signed_msg = await self.commitment_signed[chan.channel_id].get() - - chan, last_secret, _, next_point = self.derive_and_incr(chan) - self.send_message(gen_msg("revoke_and_ack", - channel_id=chan.channel_id, - per_commitment_secret=last_secret, - next_per_commitment_point=next_point)) + self.revoke(m) - next_per_commitment_point = revoke_and_ack_msg["next_per_commitment_point"] + while (await self.commitment_signed[chan.channel_id].get())["htlc_signature"] == b"": + pass + # TODO process above commitment transactions - bare_ctx = make_commitment_using_open_channel(chan, chan.remote_state.ctn + 1, False, next_per_commitment_point, + bare_ctx = make_commitment_using_open_channel(m.state, m.state.remote_state.ctn + 1, False, m.state.remote_state.next_per_commitment_point, msat_remote, msat_local) sig_64 = sign_and_get_sig_string(bare_ctx, chan.local_config, chan.remote_config) - self.send_message(gen_msg("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=0)) + m.state = m.state._replace(remote_state=m.state.remote_state._replace(ctn=m.state.remote_state.ctn + 1)) - revoke_and_ack_msg = await self.revoke_and_ack[chan.channel_id].get() - # TODO check revoke_and_ack results + await self.receive_revoke(m) - chan = chan._replace( - local_state=chan.local_state._replace( - amount_msat=msat_local, - next_htlc_id=chan.local_state.next_htlc_id + 1 - ), - remote_state=chan.remote_state._replace( - ctn=chan.remote_state.ctn + 1, - revocation_store=their_revstore, - last_per_commitment_point=next_per_commitment_point, - next_per_commitment_point=revoke_and_ack_msg["next_per_commitment_point"], - amount_msat=msat_remote - ) - ) - self.lnworker.save_channel(chan) + self.lnworker.save_channel(m.state) + + async def receive_revoke(self, m): + revoke_and_ack_msg = await self.revoke_and_ack[m.state.channel_id].get() + m.receive_revocation(RevokeAndAck(revoke_and_ack_msg["per_commitment_secret"], revoke_and_ack_msg["next_per_commitment_point"])) + + def revoke(self, m): + rev, _ = m.revoke_current_commitment() + self.send_message(gen_msg("revoke_and_ack", + channel_id=m.state.channel_id, + per_commitment_secret=rev.per_commitment_secret, + next_per_commitment_point=rev.next_per_commitment_point)) + + async def receive_commitment(self, m): + commitment_signed_msg = await self.commitment_signed[m.state.channel_id].get() + data = commitment_signed_msg["htlc_signature"] + htlc_sigs = [data[i:i+64] for i in range(0, len(data), 64)] + m.receive_new_commitment(commitment_signed_msg["signature"], htlc_sigs) + return len(htlc_sigs) @aiosafe async def receive_commitment_revoke_ack(self, htlc, decoded, payment_preimage): chan = self.channels[htlc['channel_id']] channel_id = chan.channel_id expected_received_msat = int(decoded.amount * COIN * 1000) - while True: - self.print_error("receiving commitment") - commitment_signed_msg = await self.commitment_signed[channel_id].get() - num_htlcs = int.from_bytes(commitment_signed_msg["num_htlcs"], "big") - print("num_htlcs", num_htlcs) - if num_htlcs == 1: - break htlc_id = int.from_bytes(htlc["id"], 'big') assert htlc_id == chan.remote_state.next_htlc_id, (htlc_id, chan.remote_state.next_htlc_id) @@ -1112,67 +1079,36 @@ class Peer(PrintError): htlc = UpdateAddHtlc(amount_msat, payment_hash, cltv_expiry, 0) m = HTLCStateMachine(chan) + m.receive_htlc(htlc) - data = commitment_signed_msg["htlc_signature"] - htlc_sigs = [data[i:i+64] for i in range(0, len(data), 64)] - m.receive_new_commitment(commitment_signed_msg["signature"], htlc_sigs) + assert (await self.receive_commitment(m)) == 1 - rev, _ = m.revoke_current_commitment() - self.send_message(gen_msg("revoke_and_ack", - channel_id=channel_id, - per_commitment_secret=rev.per_commitment_secret, - next_per_commitment_point=rev.next_per_commitment_point)) + self.revoke(m) sig_64, htlc_sigs = m.sign_next_commitment() - chan = m.state htlc_sig = htlc_sigs[0] self.send_message(gen_msg("commitment_signed", channel_id=channel_id, signature=sig_64, num_htlcs=1, htlc_signature=htlc_sig)) - revoke_and_ack_msg = await self.revoke_and_ack[channel_id].get() - - # TODO check revoke_and_ack_msg contents + await self.receive_revoke(m) + m.settle_htlc(payment_preimage, htlc_id) self.send_message(gen_msg("update_fulfill_htlc", channel_id=channel_id, id=htlc_id, payment_preimage=payment_preimage)) - remote_next_commitment_point = revoke_and_ack_msg["next_per_commitment_point"] - # remote commitment transaction without htlcs - bare_ctx = make_commitment_using_open_channel(chan, chan.remote_state.ctn + 2, False, remote_next_commitment_point, - chan.remote_state.amount_msat - expected_received_msat, chan.local_state.amount_msat + expected_received_msat) - - sig_64 = sign_and_get_sig_string(bare_ctx, chan.local_config, chan.remote_config) - + bare_ctx = make_commitment_using_open_channel(m.state, m.state.remote_state.ctn + 1, False, m.state.remote_state.next_per_commitment_point, + m.state.remote_state.amount_msat - expected_received_msat, m.state.local_state.amount_msat + expected_received_msat) + sig_64 = sign_and_get_sig_string(bare_ctx, m.state.local_config, m.state.remote_config) self.send_message(gen_msg("commitment_signed", channel_id=channel_id, signature=sig_64, num_htlcs=0)) + m.state = m.state._replace(remote_state=m.state.remote_state._replace(ctn=m.state.remote_state.ctn + 1)) - revoke_and_ack_msg = await self.revoke_and_ack[channel_id].get() - - # TODO check revoke_and_ack results - - commitment_signed_msg = await self.commitment_signed[channel_id].get() + await self.receive_revoke(m) - # TODO check commitment_signed results + assert (await self.receive_commitment(m)) == 0 - chan, last_secret, _, next_point = self.derive_and_incr(chan) + self.revoke(m) - self.send_message(gen_msg("revoke_and_ack", - channel_id=channel_id, - per_commitment_secret=last_secret, - next_per_commitment_point=next_point)) - - new_chan = chan._replace( - local_state=chan.local_state._replace( - amount_msat=chan.local_state.amount_msat + expected_received_msat - ), - remote_state=chan.remote_state._replace( - ctn=chan.remote_state.ctn + 2, - last_per_commitment_point=remote_next_commitment_point, - next_per_commitment_point=revoke_and_ack_msg["next_per_commitment_point"], - amount_msat=chan.remote_state.amount_msat - expected_received_msat, - next_htlc_id=htlc_id + 1 - ) - ) - self.lnworker.save_channel(new_chan) + self.lnworker.save_channel(m.state) def on_commitment_signed(self, payload): self.print_error("commitment_signed", payload) @@ -1203,6 +1139,7 @@ class Peer(PrintError): assert False def on_revoke_and_ack(self, payload): + print("got revoke_and_ack") channel_id = payload["channel_id"] self.revoke_and_ack[channel_id].put_nowait(payload) diff --git a/lib/lnhtlc.py b/lib/lnhtlc.py index 180885e87..f4f6f6a89 100644 --- a/lib/lnhtlc.py +++ b/lib/lnhtlc.py @@ -66,7 +66,8 @@ class HTLCStateMachine(PrintError): assert type(htlc) is UpdateAddHtlc self.local_update_log.append(htlc) self.print_error("add_htlc") - htlc_id = len(self.local_update_log)-1 + htlc_id = self.state.local_state.next_htlc_id + self.state = self.state._replace(local_state=self.state.local_state._replace(next_htlc_id=htlc_id + 1)) htlc.htlc_id = htlc_id return htlc_id @@ -79,7 +80,8 @@ class HTLCStateMachine(PrintError): assert type(htlc) is UpdateAddHtlc self.remote_update_log.append(htlc) self.print_error("receive_htlc") - htlc_id = len(self.remote_update_log)-1 + htlc_id = self.state.remote_state.next_htlc_id + self.state = self.state._replace(remote_state=self.state.remote_state._replace(next_htlc_id=htlc_id + 1)) htlc.htlc_id = htlc_id return htlc_id @@ -226,15 +228,23 @@ class HTLCStateMachine(PrintError): continue settle_fails2.append(x) + sent_this_batch = 0 + received_this_batch = 0 + for x in settle_fails2: - self.total_msat_sent += self.lookup_htlc(self.local_update_log, x.htlc_id).amount_msat + htlc = self.lookup_htlc(self.local_update_log, x.htlc_id) + sent_this_batch += htlc.amount_msat + htlc.total_fee + + self.total_msat_sent += sent_this_batch # increase received_msat counter for htlc's that have been settled adds2 = self.gen_htlc_indices("remote") for htlc in adds2: htlc_id = htlc.htlc_id if SettleHtlc(htlc_id) in self.local_update_log: - self.total_msat_received += self.lookup_htlc(self.remote_update_log, htlc_id).amount_msat + htlc = self.lookup_htlc(self.remote_update_log, htlc_id) + received_this_batch += htlc.amount_msat + htlc.total_fee + self.total_msat_received += received_this_batch # log compaction (remove entries relating to htlc's that have been settled) @@ -269,6 +279,10 @@ class HTLCStateMachine(PrintError): ctn=self.state.remote_state.ctn + 1, last_per_commitment_point=next_point, next_per_commitment_point=revocation.next_per_commitment_point, + amount_msat=self.state.remote_state.amount_msat + (sent_this_batch - received_this_batch) + ), + local_state=self.state.local_state._replace( + amount_msat = self.state.local_state.amount_msat + (received_this_batch - sent_this_batch) ) ) @@ -368,7 +382,7 @@ class HTLCStateMachine(PrintError): def htlcs_in_remote(self): return self.gen_htlc_indices("remote") - def settle_htlc(self, preimage, htlc_id, source_ref, dest_ref, close_key): + def settle_htlc(self, preimage, htlc_id): """ SettleHTLC attempts to settle an existing outstanding received HTLC. """ diff --git a/lib/tests/test_lnhtlc.py b/lib/tests/test_lnhtlc.py index 7232876c8..23073f89b 100644 --- a/lib/tests/test_lnhtlc.py +++ b/lib/tests/test_lnhtlc.py @@ -193,7 +193,7 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase): # Now we'll repeat a similar exchange, this time with Bob settling the # HTLC once he learns of the preimage. preimage = paymentPreimage - bob_channel.settle_htlc(preimage, bobHtlcIndex, None, None, None) + bob_channel.settle_htlc(preimage, bobHtlcIndex) alice_channel.receive_htlc_settle(preimage, aliceHtlcIndex)