Christian Decker
5 years ago
committed by
Rusty Russell
6 changed files with 425 additions and 0 deletions
@ -0,0 +1,236 @@ |
|||
from .primitives import varint_decode, varint_encode |
|||
from io import BytesIO, SEEK_CUR |
|||
from binascii import hexlify, unhexlify |
|||
import struct |
|||
|
|||
|
|||
class OnionPayload(object): |
|||
|
|||
@classmethod |
|||
def from_bytes(cls, b): |
|||
if isinstance(b, bytes): |
|||
b = BytesIO(b) |
|||
|
|||
realm = b.read(1) |
|||
b.seek(-1, SEEK_CUR) |
|||
|
|||
if realm == b'\x00': |
|||
return LegacyOnionPayload.from_bytes(b) |
|||
elif realm != b'\x01': |
|||
return TlvPayload.from_bytes(b, skip_length=False) |
|||
else: |
|||
raise ValueError("Onion payloads with realm 0x01 are unsupported") |
|||
|
|||
@classmethod |
|||
def from_hex(cls, s): |
|||
if isinstance(s, str): |
|||
s = s.encode('ASCII') |
|||
return cls.from_bytes(bytes(unhexlify(s))) |
|||
|
|||
def to_bytes(self): |
|||
raise ValueError("OnionPayload is an abstract class, use " |
|||
"LegacyOnionPayload or TlvPayload instead") |
|||
|
|||
def to_hex(self): |
|||
return hexlify(self.to_bytes()).decode('ASCII') |
|||
|
|||
|
|||
class LegacyOnionPayload(OnionPayload): |
|||
|
|||
def __init__(self, amt_to_forward, outgoing_cltv_value, |
|||
short_channel_id=None, padding=None): |
|||
assert(padding is None or len(padding) == 12) |
|||
self.padding = b'\x00' * 12 if padding is None else padding |
|||
|
|||
if isinstance(amt_to_forward, str): |
|||
self.amt_to_forward = int(amt_to_forward) |
|||
else: |
|||
self.amt_to_forward = amt_to_forward |
|||
|
|||
self.outgoing_cltv_value = outgoing_cltv_value |
|||
|
|||
if isinstance(short_channel_id, str) and 'x' in short_channel_id: |
|||
# Convert the short_channel_id from its string representation to its numeric representation |
|||
block, tx, out = short_channel_id.split('x') |
|||
num_scid = int(block) << 40 | int(tx) << 16 | int(out) |
|||
self.short_channel_id = num_scid |
|||
elif isinstance(short_channel_id, int): |
|||
self.short_channel_id = short_channel_id |
|||
else: |
|||
raise ValueError("short_channel_id format cannot be recognized: {}".format(short_channel_id)) |
|||
|
|||
@classmethod |
|||
def from_bytes(cls, b): |
|||
if isinstance(b, bytes): |
|||
b = BytesIO(b) |
|||
|
|||
assert(b.read(1) == b'\x00') |
|||
|
|||
s, a, o = struct.unpack("!QQL", b.read(20)) |
|||
padding = b.read(12) |
|||
return LegacyOnionPayload(a, o, s, padding) |
|||
|
|||
def to_bytes(self, include_realm=True): |
|||
b = b'' |
|||
if include_realm: |
|||
b += b'\x00' |
|||
|
|||
b += struct.pack("!Q", self.short_channel_id) |
|||
b += struct.pack("!Q", self.amt_to_forward) |
|||
b += struct.pack("!L", self.outgoing_cltv_value) |
|||
b += self.padding |
|||
assert(len(b) == 32 + include_realm) |
|||
return b |
|||
|
|||
def to_hex(self, include_realm=True): |
|||
return hexlify(self.to_bytes(include_realm)).decode('ASCII') |
|||
|
|||
def __str__(self): |
|||
return ("LegacyOnionPayload[scid={self.short_channel_id}, " |
|||
"amt_to_forward={self.amt_to_forward}, " |
|||
"outgoing_cltv={self.outgoing_cltv_value}]").format(self=self) |
|||
|
|||
|
|||
class TlvPayload(OnionPayload): |
|||
|
|||
def __init__(self, fields=None): |
|||
self.fields = [] if fields is None else fields |
|||
|
|||
@classmethod |
|||
def from_bytes(cls, b, skip_length=False): |
|||
if isinstance(b, str): |
|||
b = b.encode('ASCII') |
|||
if isinstance(b, bytes): |
|||
b = BytesIO(b) |
|||
|
|||
if skip_length: |
|||
# Consume the entire remainder of the buffer. |
|||
payload_length = len(b.getvalue()) - b.tell() |
|||
else: |
|||
payload_length = varint_decode(b) |
|||
|
|||
instance = TlvPayload() |
|||
|
|||
start = b.tell() |
|||
while b.tell() < start + payload_length: |
|||
typenum = varint_decode(b) |
|||
if typenum is None: |
|||
break |
|||
length = varint_decode(b) |
|||
if length is None: |
|||
raise ValueError( |
|||
"Unable to read length at position {}".format(b.tell()) |
|||
) |
|||
val = b.read(length) |
|||
|
|||
# Get the subclass that is the correct interpretation of this |
|||
# field. Default to the binary field type. |
|||
c = tlv_types.get(typenum, (TlvField, "unknown")) |
|||
cls = c[0] |
|||
field = cls.from_bytes(typenum=typenum, b=val, description=c[1]) |
|||
instance.fields.append(field) |
|||
|
|||
return instance |
|||
|
|||
@classmethod |
|||
def from_hex(cls, h): |
|||
return cls.from_bytes(unhexlify(h)) |
|||
|
|||
def add_field(self, typenum, value): |
|||
self.fields.append(TlvField(typenum=typenum, value=value)) |
|||
|
|||
def get(self, key, default=None): |
|||
for f in self.fields: |
|||
if f.typenum == key: |
|||
return f |
|||
return default |
|||
|
|||
def to_bytes(self): |
|||
ser = [f.to_bytes() for f in self.fields] |
|||
b = BytesIO() |
|||
varint_encode(sum([len(b) for b in ser]), b) |
|||
for f in ser: |
|||
b.write(f) |
|||
return b.getvalue() |
|||
|
|||
def __str__(self): |
|||
return "TlvPayload[" + ', '.join([str(f) for f in self.fields]) + "]" |
|||
|
|||
|
|||
class TlvField(object): |
|||
|
|||
def __init__(self, typenum, value=None, description=None): |
|||
self.typenum = typenum |
|||
self.value = value |
|||
self.description = description |
|||
|
|||
@classmethod |
|||
def from_bytes(cls, typenum, b, description=None): |
|||
return TlvField(typenum=typenum, value=b, description=description) |
|||
|
|||
def __str__(self): |
|||
return "TlvField[{description},{num}={hex}]".format( |
|||
description=self.description, |
|||
num=self.typenum, |
|||
hex=hexlify(self.value).decode('ASCII') |
|||
) |
|||
|
|||
def to_bytes(self): |
|||
b = BytesIO() |
|||
varint_encode(self.typenum, b) |
|||
varint_encode(len(self.value), b) |
|||
b.write(self.value) |
|||
return b.getvalue() |
|||
|
|||
|
|||
class Tu32Field(TlvField): |
|||
pass |
|||
|
|||
|
|||
class Tu64Field(TlvField): |
|||
pass |
|||
|
|||
|
|||
class ShortChannelIdField(TlvField): |
|||
pass |
|||
|
|||
|
|||
class TextField(TlvField): |
|||
|
|||
@classmethod |
|||
def from_bytes(cls, typenum, b, description=None): |
|||
val = b.decode('UTF-8') |
|||
return TextField(typenum, value=val, description=description) |
|||
|
|||
def to_bytes(self): |
|||
b = BytesIO() |
|||
val = self.value.encode('UTF-8') |
|||
varint_encode(self.typenum, b) |
|||
varint_encode(len(val), b) |
|||
b.write(val) |
|||
return b.getvalue() |
|||
|
|||
def __str__(self): |
|||
return "TextField[{description},{num}=\"{val}\"]".format( |
|||
description=self.description, |
|||
num=self.typenum, |
|||
val=self.value, |
|||
) |
|||
|
|||
|
|||
class HashField(TlvField): |
|||
pass |
|||
|
|||
|
|||
class SignatureField(TlvField): |
|||
pass |
|||
|
|||
|
|||
# A mapping of known TLV types |
|||
tlv_types = { |
|||
2: (Tu64Field, 'amt_to_forward'), |
|||
4: (Tu32Field, 'outgoing_cltv_value'), |
|||
6: (ShortChannelIdField, 'short_channel_id'), |
|||
34349334: (TextField, 'noise_message_body'), |
|||
34349336: (SignatureField, 'noise_message_signature'), |
|||
} |
@ -0,0 +1,70 @@ |
|||
import struct |
|||
|
|||
|
|||
def varint_encode(i, w): |
|||
"""Encode an integer `i` into the writer `w` |
|||
""" |
|||
if i < 0xFD: |
|||
w.write(struct.pack("!B", i)) |
|||
elif i <= 0xFFFF: |
|||
w.write(struct.pack("!BH", 0xFD, i)) |
|||
elif i <= 0xFFFFFFFF: |
|||
w.write(struct.pack("!BL", 0xFE, i)) |
|||
else: |
|||
w.write(struct.pack("!BQ", 0xFF, i)) |
|||
|
|||
|
|||
def varint_decode(r): |
|||
"""Decode an integer from reader `r` |
|||
""" |
|||
raw = r.read(1) |
|||
if len(raw) != 1: |
|||
return None |
|||
|
|||
i, = struct.unpack("!B", raw) |
|||
if i < 0xFD: |
|||
return i |
|||
elif i == 0xFD: |
|||
return struct.unpack("!H", r.read(2))[0] |
|||
elif i == 0xFE: |
|||
return struct.unpack("!L", r.read(4))[0] |
|||
else: |
|||
return struct.unpack("!Q", r.read(8))[0] |
|||
|
|||
|
|||
class ShortChannelId(object): |
|||
def __init__(self, block, txnum, outnum): |
|||
self.block = block |
|||
self.txnum = txnum |
|||
self.outnum = outnum |
|||
|
|||
@classmethod |
|||
def from_bytes(cls, b): |
|||
assert(len(b) == 8) |
|||
i, = struct.unpack("!Q", b) |
|||
return cls.from_int(i) |
|||
|
|||
@classmethod |
|||
def from_int(cls, i): |
|||
block = (i >> 40) & 0xFFFFFF |
|||
txnum = (i >> 16) & 0xFFFFFF |
|||
outnum = (i >> 0) & 0xFFFF |
|||
return cls(block=block, txnum=txnum, outnum=outnum) |
|||
|
|||
@classmethod |
|||
def from_str(self, s): |
|||
block, txnum, outnum = s.split('x') |
|||
return ShortChannelId(block=int(block), txnum=int(txnum), |
|||
outnum=int(outnum)) |
|||
|
|||
def to_int(self): |
|||
return self.block << 40 | self.txnum << 16 | self.outnum |
|||
|
|||
def to_bytes(self): |
|||
return struct.pack("!Q", self.to_int()) |
|||
|
|||
def __str__(self): |
|||
return "{self.block}x{self.txnum}x{self.outnum}".format(self=self) |
|||
|
|||
def __eq__(self, other): |
|||
return self.block == other.block and self.txnum == other.txnum and self.outnum == other.outnum |
@ -0,0 +1,56 @@ |
|||
import bitstring |
|||
|
|||
|
|||
zbase32_chars = b'ybndrfg8ejkmcpqxot1uwisza345h769' |
|||
zbase32_revchars = [ |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 18, 255, 25, 26, 27, 30, 29, 7, 31, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 24, 1, 12, 3, 8, 5, 6, 28, 21, 9, 10, 255, 11, 2, |
|||
16, 13, 14, 4, 22, 17, 19, 255, 20, 15, 0, 23, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
|||
255, 255, 255, 255, 255, 255, 255 |
|||
] |
|||
|
|||
|
|||
def bitarray_to_u5(barr): |
|||
assert len(barr) % 5 == 0 |
|||
ret = [] |
|||
s = bitstring.ConstBitStream(barr) |
|||
while s.pos != s.len: |
|||
ret.append(s.read(5).uint) |
|||
return ret |
|||
|
|||
|
|||
def u5_to_bitarray(arr): |
|||
ret = bitstring.BitArray() |
|||
for a in arr: |
|||
ret += bitstring.pack("uint:5", a) |
|||
return ret |
|||
|
|||
|
|||
def encode(b): |
|||
uint5s = bitarray_to_u5(b) |
|||
res = [zbase32_chars[c] for c in uint5s] |
|||
return bytes(res) |
|||
|
|||
|
|||
def decode(b): |
|||
if isinstance(b, str): |
|||
b = b.encode('ASCII') |
|||
|
|||
uint5s = [] |
|||
for c in b: |
|||
uint5s.append(zbase32_revchars[c]) |
|||
dec = u5_to_bitarray(uint5s) |
|||
return dec.bytes |
@ -1,2 +1,3 @@ |
|||
bitstring==3.1.6 |
|||
cryptography==2.7 |
|||
coincurve==12.0.0 |
|||
|
@ -0,0 +1,32 @@ |
|||
from binascii import unhexlify |
|||
|
|||
from pyln.proto import onion |
|||
|
|||
|
|||
def test_legacy_payload(): |
|||
legacy = unhexlify( |
|||
b'00000067000001000100000000000003e800000075000000000000000000000000' |
|||
) |
|||
payload = onion.OnionPayload.from_bytes(legacy) |
|||
assert(payload.to_bytes(include_realm=True) == legacy) |
|||
|
|||
|
|||
def test_tlv_payload(): |
|||
tlv = unhexlify( |
|||
b'58fe020c21160c48656c6c6f20776f726c6421fe020c21184076e8acd54afbf2361' |
|||
b'0b7166ba689afcc9e8ec3c44e442e765012dfc1d299958827d0205f7e4e1a12620e' |
|||
b'7fc8ce1c7d3651acefde899c33f12b6958d3304106a0' |
|||
) |
|||
payload = onion.OnionPayload.from_bytes(tlv) |
|||
assert(payload.to_bytes() == tlv) |
|||
|
|||
fields = payload.fields |
|||
assert(len(fields) == 2) |
|||
assert(isinstance(fields[0], onion.TextField)) |
|||
assert(fields[0].typenum == 34349334 and fields[0].value == "Hello world!") |
|||
assert(fields[1].typenum == 34349336 and fields[1].value == unhexlify( |
|||
b'76e8acd54afbf23610b7166ba689afcc9e8ec3c44e442e765012dfc1d299958827d' |
|||
b'0205f7e4e1a12620e7fc8ce1c7d3651acefde899c33f12b6958d3304106a0' |
|||
)) |
|||
|
|||
assert(payload.to_bytes() == tlv) |
@ -0,0 +1,30 @@ |
|||
from binascii import hexlify, unhexlify |
|||
from pyln.proto import zbase32 |
|||
from pyln.proto.primitives import ShortChannelId |
|||
|
|||
|
|||
def test_short_channel_id(): |
|||
num = 618150934845652992 |
|||
b = unhexlify(b'08941d00090d0000') |
|||
s = '562205x2317x0' |
|||
s1 = ShortChannelId.from_int(num) |
|||
s2 = ShortChannelId.from_str(s) |
|||
s3 = ShortChannelId.from_bytes(b) |
|||
expected = ShortChannelId(block=562205, txnum=2317, outnum=0) |
|||
|
|||
assert(s1 == expected) |
|||
assert(s2 == expected) |
|||
assert(s3 == expected) |
|||
|
|||
assert(expected.to_bytes() == b) |
|||
assert(str(expected) == s) |
|||
assert(expected.to_int() == num) |
|||
|
|||
|
|||
def test_zbase32(): |
|||
zb32 = b'd75qtmgijm79rpooshmgzjwji9gj7dsdat8remuskyjp9oq1ugkaoj6orbxzhuo4njtyh96e3aq84p1tiuz77nchgxa1s4ka4carnbiy' |
|||
b = zbase32.decode(zb32) |
|||
assert(hexlify(b) == b'1f76e8acd54afbf23610b7166ba689afcc9e8ec3c44e442e765012dfc1d299958827d0205f7e4e1a12620e7fc8ce1c7d3651acefde899c33f12b6958d3304106a0') |
|||
|
|||
enc = zbase32.encode(b) |
|||
assert(enc == zb32) |
Loading…
Reference in new issue