Christian Decker
5 years ago
committed by
Rusty Russell
6 changed files with 425 additions and 0 deletions
@ -0,0 +1,236 @@ |
|||||
|
from .primitives import varint_decode, varint_encode |
||||
|
from io import BytesIO, SEEK_CUR |
||||
|
from binascii import hexlify, unhexlify |
||||
|
import struct |
||||
|
|
||||
|
|
||||
|
class OnionPayload(object): |
||||
|
|
||||
|
@classmethod |
||||
|
def from_bytes(cls, b): |
||||
|
if isinstance(b, bytes): |
||||
|
b = BytesIO(b) |
||||
|
|
||||
|
realm = b.read(1) |
||||
|
b.seek(-1, SEEK_CUR) |
||||
|
|
||||
|
if realm == b'\x00': |
||||
|
return LegacyOnionPayload.from_bytes(b) |
||||
|
elif realm != b'\x01': |
||||
|
return TlvPayload.from_bytes(b, skip_length=False) |
||||
|
else: |
||||
|
raise ValueError("Onion payloads with realm 0x01 are unsupported") |
||||
|
|
||||
|
@classmethod |
||||
|
def from_hex(cls, s): |
||||
|
if isinstance(s, str): |
||||
|
s = s.encode('ASCII') |
||||
|
return cls.from_bytes(bytes(unhexlify(s))) |
||||
|
|
||||
|
def to_bytes(self): |
||||
|
raise ValueError("OnionPayload is an abstract class, use " |
||||
|
"LegacyOnionPayload or TlvPayload instead") |
||||
|
|
||||
|
def to_hex(self): |
||||
|
return hexlify(self.to_bytes()).decode('ASCII') |
||||
|
|
||||
|
|
||||
|
class LegacyOnionPayload(OnionPayload): |
||||
|
|
||||
|
def __init__(self, amt_to_forward, outgoing_cltv_value, |
||||
|
short_channel_id=None, padding=None): |
||||
|
assert(padding is None or len(padding) == 12) |
||||
|
self.padding = b'\x00' * 12 if padding is None else padding |
||||
|
|
||||
|
if isinstance(amt_to_forward, str): |
||||
|
self.amt_to_forward = int(amt_to_forward) |
||||
|
else: |
||||
|
self.amt_to_forward = amt_to_forward |
||||
|
|
||||
|
self.outgoing_cltv_value = outgoing_cltv_value |
||||
|
|
||||
|
if isinstance(short_channel_id, str) and 'x' in short_channel_id: |
||||
|
# Convert the short_channel_id from its string representation to its numeric representation |
||||
|
block, tx, out = short_channel_id.split('x') |
||||
|
num_scid = int(block) << 40 | int(tx) << 16 | int(out) |
||||
|
self.short_channel_id = num_scid |
||||
|
elif isinstance(short_channel_id, int): |
||||
|
self.short_channel_id = short_channel_id |
||||
|
else: |
||||
|
raise ValueError("short_channel_id format cannot be recognized: {}".format(short_channel_id)) |
||||
|
|
||||
|
@classmethod |
||||
|
def from_bytes(cls, b): |
||||
|
if isinstance(b, bytes): |
||||
|
b = BytesIO(b) |
||||
|
|
||||
|
assert(b.read(1) == b'\x00') |
||||
|
|
||||
|
s, a, o = struct.unpack("!QQL", b.read(20)) |
||||
|
padding = b.read(12) |
||||
|
return LegacyOnionPayload(a, o, s, padding) |
||||
|
|
||||
|
def to_bytes(self, include_realm=True): |
||||
|
b = b'' |
||||
|
if include_realm: |
||||
|
b += b'\x00' |
||||
|
|
||||
|
b += struct.pack("!Q", self.short_channel_id) |
||||
|
b += struct.pack("!Q", self.amt_to_forward) |
||||
|
b += struct.pack("!L", self.outgoing_cltv_value) |
||||
|
b += self.padding |
||||
|
assert(len(b) == 32 + include_realm) |
||||
|
return b |
||||
|
|
||||
|
def to_hex(self, include_realm=True): |
||||
|
return hexlify(self.to_bytes(include_realm)).decode('ASCII') |
||||
|
|
||||
|
def __str__(self): |
||||
|
return ("LegacyOnionPayload[scid={self.short_channel_id}, " |
||||
|
"amt_to_forward={self.amt_to_forward}, " |
||||
|
"outgoing_cltv={self.outgoing_cltv_value}]").format(self=self) |
||||
|
|
||||
|
|
||||
|
class TlvPayload(OnionPayload): |
||||
|
|
||||
|
def __init__(self, fields=None): |
||||
|
self.fields = [] if fields is None else fields |
||||
|
|
||||
|
@classmethod |
||||
|
def from_bytes(cls, b, skip_length=False): |
||||
|
if isinstance(b, str): |
||||
|
b = b.encode('ASCII') |
||||
|
if isinstance(b, bytes): |
||||
|
b = BytesIO(b) |
||||
|
|
||||
|
if skip_length: |
||||
|
# Consume the entire remainder of the buffer. |
||||
|
payload_length = len(b.getvalue()) - b.tell() |
||||
|
else: |
||||
|
payload_length = varint_decode(b) |
||||
|
|
||||
|
instance = TlvPayload() |
||||
|
|
||||
|
start = b.tell() |
||||
|
while b.tell() < start + payload_length: |
||||
|
typenum = varint_decode(b) |
||||
|
if typenum is None: |
||||
|
break |
||||
|
length = varint_decode(b) |
||||
|
if length is None: |
||||
|
raise ValueError( |
||||
|
"Unable to read length at position {}".format(b.tell()) |
||||
|
) |
||||
|
val = b.read(length) |
||||
|
|
||||
|
# Get the subclass that is the correct interpretation of this |
||||
|
# field. Default to the binary field type. |
||||
|
c = tlv_types.get(typenum, (TlvField, "unknown")) |
||||
|
cls = c[0] |
||||
|
field = cls.from_bytes(typenum=typenum, b=val, description=c[1]) |
||||
|
instance.fields.append(field) |
||||
|
|
||||
|
return instance |
||||
|
|
||||
|
@classmethod |
||||
|
def from_hex(cls, h): |
||||
|
return cls.from_bytes(unhexlify(h)) |
||||
|
|
||||
|
def add_field(self, typenum, value): |
||||
|
self.fields.append(TlvField(typenum=typenum, value=value)) |
||||
|
|
||||
|
def get(self, key, default=None): |
||||
|
for f in self.fields: |
||||
|
if f.typenum == key: |
||||
|
return f |
||||
|
return default |
||||
|
|
||||
|
def to_bytes(self): |
||||
|
ser = [f.to_bytes() for f in self.fields] |
||||
|
b = BytesIO() |
||||
|
varint_encode(sum([len(b) for b in ser]), b) |
||||
|
for f in ser: |
||||
|
b.write(f) |
||||
|
return b.getvalue() |
||||
|
|
||||
|
def __str__(self): |
||||
|
return "TlvPayload[" + ', '.join([str(f) for f in self.fields]) + "]" |
||||
|
|
||||
|
|
||||
|
class TlvField(object): |
||||
|
|
||||
|
def __init__(self, typenum, value=None, description=None): |
||||
|
self.typenum = typenum |
||||
|
self.value = value |
||||
|
self.description = description |
||||
|
|
||||
|
@classmethod |
||||
|
def from_bytes(cls, typenum, b, description=None): |
||||
|
return TlvField(typenum=typenum, value=b, description=description) |
||||
|
|
||||
|
def __str__(self): |
||||
|
return "TlvField[{description},{num}={hex}]".format( |
||||
|
description=self.description, |
||||
|
num=self.typenum, |
||||
|
hex=hexlify(self.value).decode('ASCII') |
||||
|
) |
||||
|
|
||||
|
def to_bytes(self): |
||||
|
b = BytesIO() |
||||
|
varint_encode(self.typenum, b) |
||||
|
varint_encode(len(self.value), b) |
||||
|
b.write(self.value) |
||||
|
return b.getvalue() |
||||
|
|
||||
|
|
||||
|
class Tu32Field(TlvField): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
class Tu64Field(TlvField): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
class ShortChannelIdField(TlvField): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
class TextField(TlvField): |
||||
|
|
||||
|
@classmethod |
||||
|
def from_bytes(cls, typenum, b, description=None): |
||||
|
val = b.decode('UTF-8') |
||||
|
return TextField(typenum, value=val, description=description) |
||||
|
|
||||
|
def to_bytes(self): |
||||
|
b = BytesIO() |
||||
|
val = self.value.encode('UTF-8') |
||||
|
varint_encode(self.typenum, b) |
||||
|
varint_encode(len(val), b) |
||||
|
b.write(val) |
||||
|
return b.getvalue() |
||||
|
|
||||
|
def __str__(self): |
||||
|
return "TextField[{description},{num}=\"{val}\"]".format( |
||||
|
description=self.description, |
||||
|
num=self.typenum, |
||||
|
val=self.value, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class HashField(TlvField): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
class SignatureField(TlvField): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
# A mapping of known TLV types |
||||
|
tlv_types = { |
||||
|
2: (Tu64Field, 'amt_to_forward'), |
||||
|
4: (Tu32Field, 'outgoing_cltv_value'), |
||||
|
6: (ShortChannelIdField, 'short_channel_id'), |
||||
|
34349334: (TextField, 'noise_message_body'), |
||||
|
34349336: (SignatureField, 'noise_message_signature'), |
||||
|
} |
@ -0,0 +1,70 @@ |
|||||
|
import struct |
||||
|
|
||||
|
|
||||
|
def varint_encode(i, w): |
||||
|
"""Encode an integer `i` into the writer `w` |
||||
|
""" |
||||
|
if i < 0xFD: |
||||
|
w.write(struct.pack("!B", i)) |
||||
|
elif i <= 0xFFFF: |
||||
|
w.write(struct.pack("!BH", 0xFD, i)) |
||||
|
elif i <= 0xFFFFFFFF: |
||||
|
w.write(struct.pack("!BL", 0xFE, i)) |
||||
|
else: |
||||
|
w.write(struct.pack("!BQ", 0xFF, i)) |
||||
|
|
||||
|
|
||||
|
def varint_decode(r): |
||||
|
"""Decode an integer from reader `r` |
||||
|
""" |
||||
|
raw = r.read(1) |
||||
|
if len(raw) != 1: |
||||
|
return None |
||||
|
|
||||
|
i, = struct.unpack("!B", raw) |
||||
|
if i < 0xFD: |
||||
|
return i |
||||
|
elif i == 0xFD: |
||||
|
return struct.unpack("!H", r.read(2))[0] |
||||
|
elif i == 0xFE: |
||||
|
return struct.unpack("!L", r.read(4))[0] |
||||
|
else: |
||||
|
return struct.unpack("!Q", r.read(8))[0] |
||||
|
|
||||
|
|
||||
|
class ShortChannelId(object): |
||||
|
def __init__(self, block, txnum, outnum): |
||||
|
self.block = block |
||||
|
self.txnum = txnum |
||||
|
self.outnum = outnum |
||||
|
|
||||
|
@classmethod |
||||
|
def from_bytes(cls, b): |
||||
|
assert(len(b) == 8) |
||||
|
i, = struct.unpack("!Q", b) |
||||
|
return cls.from_int(i) |
||||
|
|
||||
|
@classmethod |
||||
|
def from_int(cls, i): |
||||
|
block = (i >> 40) & 0xFFFFFF |
||||
|
txnum = (i >> 16) & 0xFFFFFF |
||||
|
outnum = (i >> 0) & 0xFFFF |
||||
|
return cls(block=block, txnum=txnum, outnum=outnum) |
||||
|
|
||||
|
@classmethod |
||||
|
def from_str(self, s): |
||||
|
block, txnum, outnum = s.split('x') |
||||
|
return ShortChannelId(block=int(block), txnum=int(txnum), |
||||
|
outnum=int(outnum)) |
||||
|
|
||||
|
def to_int(self): |
||||
|
return self.block << 40 | self.txnum << 16 | self.outnum |
||||
|
|
||||
|
def to_bytes(self): |
||||
|
return struct.pack("!Q", self.to_int()) |
||||
|
|
||||
|
def __str__(self): |
||||
|
return "{self.block}x{self.txnum}x{self.outnum}".format(self=self) |
||||
|
|
||||
|
def __eq__(self, other): |
||||
|
return self.block == other.block and self.txnum == other.txnum and self.outnum == other.outnum |
@ -0,0 +1,56 @@ |
|||||
|
import bitstring |
||||
|
|
||||
|
|
||||
|
zbase32_chars = b'ybndrfg8ejkmcpqxot1uwisza345h769' |
||||
|
zbase32_revchars = [ |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 18, 255, 25, 26, 27, 30, 29, 7, 31, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 24, 1, 12, 3, 8, 5, 6, 28, 21, 9, 10, 255, 11, 2, |
||||
|
16, 13, 14, 4, 22, 17, 19, 255, 20, 15, 0, 23, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, |
||||
|
255, 255, 255, 255, 255, 255, 255 |
||||
|
] |
||||
|
|
||||
|
|
||||
|
def bitarray_to_u5(barr): |
||||
|
assert len(barr) % 5 == 0 |
||||
|
ret = [] |
||||
|
s = bitstring.ConstBitStream(barr) |
||||
|
while s.pos != s.len: |
||||
|
ret.append(s.read(5).uint) |
||||
|
return ret |
||||
|
|
||||
|
|
||||
|
def u5_to_bitarray(arr): |
||||
|
ret = bitstring.BitArray() |
||||
|
for a in arr: |
||||
|
ret += bitstring.pack("uint:5", a) |
||||
|
return ret |
||||
|
|
||||
|
|
||||
|
def encode(b): |
||||
|
uint5s = bitarray_to_u5(b) |
||||
|
res = [zbase32_chars[c] for c in uint5s] |
||||
|
return bytes(res) |
||||
|
|
||||
|
|
||||
|
def decode(b): |
||||
|
if isinstance(b, str): |
||||
|
b = b.encode('ASCII') |
||||
|
|
||||
|
uint5s = [] |
||||
|
for c in b: |
||||
|
uint5s.append(zbase32_revchars[c]) |
||||
|
dec = u5_to_bitarray(uint5s) |
||||
|
return dec.bytes |
@ -1,2 +1,3 @@ |
|||||
|
bitstring==3.1.6 |
||||
cryptography==2.7 |
cryptography==2.7 |
||||
coincurve==12.0.0 |
coincurve==12.0.0 |
||||
|
@ -0,0 +1,32 @@ |
|||||
|
from binascii import unhexlify |
||||
|
|
||||
|
from pyln.proto import onion |
||||
|
|
||||
|
|
||||
|
def test_legacy_payload(): |
||||
|
legacy = unhexlify( |
||||
|
b'00000067000001000100000000000003e800000075000000000000000000000000' |
||||
|
) |
||||
|
payload = onion.OnionPayload.from_bytes(legacy) |
||||
|
assert(payload.to_bytes(include_realm=True) == legacy) |
||||
|
|
||||
|
|
||||
|
def test_tlv_payload(): |
||||
|
tlv = unhexlify( |
||||
|
b'58fe020c21160c48656c6c6f20776f726c6421fe020c21184076e8acd54afbf2361' |
||||
|
b'0b7166ba689afcc9e8ec3c44e442e765012dfc1d299958827d0205f7e4e1a12620e' |
||||
|
b'7fc8ce1c7d3651acefde899c33f12b6958d3304106a0' |
||||
|
) |
||||
|
payload = onion.OnionPayload.from_bytes(tlv) |
||||
|
assert(payload.to_bytes() == tlv) |
||||
|
|
||||
|
fields = payload.fields |
||||
|
assert(len(fields) == 2) |
||||
|
assert(isinstance(fields[0], onion.TextField)) |
||||
|
assert(fields[0].typenum == 34349334 and fields[0].value == "Hello world!") |
||||
|
assert(fields[1].typenum == 34349336 and fields[1].value == unhexlify( |
||||
|
b'76e8acd54afbf23610b7166ba689afcc9e8ec3c44e442e765012dfc1d299958827d' |
||||
|
b'0205f7e4e1a12620e7fc8ce1c7d3651acefde899c33f12b6958d3304106a0' |
||||
|
)) |
||||
|
|
||||
|
assert(payload.to_bytes() == tlv) |
@ -0,0 +1,30 @@ |
|||||
|
from binascii import hexlify, unhexlify |
||||
|
from pyln.proto import zbase32 |
||||
|
from pyln.proto.primitives import ShortChannelId |
||||
|
|
||||
|
|
||||
|
def test_short_channel_id(): |
||||
|
num = 618150934845652992 |
||||
|
b = unhexlify(b'08941d00090d0000') |
||||
|
s = '562205x2317x0' |
||||
|
s1 = ShortChannelId.from_int(num) |
||||
|
s2 = ShortChannelId.from_str(s) |
||||
|
s3 = ShortChannelId.from_bytes(b) |
||||
|
expected = ShortChannelId(block=562205, txnum=2317, outnum=0) |
||||
|
|
||||
|
assert(s1 == expected) |
||||
|
assert(s2 == expected) |
||||
|
assert(s3 == expected) |
||||
|
|
||||
|
assert(expected.to_bytes() == b) |
||||
|
assert(str(expected) == s) |
||||
|
assert(expected.to_int() == num) |
||||
|
|
||||
|
|
||||
|
def test_zbase32(): |
||||
|
zb32 = b'd75qtmgijm79rpooshmgzjwji9gj7dsdat8remuskyjp9oq1ugkaoj6orbxzhuo4njtyh96e3aq84p1tiuz77nchgxa1s4ka4carnbiy' |
||||
|
b = zbase32.decode(zb32) |
||||
|
assert(hexlify(b) == b'1f76e8acd54afbf23610b7166ba689afcc9e8ec3c44e442e765012dfc1d299958827d0205f7e4e1a12620e7fc8ce1c7d3651acefde899c33f12b6958d3304106a0') |
||||
|
|
||||
|
enc = zbase32.encode(b) |
||||
|
assert(enc == zb32) |
Loading…
Reference in new issue