|
|
@ -309,6 +309,7 @@ class Peer(PrintError): |
|
|
|
self.temporary_channel_id_to_incoming_accept_channel = {} |
|
|
|
self.temporary_channel_id_to_incoming_funding_signed = {} |
|
|
|
self.init_message_received_future = asyncio.Future() |
|
|
|
self.localfeatures = 0x08 # request initial sync |
|
|
|
|
|
|
|
def diagnostic_name(self): |
|
|
|
return self.host |
|
|
@ -322,26 +323,29 @@ class Peer(PrintError): |
|
|
|
message_type, payload = decode_msg(msg) |
|
|
|
self.print_error("Sending '%s'"%message_type.upper(), payload) |
|
|
|
l = encode(len(msg), 2) |
|
|
|
lc = aead_encrypt(self.sk, self.sn, b'', l) |
|
|
|
c = aead_encrypt(self.sk, self.sn+1, b'', msg) |
|
|
|
lc = aead_encrypt(self.sk, self.sn(), b'', l) |
|
|
|
c = aead_encrypt(self.sk, self.sn(), b'', msg) |
|
|
|
assert len(lc) == 18 |
|
|
|
assert len(c) == len(msg) + 16 |
|
|
|
self.writer.write(lc+c) |
|
|
|
self.sn += 2 |
|
|
|
|
|
|
|
async def read_message(self): |
|
|
|
while True: |
|
|
|
self.read_buffer += await self.reader.read(2**10) |
|
|
|
s = await self.reader.read(2**10) |
|
|
|
if not s: |
|
|
|
raise Exception('connection closed') |
|
|
|
self.read_buffer += s |
|
|
|
if len(self.read_buffer) < 18: |
|
|
|
continue |
|
|
|
lc = self.read_buffer[:18] |
|
|
|
l = aead_decrypt(self.rk, self.rn, b'', lc) |
|
|
|
l = aead_decrypt(self.rk, self.rn(), b'', lc) |
|
|
|
length = int.from_bytes(l, byteorder="big") |
|
|
|
offset = 18 + length + 16 |
|
|
|
if len(self.read_buffer) < offset: |
|
|
|
continue |
|
|
|
c = self.read_buffer[18:offset] |
|
|
|
self.read_buffer = self.read_buffer[offset:] |
|
|
|
msg = aead_decrypt(self.rk, self.rn+1, b'', c) |
|
|
|
self.rn += 2 |
|
|
|
msg = aead_decrypt(self.rk, self.rn(), b'', c) |
|
|
|
return msg |
|
|
|
|
|
|
|
async def handshake(self): |
|
|
@ -373,8 +377,26 @@ class Peer(PrintError): |
|
|
|
msg = hs.handshake_version + c + t |
|
|
|
self.writer.write(msg) |
|
|
|
# init counters |
|
|
|
self.sn = 0 |
|
|
|
self.rn = 0 |
|
|
|
self._sn = 0 |
|
|
|
self._rn = 0 |
|
|
|
self.r_ck = ck |
|
|
|
self.s_ck = ck |
|
|
|
|
|
|
|
def rn(self): |
|
|
|
o = self._rn |
|
|
|
self._rn += 1 |
|
|
|
if self._rn == 1000: |
|
|
|
self.r_ck, self.rk = get_bolt8_hkdf(self.r_ck, self.rk) |
|
|
|
self._rn = 0 |
|
|
|
return o |
|
|
|
|
|
|
|
def sn(self): |
|
|
|
o = self._sn |
|
|
|
self._sn += 1 |
|
|
|
if self._sn == 1000: |
|
|
|
self.s_ck, self.sk = get_bolt8_hkdf(self.s_ck, self.sk) |
|
|
|
self._sn = 0 |
|
|
|
return o |
|
|
|
|
|
|
|
def process_message(self, message): |
|
|
|
message_type, payload = decode_msg(message) |
|
|
@ -458,7 +480,7 @@ class Peer(PrintError): |
|
|
|
self.reader, self.writer = await asyncio.open_connection(self.host, self.port) |
|
|
|
await self.handshake() |
|
|
|
# send init |
|
|
|
self.send_message(gen_msg("init", gflen=0, lflen=0)) |
|
|
|
self.send_message(gen_msg("init", gflen=0, lflen=1, localfeatures=self.localfeatures)) |
|
|
|
# read init |
|
|
|
msg = await self.read_message() |
|
|
|
self.process_message(msg) |
|
|
|