|
|
@ -15,8 +15,9 @@ from cryptography.hazmat.primitives import hashes, hmac |
|
|
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms |
|
|
|
from hashlib import sha256 |
|
|
|
from io import BytesIO, SEEK_CUR |
|
|
|
from typing import List, Optional, Union |
|
|
|
from typing import List, Optional, Union, Tuple |
|
|
|
import coincurve |
|
|
|
import io |
|
|
|
import os |
|
|
|
import struct |
|
|
|
|
|
|
@ -44,7 +45,7 @@ class OnionPayload(object): |
|
|
|
s = s.encode('ASCII') |
|
|
|
return cls.from_bytes(bytes(unhexlify(s))) |
|
|
|
|
|
|
|
def to_bytes(self): |
|
|
|
def to_bytes(self, include_prefix): |
|
|
|
raise ValueError("OnionPayload is an abstract class, use " |
|
|
|
"LegacyOnionPayload or TlvPayload instead") |
|
|
|
|
|
|
@ -92,20 +93,20 @@ class LegacyOnionPayload(OnionPayload): |
|
|
|
padding = b.read(12) |
|
|
|
return LegacyOnionPayload(a, o, s, padding) |
|
|
|
|
|
|
|
def to_bytes(self, include_realm=True): |
|
|
|
def to_bytes(self, include_prefix=True): |
|
|
|
b = b'' |
|
|
|
if include_realm: |
|
|
|
if include_prefix: |
|
|
|
b += b'\x00' |
|
|
|
|
|
|
|
b += struct.pack("!Q", self.short_channel_id) |
|
|
|
b += struct.pack("!Q", self.amt_to_forward) |
|
|
|
b += struct.pack("!L", self.outgoing_cltv_value) |
|
|
|
b += self.padding |
|
|
|
assert(len(b) == 32 + include_realm) |
|
|
|
assert(len(b) == 32 + include_prefix) |
|
|
|
return b |
|
|
|
|
|
|
|
def to_hex(self, include_realm=True): |
|
|
|
return hexlify(self.to_bytes(include_realm)).decode('ASCII') |
|
|
|
def to_hex(self, include_prefix=True): |
|
|
|
return hexlify(self.to_bytes(include_prefix)).decode('ASCII') |
|
|
|
|
|
|
|
def __str__(self): |
|
|
|
return ("LegacyOnionPayload[scid={self.short_channel_id}, " |
|
|
@ -143,6 +144,12 @@ class TlvPayload(OnionPayload): |
|
|
|
raise ValueError( |
|
|
|
"Unable to read length at position {}".format(b.tell()) |
|
|
|
) |
|
|
|
|
|
|
|
elif length > start + payload_length - b.tell(): |
|
|
|
b.seek(start + payload_length) |
|
|
|
raise ValueError("Failed to parse TLV payload: value length " |
|
|
|
"is longer than available bytes.") |
|
|
|
|
|
|
|
val = b.read(length) |
|
|
|
|
|
|
|
# Get the subclass that is the correct interpretation of this |
|
|
@ -167,10 +174,11 @@ class TlvPayload(OnionPayload): |
|
|
|
return f |
|
|
|
return default |
|
|
|
|
|
|
|
def to_bytes(self): |
|
|
|
def to_bytes(self, include_prefix=True) -> bytes: |
|
|
|
ser = [f.to_bytes() for f in self.fields] |
|
|
|
b = BytesIO() |
|
|
|
varint_encode(sum([len(b) for b in ser]), b) |
|
|
|
if include_prefix: |
|
|
|
varint_encode(sum([len(b) for b in ser]), b) |
|
|
|
for f in ser: |
|
|
|
b.write(f) |
|
|
|
return b.getvalue() |
|
|
@ -179,6 +187,40 @@ class TlvPayload(OnionPayload): |
|
|
|
return "TlvPayload[" + ', '.join([str(f) for f in self.fields]) + "]" |
|
|
|
|
|
|
|
|
|
|
|
class RawPayload(OnionPayload): |
|
|
|
"""A payload that doesn't deserialize correctly as TLV stream. |
|
|
|
|
|
|
|
Mainly used if TLV parsing fails, but we still want access to the raw |
|
|
|
payload. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self.content: Optional[bytes] = None |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def from_bytes(cls, b): |
|
|
|
if isinstance(b, str): |
|
|
|
b = b.encode('ASCII') |
|
|
|
if isinstance(b, bytes): |
|
|
|
b = BytesIO(b) |
|
|
|
|
|
|
|
self = cls() |
|
|
|
payload_length = varint_decode(b) |
|
|
|
self.content = b.read(payload_length) |
|
|
|
return self |
|
|
|
|
|
|
|
def to_bytes(self, include_prefix=True) -> bytes: |
|
|
|
b = BytesIO() |
|
|
|
if self.content is None: |
|
|
|
raise ValueError("Cannot serialize empty TLV payload") |
|
|
|
|
|
|
|
if include_prefix: |
|
|
|
varint_encode(len(self.content), b) |
|
|
|
b.write(self.content) |
|
|
|
return b.getvalue() |
|
|
|
|
|
|
|
|
|
|
|
class TlvField(object): |
|
|
|
|
|
|
|
def __init__(self, typenum, value=None, description=None): |
|
|
@ -319,6 +361,57 @@ class RoutingOnion(object): |
|
|
|
def to_hex(self): |
|
|
|
return hexlify(self.to_bin()) |
|
|
|
|
|
|
|
def unwrap(self, privkey: PrivateKey, assocdata: Optional[bytes]) \ |
|
|
|
-> Tuple[OnionPayload, Optional['RoutingOnion']]: |
|
|
|
shared_secret = ecdh(privkey, self.ephemeralkey) |
|
|
|
keys = generate_keyset(shared_secret) |
|
|
|
|
|
|
|
h = hmac.HMAC(keys.mu, hashes.SHA256(), |
|
|
|
backend=default_backend()) |
|
|
|
h.update(self.payloads) |
|
|
|
if assocdata is not None: |
|
|
|
h.update(assocdata) |
|
|
|
hh = h.finalize() |
|
|
|
|
|
|
|
if hh != self.hmac: |
|
|
|
raise ValueError("HMAC does not match, onion might have been " |
|
|
|
"tampered with: {hh} != {hmac}".format( |
|
|
|
hh=hexlify(hh).decode('ascii'), |
|
|
|
hmac=hexlify(self.hmac).decode('ascii'), |
|
|
|
)) |
|
|
|
|
|
|
|
# Create the scratch twice as large as the original packet, since we |
|
|
|
# need to left-shift a single payload off, which may itself be up to |
|
|
|
# ROUTING_INFO_SIZE in length. |
|
|
|
payloads = bytearray(2 * ROUTING_INFO_SIZE) |
|
|
|
payloads[:ROUTING_INFO_SIZE] = self.payloads |
|
|
|
chacha20_stream(keys.rho, payloads) |
|
|
|
|
|
|
|
r = io.BytesIO(payloads) |
|
|
|
start = r.tell() |
|
|
|
|
|
|
|
try: |
|
|
|
payload = OnionPayload.from_bytes(r) |
|
|
|
except ValueError: |
|
|
|
r.seek(start) |
|
|
|
payload = RawPayload.from_bytes(r) |
|
|
|
|
|
|
|
next_hmac = r.read(32) |
|
|
|
shift_size = r.tell() |
|
|
|
|
|
|
|
if next_hmac == bytes(32): |
|
|
|
return payload, None |
|
|
|
else: |
|
|
|
b = blind(self.ephemeralkey, shared_secret) |
|
|
|
ek = blind_group_element(self.ephemeralkey, b) |
|
|
|
payloads = payloads[shift_size:shift_size + ROUTING_INFO_SIZE] |
|
|
|
return payload, RoutingOnion( |
|
|
|
version=self.version, |
|
|
|
ephemeralkey=ek, |
|
|
|
payloads=payloads, |
|
|
|
hmac=next_hmac, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
KeySet = namedtuple('KeySet', ['rho', 'mu', 'um', 'pad', 'gamma', 'pi']) |
|
|
|
|
|
|
|