Browse Source

fix channel_reestablish

regtest_lnd
ThomasV 7 years ago
committed by SomberNight
parent
commit
eb4d151324
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 2
      gui/qt/lightning_channels_list.py
  2. 40
      lib/lnbase.py
  3. 2
      lib/lnrouter.py
  4. 25
      lib/lnworker.py
  5. 2
      lib/tests/test_lnbase.py

2
gui/qt/lightning_channels_list.py

@ -45,7 +45,7 @@ class LightningChannelsList(QtWidgets.QWidget):
push_amt = int(push_amt_inp.text()) push_amt = int(push_amt_inp.text())
assert local_amt >= 200000 assert local_amt >= 200000
assert local_amt >= push_amt assert local_amt >= push_amt
obj = self.lnworker.open_channel_from_other_thread(node_id, local_amt, push_amt, self.update_rows.emit, password) obj = self.lnworker.open_channel(node_id, local_amt, push_amt, password)
@QtCore.pyqtSlot(dict) @QtCore.pyqtSlot(dict)
def do_update_single_row(self, new): def do_update_single_row(self, new):

40
lib/lnbase.py

@ -568,8 +568,7 @@ def is_synced(network):
class Peer(PrintError): class Peer(PrintError):
def __init__(self, host, port, pubkey, privkey, network, channel_db, path_finder, channel_state, handle_channel_reestablish, request_initial_sync=False): def __init__(self, host, port, pubkey, privkey, network, channel_db, path_finder, channel_state, channels, request_initial_sync=False):
self.handle_channel_reestablish = handle_channel_reestablish
self.update_add_htlc_event = asyncio.Event() self.update_add_htlc_event = asyncio.Event()
self.channel_update_event = asyncio.Event() self.channel_update_event = asyncio.Event()
self.host = host self.host = host
@ -594,7 +593,6 @@ class Peer(PrintError):
self.local_funding_locked = defaultdict(asyncio.Future) self.local_funding_locked = defaultdict(asyncio.Future)
self.remote_funding_locked = defaultdict(asyncio.Future) self.remote_funding_locked = defaultdict(asyncio.Future)
self.revoke_and_ack = defaultdict(asyncio.Future) self.revoke_and_ack = defaultdict(asyncio.Future)
self.channel_reestablish = defaultdict(asyncio.Future)
self.update_fulfill_htlc = defaultdict(asyncio.Future) self.update_fulfill_htlc = defaultdict(asyncio.Future)
self.commitment_signed = defaultdict(asyncio.Future) self.commitment_signed = defaultdict(asyncio.Future)
self.initialized = asyncio.Future() self.initialized = asyncio.Future()
@ -602,6 +600,7 @@ class Peer(PrintError):
self.unfulfilled_htlcs = [] self.unfulfilled_htlcs = []
self.channel_state = channel_state self.channel_state = channel_state
self.nodes = {} self.nodes = {}
self.channels = channels
def diagnostic_name(self): def diagnostic_name(self):
return self.host return self.host
@ -714,13 +713,6 @@ class Peer(PrintError):
l = int.from_bytes(payload['num_pong_bytes'], 'big') l = int.from_bytes(payload['num_pong_bytes'], 'big')
self.send_message(gen_msg('pong', byteslen=l)) self.send_message(gen_msg('pong', byteslen=l))
def on_channel_reestablish(self, payload):
chan_id = int.from_bytes(payload["channel_id"], 'big')
if chan_id in self.channel_reestablish:
self.channel_reestablish[chan_id].set_result(payload)
else:
asyncio.run_coroutine_threadsafe(self.handle_channel_reestablish(chan_id, payload), self.network.asyncio_loop).result()
def on_accept_channel(self, payload): def on_accept_channel(self, payload):
temp_chan_id = payload["temporary_channel_id"] temp_chan_id = payload["temporary_channel_id"]
if temp_chan_id not in self.channel_accepted: raise Exception("Got unknown accept_channel") if temp_chan_id not in self.channel_accepted: raise Exception("Got unknown accept_channel")
@ -795,6 +787,8 @@ class Peer(PrintError):
self.process_message(msg) self.process_message(msg)
# initialized # initialized
self.initialized.set_result(msg) self.initialized.set_result(msg)
# reestablish channels
[await self.reestablish_channel(c) for c in self.channels]
# loop # loop
while True: while True:
self.ping_if_required() self.ping_if_required()
@ -963,33 +957,29 @@ class Peer(PrintError):
async def reestablish_channel(self, chan): async def reestablish_channel(self, chan):
assert chan.channel_id not in self.channel_state assert chan.channel_id not in self.channel_state
await self.initialized
self.send_message(gen_msg("channel_reestablish", self.send_message(gen_msg("channel_reestablish",
channel_id=chan.channel_id, channel_id=chan.channel_id,
next_local_commitment_number=chan.local_state.ctn+1, next_local_commitment_number=chan.local_state.ctn+1,
next_remote_revocation_number=chan.remote_state.ctn next_remote_revocation_number=chan.remote_state.ctn
)) ))
channel_reestablish_msg = await self.channel_reestablish[chan.channel_id]
print(channel_reestablish_msg) def on_channel_reestablish(self, payload):
# { chan_id = int.from_bytes(payload["channel_id"], 'big')
# 'channel_id': b'\xfa\xce\x0b\x8cjZ6\x03\xd2\x99k\x12\x86\xc7\xed\xe5\xec\x80\x85F\xf2\x1bzn\xa1\xd30I\xf9_V\xfa', for chan in self.channels:
# 'next_local_commitment_number': b'\x00\x00\x00\x00\x00\x00\x00\x01', if chan.channel_id == chan_id:
# 'next_remote_revocation_number': b'\x00\x00\x00\x00\x00\x00\x00\x00', break
# 'your_last_per_commitment_secret': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', else:
# 'my_current_per_commitment_point': b'\x03\x18\xb9\x1b\x99\xd4\xc3\xf1\x92\x0f\xfe\xe4c\x9e\xae\xa4\xf1\xdeX\xcf4\xa9[\xd1\tAh\x80\x88\x01b*[' print("Warning: received unknown channel_reestablish", chan_id, list(self.channels))
# } return
channel_reestablish_msg = payload
remote_ctn = int.from_bytes(channel_reestablish_msg["next_local_commitment_number"], 'big') remote_ctn = int.from_bytes(channel_reestablish_msg["next_local_commitment_number"], 'big')
if remote_ctn != chan.remote_state.ctn + 1: if remote_ctn != chan.remote_state.ctn + 1:
raise Exception("expected remote ctn {}, got {}".format(chan.remote_state.ctn + 1, remote_ctn)) raise Exception("expected remote ctn {}, got {}".format(chan.remote_state.ctn + 1, remote_ctn))
local_ctn = int.from_bytes(channel_reestablish_msg["next_remote_revocation_number"], 'big') local_ctn = int.from_bytes(channel_reestablish_msg["next_remote_revocation_number"], 'big')
if local_ctn != chan.local_state.ctn: if local_ctn != chan.local_state.ctn:
raise Exception("expected local ctn {}, got {}".format(chan.local_state.ctn, local_ctn)) raise Exception("expected local ctn {}, got {}".format(chan.local_state.ctn, local_ctn))
if channel_reestablish_msg["my_current_per_commitment_point"] != chan.remote_state.last_per_commitment_point: if channel_reestablish_msg["my_current_per_commitment_point"] != chan.remote_state.last_per_commitment_point:
raise Exception("Remote PCP mismatch") raise Exception("Remote PCP mismatch")
self.channel_state[chan.channel_id] = "OPEN" self.channel_state[chan.channel_id] = "OPEN"
async def funding_locked(self, chan): async def funding_locked(self, chan):
@ -1009,9 +999,7 @@ class Peer(PrintError):
finally: finally:
del self.remote_funding_locked[channel_id] del self.remote_funding_locked[channel_id]
self.print_error('Done waiting for remote_funding_locked', remote_funding_locked_msg) self.print_error('Done waiting for remote_funding_locked', remote_funding_locked_msg)
self.channel_state[chan.channel_id] = "OPEN" self.channel_state[chan.channel_id] = "OPEN"
return chan._replace(short_channel_id=short_channel_id, remote_state=chan.remote_state._replace(next_per_commitment_point=remote_funding_locked_msg["next_per_commitment_point"])) return chan._replace(short_channel_id=short_channel_id, remote_state=chan.remote_state._replace(next_per_commitment_point=remote_funding_locked_msg["next_per_commitment_point"]))
def on_update_fail_htlc(self, payload): def on_update_fail_htlc(self, payload):

