diff --git a/lib/lnbase.py b/lib/lnbase.py index 76b8b5a53..84fe7f063 100644 --- a/lib/lnbase.py +++ b/lib/lnbase.py @@ -276,7 +276,7 @@ ChannelConfig = namedtuple("ChannelConfig", [ "payment_basepoint", "multisig_key", "htlc_basepoint", "delayed_basepoint", "revocation_basepoint", "to_self_delay", "dust_limit_sat", "max_htlc_value_in_flight_msat", "max_accepted_htlcs"]) OnlyPubkeyKeypair = namedtuple("OnlyPubkeyKeypair", ["pubkey"]) -RemoteState = namedtuple("RemoteState", ["ctn", "next_per_commitment_point", "amount_sat"]) +RemoteState = namedtuple("RemoteState", ["ctn", "next_per_commitment_point", "amount_sat", "commitment_points"]) LocalState = namedtuple("LocalState", ["ctn", "per_commitment_secret_seed", "amount_sat"]) ChannelConstraints = namedtuple("ChannelConstraints", ["feerate", "capacity", "is_initiator", "funding_txn_minimum_depth"]) OpenChannel = namedtuple("OpenChannel", ["channel_id", "funding_outpoint", "local_config", "remote_config", "remote_state", "local_state", "constraints"]) @@ -691,17 +691,23 @@ class Peer(PrintError): self.send_message(gen_msg('pong', byteslen=l)) def on_channel_reestablish(self, payload): - self.channel_reestablish[payload["channel_id"]].set_result(payload) + chan_id = int.from_bytes(payload["channel_id"], 'big') + if chan_id not in self.channel_reestablish: raise Exception("Got unknown channel_reestablish") + self.channel_reestablish[chan_id].set_result(payload) def on_accept_channel(self, payload): - self.channel_accepted[payload["temporary_channel_id"]].set_result(payload) + temp_chan_id = payload["temporary_channel_id"] + if temp_chan_id not in self.channel_accepted: raise Exception("Got unknown accept_channel") + self.channel_accepted[temp_chan_id].set_result(payload) def on_funding_signed(self, payload): channel_id = int.from_bytes(payload['channel_id'], 'big') + if channel_id not in self.funding_signed: raise Exception("Got unknown funding_signed") self.funding_signed[channel_id].set_result(payload) def on_funding_locked(self, payload): channel_id = int.from_bytes(payload['channel_id'], 'big') + #if channel_id not in self.funding_signed: raise Exception("Got unknown funding_locked") self.remote_funding_locked[channel_id].set_result(payload) def on_node_announcement(self, payload): @@ -909,7 +915,8 @@ class Peer(PrintError): remote_state=RemoteState( ctn = 0, next_per_commitment_point=None, - amount_sat=remote_amount + amount_sat=remote_amount, + commitment_points=[bh2u(remote_per_commitment_point)] ), local_state=LocalState( ctn = 0, @@ -920,11 +927,24 @@ class Peer(PrintError): ) return chan - async def reestablish_channel(self, chan, m, n): + async def reestablish_channel(self, chan): + await self.initialized - self.send_message(gen_msg("channel_reestablish", channel_id=chan.channel_id, next_local_commitment_number=m, next_remote_revocation_number=n)) channel_reestablish_msg = await self.channel_reestablish[chan.channel_id] print(channel_reestablish_msg) + # { + # 'channel_id': b'\xfa\xce\x0b\x8cjZ6\x03\xd2\x99k\x12\x86\xc7\xed\xe5\xec\x80\x85F\xf2\x1bzn\xa1\xd30I\xf9_V\xfa', + # 'next_local_commitment_number': b'\x00\x00\x00\x00\x00\x00\x00\x01', + # 'next_remote_revocation_number': b'\x00\x00\x00\x00\x00\x00\x00\x00', + # 'your_last_per_commitment_secret': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + # 'my_current_per_commitment_point': b'\x03\x18\xb9\x1b\x99\xd4\xc3\xf1\x92\x0f\xfe\xe4c\x9e\xae\xa4\xf1\xdeX\xcf4\xa9[\xd1\tAh\x80\x88\x01b*[' + # } + if channel_reestablish_msg["my_current_per_commitment_point"] != bfh(chan.remote_state.commitment_points[-1]): + raise Exception("Remote PCP mismatch") + n = chan.local_state.ctn + 1 + self.send_message(gen_msg("channel_reestablish", channel_id=chan.channel_id, next_local_commitment_number=n, next_remote_revocation_number=chan.remote_state.ctn)) + return chan + async def wait_for_funding_locked(self, chan, wallet): channel_id = chan.channel_id diff --git a/lib/tests/test_lnbase_online.py b/lib/tests/test_lnbase_online.py index 2c927af2f..bcc2ab9de 100644 --- a/lib/tests/test_lnbase_online.py +++ b/lib/tests/test_lnbase_online.py @@ -64,7 +64,7 @@ def serialize_channels(channels): reconstructed = [reconstruct_namedtuples(x) for x in roundtripped] if reconstructed != channels: raise Exception("Channels did not roundtrip serialization without changes:\n" + repr(reconstructed) + "\n" + repr(channels)) - return dumped + return roundtripped if __name__ == "__main__": if len(sys.argv) > 3: @@ -107,15 +107,16 @@ if __name__ == "__main__": if sys.argv[1] == "new_channel": openingchannel = await peer.channel_establishment_flow(wallet, config, None, funding_satoshis, push_msat, temp_channel_id=os.urandom(32)) - dumped = serialize_channels([openingchannel]) + openchannel = await peer.wait_for_funding_locked(openingchannel, wallet) + dumped = serialize_channels([openchannel]) wallet.storage.put("channels", dumped) - return - else: - openingchannel = json.loads(channels)[0] - openingchannel = reconstruct_namedtuples(openingchannel) - next_local_commitment_number, next_remote_revocation_number = 1, 1 - await peer.reestablish_channel(openingchannel, next_local_commitment_number, next_remote_revocation_number) - openchannel = await peer.wait_for_funding_locked(openingchannel, wallet) + wallet.storage.write() + return openchannel.channel_id + if channels is None or len(channels) < 1: + raise Exception("Can't reestablish: No channel saved") + openchannel = channels[0] + openchannel = reconstruct_namedtuples(openchannel) + openchannel = await peer.reestablish_channel(openchannel) expected_received_sat = 400000 pay_req = lnencode(LnAddr(RHASH, amount=Decimal("0.00000001")*expected_received_sat, tags=[('d', 'one cup of coffee')]), peer.privkey[:32]) print("payment request", pay_req) @@ -124,9 +125,13 @@ if __name__ == "__main__": fut = asyncio.run_coroutine_threadsafe(async_test(), network.asyncio_loop) while not fut.done(): time.sleep(1) - if fut.exception(): - try: + try: + if fut.exception(): raise fut.exception() - except: - traceback.print_exc() - network.stop() + except: + traceback.print_exc() + else: + print("result", fut.result()) + finally: + network.stop() +