diff --git a/contrib/pyln-proto/pyln/proto/onion.py b/contrib/pyln-proto/pyln/proto/onion.py index 9303b0abd..3ca55f517 100644 --- a/contrib/pyln-proto/pyln/proto/onion.py +++ b/contrib/pyln-proto/pyln/proto/onion.py @@ -15,8 +15,9 @@ from cryptography.hazmat.primitives import hashes, hmac from cryptography.hazmat.primitives.ciphers import Cipher, algorithms from hashlib import sha256 from io import BytesIO, SEEK_CUR -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple import coincurve +import io import os import struct @@ -44,7 +45,7 @@ class OnionPayload(object): s = s.encode('ASCII') return cls.from_bytes(bytes(unhexlify(s))) - def to_bytes(self): + def to_bytes(self, include_prefix): raise ValueError("OnionPayload is an abstract class, use " "LegacyOnionPayload or TlvPayload instead") @@ -92,20 +93,20 @@ class LegacyOnionPayload(OnionPayload): padding = b.read(12) return LegacyOnionPayload(a, o, s, padding) - def to_bytes(self, include_realm=True): + def to_bytes(self, include_prefix=True): b = b'' - if include_realm: + if include_prefix: b += b'\x00' b += struct.pack("!Q", self.short_channel_id) b += struct.pack("!Q", self.amt_to_forward) b += struct.pack("!L", self.outgoing_cltv_value) b += self.padding - assert(len(b) == 32 + include_realm) + assert(len(b) == 32 + include_prefix) return b - def to_hex(self, include_realm=True): - return hexlify(self.to_bytes(include_realm)).decode('ASCII') + def to_hex(self, include_prefix=True): + return hexlify(self.to_bytes(include_prefix)).decode('ASCII') def __str__(self): return ("LegacyOnionPayload[scid={self.short_channel_id}, " @@ -143,6 +144,12 @@ class TlvPayload(OnionPayload): raise ValueError( "Unable to read length at position {}".format(b.tell()) ) + + elif length > start + payload_length - b.tell(): + b.seek(start + payload_length) + raise ValueError("Failed to parse TLV payload: value length " + "is longer than available bytes.") + val = b.read(length) # Get the subclass that is the correct interpretation of this @@ -167,10 +174,11 @@ class TlvPayload(OnionPayload): return f return default - def to_bytes(self): + def to_bytes(self, include_prefix=True) -> bytes: ser = [f.to_bytes() for f in self.fields] b = BytesIO() - varint_encode(sum([len(b) for b in ser]), b) + if include_prefix: + varint_encode(sum([len(b) for b in ser]), b) for f in ser: b.write(f) return b.getvalue() @@ -179,6 +187,40 @@ class TlvPayload(OnionPayload): return "TlvPayload[" + ', '.join([str(f) for f in self.fields]) + "]" +class RawPayload(OnionPayload): + """A payload that doesn't deserialize correctly as TLV stream. + + Mainly used if TLV parsing fails, but we still want access to the raw + payload. + + """ + + def __init__(self): + self.content: Optional[bytes] = None + + @classmethod + def from_bytes(cls, b): + if isinstance(b, str): + b = b.encode('ASCII') + if isinstance(b, bytes): + b = BytesIO(b) + + self = cls() + payload_length = varint_decode(b) + self.content = b.read(payload_length) + return self + + def to_bytes(self, include_prefix=True) -> bytes: + b = BytesIO() + if self.content is None: + raise ValueError("Cannot serialize empty TLV payload") + + if include_prefix: + varint_encode(len(self.content), b) + b.write(self.content) + return b.getvalue() + + class TlvField(object): def __init__(self, typenum, value=None, description=None): @@ -319,6 +361,57 @@ class RoutingOnion(object): def to_hex(self): return hexlify(self.to_bin()) + def unwrap(self, privkey: PrivateKey, assocdata: Optional[bytes]) \ + -> Tuple[OnionPayload, Optional['RoutingOnion']]: + shared_secret = ecdh(privkey, self.ephemeralkey) + keys = generate_keyset(shared_secret) + + h = hmac.HMAC(keys.mu, hashes.SHA256(), + backend=default_backend()) + h.update(self.payloads) + if assocdata is not None: + h.update(assocdata) + hh = h.finalize() + + if hh != self.hmac: + raise ValueError("HMAC does not match, onion might have been " + "tampered with: {hh} != {hmac}".format( + hh=hexlify(hh).decode('ascii'), + hmac=hexlify(self.hmac).decode('ascii'), + )) + + # Create the scratch twice as large as the original packet, since we + # need to left-shift a single payload off, which may itself be up to + # ROUTING_INFO_SIZE in length. + payloads = bytearray(2 * ROUTING_INFO_SIZE) + payloads[:ROUTING_INFO_SIZE] = self.payloads + chacha20_stream(keys.rho, payloads) + + r = io.BytesIO(payloads) + start = r.tell() + + try: + payload = OnionPayload.from_bytes(r) + except ValueError: + r.seek(start) + payload = RawPayload.from_bytes(r) + + next_hmac = r.read(32) + shift_size = r.tell() + + if next_hmac == bytes(32): + return payload, None + else: + b = blind(self.ephemeralkey, shared_secret) + ek = blind_group_element(self.ephemeralkey, b) + payloads = payloads[shift_size:shift_size + ROUTING_INFO_SIZE] + return payload, RoutingOnion( + version=self.version, + ephemeralkey=ek, + payloads=payloads, + hmac=next_hmac, + ) + KeySet = namedtuple('KeySet', ['rho', 'mu', 'um', 'pad', 'gamma', 'pi']) diff --git a/contrib/pyln-proto/tests/test_onion.py b/contrib/pyln-proto/tests/test_onion.py index 72a44440f..819ad1ea2 100644 --- a/contrib/pyln-proto/tests/test_onion.py +++ b/contrib/pyln-proto/tests/test_onion.py @@ -12,7 +12,7 @@ def test_legacy_payload(): b'00000067000001000100000000000003e800000075000000000000000000000000' ) payload = onion.OnionPayload.from_bytes(legacy) - assert(payload.to_bytes(include_realm=True) == legacy) + assert(payload.to_bytes(include_prefix=True) == legacy) def test_tlv_payload(): @@ -325,3 +325,21 @@ def test_sphinx_path_compile(): o = sp.compile() assert(o.to_bin() == unhexlify(v['onion'])) + + +def test_unwrap(): + f = 'tests/vectors/onion-test-multi-frame.json' + sp, v = sphinx_path_from_test_vector(f) + o = onion.RoutingOnion.from_hex(v['onion']) + assocdata = unhexlify(v['generate']['associated_data']) + privkeys = [onion.PrivateKey(unhexlify(h)) for h in v['decode']] + + for pk, h in zip(privkeys, v['generate']['hops']): + pl, o = o.unwrap(pk, assocdata=assocdata) + + b = hexlify(pl.to_bytes(include_prefix=False)) + if h['type'] == 'legacy': + assert(b == h['payload'].encode('ascii') + b'00' * 12) + else: + assert(b == h['payload'].encode('ascii')) + assert(o is None)