diff --git a/lib/lnbase.py b/lib/lnbase.py index eb2e4d0cd..7b550b460 100644 --- a/lib/lnbase.py +++ b/lib/lnbase.py @@ -572,29 +572,29 @@ def is_synced(network): class Peer(PrintError): - def __init__(self, host, port, pubkey, privkey, network, channel_db, path_finder, channel_state, channels, invoices, request_initial_sync=False): + def __init__(self, lnworker, host, port, pubkey, request_initial_sync=False): self.channel_update_event = asyncio.Event() self.host = host self.port = port - self.privkey = privkey self.pubkey = pubkey - self.network = network - self.channel_db = channel_db - self.path_finder = path_finder + self.lnworker = lnworker + self.privkey = lnworker.privkey + self.network = lnworker.network + self.channel_db = lnworker.channel_db + self.path_finder = lnworker.path_finder + self.channel_state = lnworker.channel_state self.read_buffer = b'' self.ping_time = 0 self.initialized = asyncio.Future() self.channel_accepted = defaultdict(asyncio.Queue) self.funding_signed = defaultdict(asyncio.Queue) - self.remote_funding_locked = defaultdict(asyncio.Queue) self.revoke_and_ack = defaultdict(asyncio.Queue) self.update_fulfill_htlc = defaultdict(asyncio.Queue) self.commitment_signed = defaultdict(asyncio.Queue) self.localfeatures = (0x08 if request_initial_sync else 0) - self.channel_state = channel_state self.nodes = {} - self.channels = channels - self.invoices = invoices + self.channels = lnworker.channels + self.invoices = lnworker.invoices def diagnostic_name(self): return self.host @@ -713,11 +713,6 @@ class Peer(PrintError): if channel_id not in self.funding_signed: raise Exception("Got unknown funding_signed") self.funding_signed[channel_id].put_nowait(payload) - def on_funding_locked(self, payload): - channel_id = payload['channel_id'] - if channel_id not in self.remote_funding_locked: print("Got unknown funding_locked", payload) - self.remote_funding_locked[channel_id].put_nowait(payload) - def on_node_announcement(self, payload): pubkey = payload['node_id'] signature = payload['signature'] @@ -957,22 +952,43 @@ class Peer(PrintError): raise Exception("expected local ctn {}, got {}".format(chan.local_state.ctn, local_ctn)) if channel_reestablish_msg["my_current_per_commitment_point"] != chan.remote_state.last_per_commitment_point: raise Exception("Remote PCP mismatch") - self.channel_state[chan_id] = 'OPEN' if chan.local_state.funding_locked_received else 'OPENING' + self.channel_state[chan_id] = 'OPENING' #if chan.local_state.funding_locked_received else 'OPENING' self.network.trigger_callback('channel', chan) - async def funding_locked(self, chan): + def funding_locked(self, chan): channel_id = chan.channel_id - short_channel_id = chan.short_channel_id per_commitment_secret_index = 2**48 - 2 per_commitment_point_second = secret_to_pubkey(int.from_bytes( get_per_commitment_secret_from_seed(chan.local_state.per_commitment_secret_seed, per_commitment_secret_index), 'big')) self.send_message(gen_msg("funding_locked", channel_id=channel_id, next_per_commitment_point=per_commitment_point_second)) - # wait until we receive funding_locked - remote_funding_locked_msg = await self.remote_funding_locked[channel_id].get() - self.print_error('Done waiting for remote_funding_locked', remote_funding_locked_msg) - new_remote_state = chan.remote_state._replace(next_per_commitment_point=remote_funding_locked_msg["next_per_commitment_point"]) + if chan.local_state.funding_locked_received: + self.mark_open(chan) + + def on_funding_locked(self, payload): + channel_id = payload['channel_id'] + chan = self.channels.get(channel_id) + if not chan: + raise Exception("Got unknown funding_locked", channel_id) + short_channel_id = chan.short_channel_id + new_remote_state = chan.remote_state._replace(next_per_commitment_point=payload["next_per_commitment_point"]) new_local_state = chan.local_state._replace(funding_locked_received = True) - return chan._replace(short_channel_id=short_channel_id, remote_state=new_remote_state, local_state=new_local_state) + chan = chan._replace(short_channel_id=short_channel_id, remote_state=new_remote_state, local_state=new_local_state) + self.lnworker.save_channel(chan) + if chan.short_channel_id: + self.mark_open(chan) + + def mark_open(self, chan): + if self.channel_state[chan.channel_id] == "OPEN": + return + assert chan.local_state.funding_locked_received + self.channel_state[chan.channel_id] = "OPEN" + self.network.trigger_callback('channel', chan) + # add channel to database + sorted_keys = list(sorted([self.pubkey, self.lnworker.pubkey])) + self.channel_db.on_channel_announcement({"short_channel_id": chan.short_channel_id, "node_id_1": sorted_keys[0], "node_id_2": sorted_keys[1]}) + self.channel_db.on_channel_update({"short_channel_id": chan.short_channel_id, 'flags': b'\x01', 'cltv_expiry_delta': b'\x90', 'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01'}) + self.channel_db.on_channel_update({"short_channel_id": chan.short_channel_id, 'flags': b'\x00', 'cltv_expiry_delta': b'\x90', 'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01'}) + self.print_error("CHANNEL OPENING COMPLETED") def on_update_fail_htlc(self, payload): print("UPDATE_FAIL_HTLC", decode_onion_error(payload["reason"], self.node_keys, self.secret_key)) @@ -993,37 +1009,16 @@ class Peer(PrintError): ) return chan, last_secret, this_point, next_point - async def pay(self, wallet, chan, amount_msat, payment_hash, pubkey_in_invoice, min_final_cltv_expiry): + @aiosafe + async def pay(self, path, chan, amount_msat, payment_hash, pubkey_in_invoice, min_final_cltv_expiry): assert self.channel_state[chan.channel_id] == "OPEN" + assert amount_msat > 0, "amount_msat is not greater zero" + height = self.network.get_local_height() their_revstore = chan.remote_state.revocation_store - while not is_synced(wallet.network): - await asyncio.sleep(1) - print("sleeping more") - if chan.channel_id in self.commitment_signed: print("too many commitments signed") del self.commitment_signed[chan.channel_id] - - height = wallet.get_local_height() - assert amount_msat > 0, "amount_msat is not greater zero" - - our_pubkey = ecc.ECPrivkey(self.privkey).get_public_key_bytes() - sorted_keys = list(sorted([self.pubkey, our_pubkey])) - self.channel_db.on_channel_announcement({"short_channel_id": chan.short_channel_id, "node_id_1": sorted_keys[0], "node_id_2": sorted_keys[1]}) - self.channel_db.on_channel_update({"short_channel_id": chan.short_channel_id, 'flags': b'\x01', 'cltv_expiry_delta': b'\x90', 'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01'}) - self.channel_db.on_channel_update({"short_channel_id": chan.short_channel_id, 'flags': b'\x00', 'cltv_expiry_delta': b'\x90', 'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01'}) - - print("our short chan id", chan.short_channel_id) - while True: - path = self.path_finder.find_path_for_payment(our_pubkey, pubkey_in_invoice, amount_msat) - if path is not None: - break - print("waiting for path") - await self.channel_update_event.wait() - self.channel_update_event.clear() - - route = self.path_finder.create_route_from_path(path, our_pubkey) - + route = self.path_finder.create_route_from_path(path, self.lnworker.pubkey) hops_data = [] sum_of_deltas = sum(route_edge.channel_policy.cltv_expiry_delta for route_edge in route[1:]) total_fee = 0 @@ -1035,17 +1030,11 @@ class Peer(PrintError): associated_data = payment_hash self.secret_key = os.urandom(32) self.node_keys = [x.node_id for x in route] - hops_data += [OnionHopsDataSingle(OnionPerHop(b"\x00"*8, amount_msat.to_bytes(8, "big"), (final_cltv_expiry_without_deltas).to_bytes(4, "big")))] - onion = new_onion_packet(self.node_keys, self.secret_key, hops_data, associated_data) - msat_local = chan.local_state.amount_msat - (amount_msat + total_fee) - msat_remote = chan.remote_state.amount_msat + (amount_msat + total_fee) - amount_msat += total_fee - self.send_message(gen_msg("update_add_htlc", channel_id=chan.channel_id, id=chan.local_state.next_htlc_id, cltv_expiry=final_cltv_expiry_with_deltas, amount_msat=amount_msat, payment_hash=payment_hash, onion_routing_packet=onion.to_bytes())) their_local_htlc_pubkey = derive_pubkey(chan.remote_config.htlc_basepoint.pubkey, chan.remote_state.next_per_commitment_point) @@ -1105,7 +1094,7 @@ class Peer(PrintError): revoke_and_ack_msg = await self.revoke_and_ack[chan.channel_id].get() # TODO check revoke_and_ack results - return chan._replace( + chan = chan._replace( local_state=chan.local_state._replace( amount_msat=msat_local, next_htlc_id=chan.local_state.next_htlc_id + 1 @@ -1118,6 +1107,7 @@ class Peer(PrintError): amount_msat=msat_remote ) ) + self.lnworker.save_channel(chan) @aiosafe async def receive_commitment_revoke_ack(self, htlc, decoded, payment_preimage): diff --git a/lib/lnworker.py b/lib/lnworker.py index 6aa311e77..1ff54b626 100644 --- a/lib/lnworker.py +++ b/lib/lnworker.py @@ -22,7 +22,7 @@ from .wallet import Wallet from .lnbase import Peer, Outpoint, ChannelConfig, LocalState, RemoteState, Keypair, OnlyPubkeyKeypair, OpenChannel, ChannelConstraints, RevocationStore, aiosafe, calc_short_channel_id, privkey_to_pubkey from .lightning_payencode.lnaddr import lnencode, LnAddr, lndecode from . import lnrouter - +from .ecc import ECPrivkey is_key = lambda k: k.endswith("_basepoint") or k.endswith("_key") @@ -96,6 +96,7 @@ class LNWorker(PrintError): wallet.storage.put('lightning_privkey', pk) wallet.storage.write() self.privkey = bfh(pk) + self.pubkey = ECPrivkey(self.privkey).get_public_key_bytes() self.config = network.config self.peers = {} # view of the network @@ -119,7 +120,7 @@ class LNWorker(PrintError): def add_peer(self, host, port, pubkey): node_id = bfh(pubkey) channels = self.channels_for_peer(node_id) - peer = Peer(host, int(port), node_id, self.privkey, self.network, self.channel_db, self.path_finder, self.channel_state, channels, self.invoices, request_initial_sync=True) + peer = Peer(self, host, int(port), node_id, request_initial_sync=False) self.network.futures.append(asyncio.run_coroutine_threadsafe(peer.main_loop(), asyncio.get_event_loop())) self.peers[node_id] = peer self.lock = threading.Lock() @@ -164,19 +165,7 @@ class LNWorker(PrintError): self.print_error("network update but funding tx is still not at sufficient depth") continue peer = self.peers[chan.node_id] - asyncio.run_coroutine_threadsafe(self.wait_funding_locked_and_mark_open(peer, chan), asyncio.get_event_loop()) - - # aiosafe because we don't wait for result - @aiosafe - async def wait_funding_locked_and_mark_open(self, peer, chan): - await peer.initialized - if self.channel_state[chan.channel_id] == "OPEN": - return - if not chan.local_state.funding_locked_received: - chan = await peer.funding_locked(chan) - self.save_channel(chan) - self.print_error("CHANNEL OPENING COMPLETED") - self.channel_state[chan.channel_id] = "OPEN" + peer.funding_locked(chan) # not aiosafe because we call .result() which will propagate an exception async def _open_channel_coroutine(self, node_id, amount_sat, push_sat, password): @@ -194,19 +183,18 @@ class LNWorker(PrintError): return asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop).result() def pay(self, invoice): - coro = self._pay_coroutine(invoice) - return asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) - - @aiosafe - async def _pay_coroutine(self, invoice): - openchannel = next(iter(self.channels.values())) addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) payment_hash = addr.paymenthash - pubkey = addr.pubkey.serialize() - msat_amt = int(addr.amount * COIN * 1000) - peer = self.peers[openchannel.node_id] - openchannel = await peer.pay(self.wallet, openchannel, msat_amt, payment_hash, pubkey, addr.min_final_cltv_expiry) - self.save_channel(openchannel) + invoice_pubkey = addr.pubkey.serialize() + amount_msat = int(addr.amount * COIN * 1000) + path = self.path_finder.find_path_for_payment(self.pubkey, invoice_pubkey, amount_msat) + node_id, short_channel_id = path[0] + peer = self.peers[node_id] + for chan in self.channels.values(): + if chan.short_channel_id == short_channel_id: + break + coro = peer.pay(path, chan, amount_msat, payment_hash, invoice_pubkey, addr.min_final_cltv_expiry) + asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) def add_invoice(self, amount_sat, message='one cup of coffee'): coro = self._add_invoice_coroutine(amount_sat, message)