Browse Source

lnbase: implement key rotation, request initial sync in localfeatures

regtest_lnd
ThomasV 7 years ago
committed by SomberNight
parent
commit
19ded0ff10
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 42
      lib/lnbase.py

42
lib/lnbase.py

@ -309,6 +309,7 @@ class Peer(PrintError):
self.temporary_channel_id_to_incoming_accept_channel = {} self.temporary_channel_id_to_incoming_accept_channel = {}
self.temporary_channel_id_to_incoming_funding_signed = {} self.temporary_channel_id_to_incoming_funding_signed = {}
self.init_message_received_future = asyncio.Future() self.init_message_received_future = asyncio.Future()
self.localfeatures = 0x08 # request initial sync
def diagnostic_name(self): def diagnostic_name(self):
return self.host return self.host
@ -322,26 +323,29 @@ class Peer(PrintError):
message_type, payload = decode_msg(msg) message_type, payload = decode_msg(msg)
self.print_error("Sending '%s'"%message_type.upper(), payload) self.print_error("Sending '%s'"%message_type.upper(), payload)
l = encode(len(msg), 2) l = encode(len(msg), 2)
lc = aead_encrypt(self.sk, self.sn, b'', l) lc = aead_encrypt(self.sk, self.sn(), b'', l)
c = aead_encrypt(self.sk, self.sn+1, b'', msg) c = aead_encrypt(self.sk, self.sn(), b'', msg)
assert len(lc) == 18 assert len(lc) == 18
assert len(c) == len(msg) + 16 assert len(c) == len(msg) + 16
self.writer.write(lc+c) self.writer.write(lc+c)
self.sn += 2
async def read_message(self): async def read_message(self):
while True: while True:
self.read_buffer += await self.reader.read(2**10) s = await self.reader.read(2**10)
if not s:
raise Exception('connection closed')
self.read_buffer += s
if len(self.read_buffer) < 18:
continue
lc = self.read_buffer[:18] lc = self.read_buffer[:18]
l = aead_decrypt(self.rk, self.rn, b'', lc) l = aead_decrypt(self.rk, self.rn(), b'', lc)
length = int.from_bytes(l, byteorder="big") length = int.from_bytes(l, byteorder="big")
offset = 18 + length + 16 offset = 18 + length + 16
if len(self.read_buffer) < offset: if len(self.read_buffer) < offset:
continue continue
c = self.read_buffer[18:offset] c = self.read_buffer[18:offset]
self.read_buffer = self.read_buffer[offset:] self.read_buffer = self.read_buffer[offset:]
msg = aead_decrypt(self.rk, self.rn+1, b'', c) msg = aead_decrypt(self.rk, self.rn(), b'', c)
self.rn += 2
return msg return msg
async def handshake(self): async def handshake(self):
@ -373,8 +377,26 @@ class Peer(PrintError):
msg = hs.handshake_version + c + t msg = hs.handshake_version + c + t
self.writer.write(msg) self.writer.write(msg)
# init counters # init counters
self.sn = 0 self._sn = 0
self.rn = 0 self._rn = 0
self.r_ck = ck
self.s_ck = ck
def rn(self):
o = self._rn
self._rn += 1
if self._rn == 1000:
self.r_ck, self.rk = get_bolt8_hkdf(self.r_ck, self.rk)
self._rn = 0
return o
def sn(self):
o = self._sn
self._sn += 1
if self._sn == 1000:
self.s_ck, self.sk = get_bolt8_hkdf(self.s_ck, self.sk)
self._sn = 0
return o
def process_message(self, message): def process_message(self, message):
message_type, payload = decode_msg(message) message_type, payload = decode_msg(message)
@ -458,7 +480,7 @@ class Peer(PrintError):
self.reader, self.writer = await asyncio.open_connection(self.host, self.port) self.reader, self.writer = await asyncio.open_connection(self.host, self.port)
await self.handshake() await self.handshake()
# send init # send init
self.send_message(gen_msg("init", gflen=0, lflen=0)) self.send_message(gen_msg("init", gflen=0, lflen=1, localfeatures=self.localfeatures))
# read init # read init
msg = await self.read_message() msg = await self.read_message()
self.process_message(msg) self.process_message(msg)

Loading…
Cancel
Save