2
lib/lnrouter.py

@ -120,7 +120,7 @@ class ChannelDB(PrintError):
try: try:
channel_info = self._id_to_channel_info[short_channel_id] channel_info = self._id_to_channel_info[short_channel_id]
except KeyError: except KeyError:
print("could not find", short_channel_id) self.print_error("could not find", short_channel_id)
else: else:
channel_info.on_channel_update(msg_payload) channel_info.on_channel_update(msg_payload)

25
lib/lnworker.py

@ -98,7 +98,6 @@ class LNWorker:
self.nodes = {} # received node announcements self.nodes = {} # received node announcements
self.channel_db = lnrouter.ChannelDB() self.channel_db = lnrouter.ChannelDB()
self.path_finder = lnrouter.LNPathFinder(self.channel_db) self.path_finder = lnrouter.LNPathFinder(self.channel_db)
self.channels = [reconstruct_namedtuples(x) for x in wallet.storage.get("channels", {})] self.channels = [reconstruct_namedtuples(x) for x in wallet.storage.get("channels", {})]
peer_list = network.config.get('lightning_peers', node_list) peer_list = network.config.get('lightning_peers', node_list)
self.channel_state = {} self.channel_state = {}
@ -109,15 +108,11 @@ class LNWorker:
self.on_network_update('updated') # shortcut (don't block) if funding tx locked and verified self.on_network_update('updated') # shortcut (don't block) if funding tx locked and verified
def add_peer(self, host, port, pubkey): def add_peer(self, host, port, pubkey):
peer = Peer(host, int(port), binascii.unhexlify(pubkey), self.privkey, node_id = bfh(pubkey)
self.network, self.channel_db, self.path_finder, self.channel_state, self.handle_channel_reestablish) channels = list(filter(lambda x: x.node_id == node_id, self.channels))
peer = Peer(host, int(port), node_id, self.privkey, self.network, self.channel_db, self.path_finder, self.channel_state, channels)
self.network.futures.append(asyncio.run_coroutine_threadsafe(peer.main_loop(), asyncio.get_event_loop())) self.network.futures.append(asyncio.run_coroutine_threadsafe(peer.main_loop(), asyncio.get_event_loop()))
self.peers[bfh(pubkey)] = peer self.peers[node_id] = peer
async def handle_channel_reestablish(self, chan_id, payload):
chans = [x for x in self.channels if x.channel_id == chan_id ]
chan = chans[0]
await self.peers[chan.node_id].reestablish_channel(chan)
def save_channel(self, openchannel): def save_channel(self, openchannel):
self.channels = [openchannel] # TODO multiple channels self.channels = [openchannel] # TODO multiple channels
@ -179,17 +174,6 @@ class LNWorker:
def list_channels(self): def list_channels(self):
return serialize_channels(self.channels) return serialize_channels(self.channels)
def reestablish_channels(self):
coro = self._reestablish_channels_coroutine()
return asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop).result()
# not aiosafe because we call .result() which will propagate an exception
async def _reestablish_channels_coroutine(self):
if self.channels is None or len(self.channels) < 1:
raise Exception("Can't reestablish: No channel saved")
peer = self.peers[self.channels[0].node_id]
await peer.reestablish_channel(self.channels[0])
# not aiosafe because we call .result() which will propagate an exception # not aiosafe because we call .result() which will propagate an exception
async def _pay_coroutine(self, invoice): async def _pay_coroutine(self, invoice):
openchannel = self.channels[0] openchannel = self.channels[0]
@ -216,7 +200,6 @@ class LNWorker:
openchannel = await peer.receive_commitment_revoke_ack(openchannel, expected_received_msat, payment_preimage) openchannel = await peer.receive_commitment_revoke_ack(openchannel, expected_received_msat, payment_preimage)
self.save_channel(openchannel) self.save_channel(openchannel)
def subscribe_payment_received_from_other_thread(self, emit_function): def subscribe_payment_received_from_other_thread(self, emit_function):
pass pass

2
lib/tests/test_lnbase.py

@ -256,7 +256,7 @@ class Test_LNBase(unittest.TestCase):
def test_find_path_for_payment(self): def test_find_path_for_payment(self):
channel_db = lnrouter.ChannelDB() channel_db = lnrouter.ChannelDB()
path_finder = lnrouter.LNPathFinder(channel_db) path_finder = lnrouter.LNPathFinder(channel_db)
p = Peer('', 0, 'a', bitcoin.sha256('privkeyseed'), None, channel_db, path_finder, {}, lambda x, y: None) p = Peer('', 0, 'a', bitcoin.sha256('privkeyseed'), None, channel_db, path_finder, {}, [])
p.on_channel_announcement({'node_id_1': b'b', 'node_id_2': b'c', 'short_channel_id': bfh('0000000000000001')}) p.on_channel_announcement({'node_id_1': b'b', 'node_id_2': b'c', 'short_channel_id': bfh('0000000000000001')})
p.on_channel_announcement({'node_id_1': b'b', 'node_id_2': b'e', 'short_channel_id': bfh('0000000000000002')}) p.on_channel_announcement({'node_id_1': b'b', 'node_id_2': b'e', 'short_channel_id': bfh('0000000000000002')})
p.on_channel_announcement({'node_id_1': b'a', 'node_id_2': b'b', 'short_channel_id': bfh('0000000000000003')}) p.on_channel_announcement({'node_id_1': b'a', 'node_id_2': b'b', 'short_channel_id': bfh('0000000000000003')})

Loading…
Cancel
Save