Browse Source

pylightning: Implement the lightning handshake and wire protocol

Simple transcript of the specification in python :-)

Signed-off-by: Christian Decker <decker.christian@gmail.com>
pull/2803/head
Christian Decker 6 years ago
parent
commit
8fc813e0da
  1. 379
      contrib/pylightning/lightning/wire.py
  2. 2
      contrib/pylightning/requirements.txt
  3. 6
      contrib/pylightning/setup.py
  4. 189
      contrib/pylightning/tests/test_wire.py

379
contrib/pylightning/lightning/wire.py

@ -0,0 +1,379 @@
from binascii import hexlify
from cryptography.exceptions import InvalidTag
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import serialization
from hashlib import sha256
import coincurve
import os
import socket
import struct
def hkdf(ikm, salt=b"", info=b""):
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=64,
salt=salt,
info=info,
backend=default_backend())
return hkdf.derive(ikm)
def hkdf_two_keys(ikm, salt):
t = hkdf(ikm, salt)
return t[:32], t[32:]
def ecdh(k, rk):
k = coincurve.PrivateKey(secret=k.rawkey)
rk = coincurve.PublicKey(data=rk.serializeCompressed())
a = k.ecdh(rk.public_key)
return Secret(a)
def encryptWithAD(k, n, ad, plaintext):
chacha = ChaCha20Poly1305(k)
return chacha.encrypt(n, plaintext, ad)
def decryptWithAD(k, n, ad, ciphertext):
chacha = ChaCha20Poly1305(k)
return chacha.decrypt(n, ciphertext, ad)
class PrivateKey(object):
def __init__(self, rawkey):
assert len(rawkey) == 32 and isinstance(rawkey, bytes)
self.rawkey = rawkey
rawkey = int(hexlify(rawkey), base=16)
self.key = ec.derive_private_key(rawkey, ec.SECP256K1(),
default_backend())
def serializeCompressed(self):
return self.key.private_bytes(serialization.Encoding.Raw,
serialization.PrivateFormat.Raw, None)
def public_key(self):
return PublicKey(self.key.public_key())
class Secret(object):
def __init__(self, raw):
assert(len(raw) == 32)
self.raw = raw
def __str__(self):
return "Secret[0x{}]".format(hexlify(self.raw).decode('ASCII'))
class PublicKey(object):
def __init__(self, innerkey):
# We accept either 33-bytes raw keys, or an EC PublicKey as returned
# by cryptography.io
if isinstance(innerkey, bytes):
innerkey = ec.EllipticCurvePublicKey.from_encoded_point(
ec.SECP256K1(), innerkey
)
elif not isinstance(innerkey, ec.EllipticCurvePublicKey):
raise ValueError(
"Key must either be bytes or ec.EllipticCurvePublicKey"
)
self.key = innerkey
def serializeCompressed(self):
raw = self.key.public_bytes(
serialization.Encoding.X962,
serialization.PublicFormat.CompressedPoint
)
return raw
def __str__(self):
return "PublicKey[0x{}]".format(
hexlify(self.serializeCompressed()).decode('ASCII')
)
def Keypair(object):
def __init__(self, priv, pub):
self.priv, self.pub = priv, pub
class Sha256Mixer(object):
def __init__(self, base):
self.hash = sha256(base).digest()
def update(self, data):
h = sha256(self.hash)
h.update(data)
self.hash = h.digest()
return self.hash
def digest(self):
return self.hash
def __str__(self):
return "Sha256Mixer[0x{}]".format(hexlify(self.hash).decode('ASCII'))
class LightningConnection(object):
def __init__(self, connection, remote_pubkey, local_privkey, is_initiator):
self.connection = connection
self.chaining_key = None
self.handshake_hash = None
self.local_privkey = local_privkey
self.local_pubkey = self.local_privkey.public_key()
self.remote_pubkey = remote_pubkey
self.is_initiator = is_initiator
self.init_handshake()
self.rn, self.sn = 0, 0
@classmethod
def nonce(cls, n):
"""Transforms a numeric nonce into a byte formatted one
Nonce n encoded as 32 zero bits, followed by a little-endian 64-bit
value. Note: this follows the Noise Protocol convention, rather than
our normal endian.
"""
return b'\x00' * 4 + struct.pack("<Q", n)
def init_handshake(self):
h = sha256(b'Noise_XK_secp256k1_ChaChaPoly_SHA256').digest()
self.chaining_key = h
h = sha256(h + b'lightning').digest()
if self.is_initiator:
responder_pubkey = self.remote_pubkey
else:
responder_pubkey = self.local_pubkey
h = sha256(h + responder_pubkey.serializeCompressed()).digest()
self.handshake = {
'h': h,
'e': PrivateKey(os.urandom(32)),
}
def handshake_act_one_initiator(self):
h = Sha256Mixer(b'')
h.hash = self.handshake['h']
h.update(self.handshake['e'].public_key().serializeCompressed())
es = ecdh(self.handshake['e'], self.remote_pubkey)
t = hkdf(salt=self.chaining_key, ikm=es.raw, info=b'')
assert(len(t) == 64)
self.chaining_key, temp_k1 = t[:32], t[32:]
c = encryptWithAD(temp_k1, self.nonce(0), h.digest(), b'')
self.handshake['h'] = h.update(c)
pk = self.handshake['e'].public_key().serializeCompressed()
m = b'\x00' + pk + c
return m
def handshake_act_one_responder(self, m):
v, re, c = m[0], PublicKey(m[1:34]), m[34:]
if v != 0:
raise ValueError("Unsupported handshake version {}, only version "
"0 is supported.".format(v))
h = Sha256Mixer(b'')
h.hash = self.handshake['h']
h.update(re.serializeCompressed())
es = ecdh(self.local_privkey, re)
self.handshake['re'] = re
t = hkdf(salt=self.chaining_key, ikm=es.raw, info=b'')
self.chaining_key, temp_k1 = t[:32], t[32:]
try:
decryptWithAD(temp_k1, self.nonce(0), h.digest(), c)
except InvalidTag:
ValueError("Verification of tag failed, remote peer doesn't know "
"our node ID.")
h.update(c)
self.handshake['h'] = h.digest()
def handshake_act_two_responder(self):
h = Sha256Mixer(b'')
h.hash = self.handshake['h']
h.update(self.handshake['e'].public_key().serializeCompressed())
ee = ecdh(self.handshake['e'], self.handshake['re'])
t = hkdf(salt=self.chaining_key, ikm=ee.raw, info=b'')
assert(len(t) == 64)
self.chaining_key, self.temp_k2 = t[:32], t[32:]
c = encryptWithAD(self.temp_k2, self.nonce(0), h.digest(), b'')
h.update(c)
self.handshake['h'] = h.digest()
pk = self.handshake['e'].public_key().serializeCompressed()
m = b'\x00' + pk + c
return m
def handshake_act_two_initiator(self, m):
v, re, c = m[0], PublicKey(m[1:34]), m[34:]
if v != 0:
raise ValueError("Unsupported handshake version {}, only version "
"0 is supported.".format(v))
self.re = re
h = Sha256Mixer(b'')
h.hash = self.handshake['h']
h.update(re.serializeCompressed())
ee = ecdh(self.handshake['e'], re)
self.chaining_key, self.temp_k2 = hkdf_two_keys(
salt=self.chaining_key, ikm=ee.raw
)
try:
decryptWithAD(self.temp_k2, self.nonce(0), h.digest(), c)
except InvalidTag:
ValueError("Verification of tag failed.")
h.update(c)
self.handshake['h'] = h.digest()
def handshake_act_three_initiator(self):
h = Sha256Mixer(b'')
h.hash = self.handshake['h']
pk = self.local_pubkey.serializeCompressed()
c = encryptWithAD(self.temp_k2, self.nonce(1), h.digest(), pk)
h.update(c)
se = ecdh(self.local_privkey, self.re)
self.chaining_key, self.temp_k3 = hkdf_two_keys(
salt=self.chaining_key, ikm=se.raw
)
t = encryptWithAD(self.temp_k3, self.nonce(0), h.digest(), b'')
m = b'\x00' + c + t
t = hkdf(salt=self.chaining_key, ikm=b'', info=b'')
self.sk, self.rk = hkdf_two_keys(salt=self.chaining_key, ikm=b'')
self.rn, self.sn = 0, 0
return m
def handshake_act_three_responder(self, m):
h = Sha256Mixer(b'')
h.hash = self.handshake['h']
v, c, t = m[0], m[1:50], m[50:]
if v != 0:
raise ValueError("Unsupported handshake version {}, only version "
"0 is supported.".format(v))
rs = decryptWithAD(self.temp_k2, self.nonce(1), h.digest(), c)
h.update(c)
se = ecdh(self.handshake['e'], PublicKey(rs))
self.chaining_key, self.temp_k3 = hkdf_two_keys(
se.raw, self.chaining_key
)
decryptWithAD(self.temp_k3, self.nonce(0), h.digest(), t)
self.rn, self.sn = 0, 0
self.rk, self.sk = hkdf_two_keys(salt=self.chaining_key, ikm=b'')
def read_message(self):
lc = self.connection.recv(18)
if len(lc) != 18:
raise ValueError(
"Short read reading the message length: 18 != {}".format(
len(lc))
)
length = decryptWithAD(self.rk, self.nonce(self.rn), b'', lc)
length, = struct.unpack("!H", length)
self.rn += 1
mc = self.connection.recv(length + 16)
if len(mc) < length + 16:
raise ValueError("Short read reading the message: {} != {}".format(
length + 16, len(lc))
)
m = decryptWithAD(self.rk, self.nonce(self.rn), b'', mc)
self.rn += 1
assert(self.rn % 2 == 0)
self._maybe_rotate_keys()
return m
def send_message(self, m):
length = struct.pack("!H", len(m))
lc = encryptWithAD(self.sk, self.nonce(self.sn), b'', length)
self.sn += 1
mc = encryptWithAD(self.sk, self.nonce(self.sn), b'', m)
self.sn += 1
self.connection.send(lc)
self.connection.send(mc)
assert(self.sn % 2 == 0)
self._maybe_rotate_keys()
def _maybe_rotate_keys(self):
if self.sn == 1000:
self.sck, self.sk = hkdf_two_keys(salt=self.sck, ikm=self.sk)
self.sn = 0
if self.rn == 1000:
self.rck, self.rk = hkdf_two_keys(salt=self.rck, ikm=self.rk)
self.rn = 0
def shake(self):
if self.is_initiator:
m = self.handshake_act_one_initiator()
self.connection.send(m)
m = self.connection.recv(50)
if len(m) != 50:
raise ValueError(
"Short read from peer reading act2: 50 != {}".format(
len(m))
)
self.handshake_act_two_initiator(m)
m = self.handshake_act_three_initiator()
self.connection.send(m)
else:
m = self.connection.recv(50)
if len(m) != 50:
raise ValueError(
"Short read from peer reading act1: 50 != {}".format(
len(m))
)
self.handshake_act_one_responder(m)
m = self.handshake_act_two_responder()
self.connection.send(m)
m = self.connection.recv(66)
if len(m) != 66:
raise ValueError(
"Short read from peer reading act3: 66 != {}".format(
len(m))
)
self.handshake_act_three_responder(m)
self.sck = self.chaining_key
self.rck = self.chaining_key
class LightningServerSocket(socket.socket):
def __init__(self, local_privkey):
socket.socket.__init__(self)
self.local_privkey = local_privkey
self.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
def accept(self):
print("Accepting a new socket")
conn, address = socket.socket.accept(self)
lconn = LightningConnection(
conn, remote_pubkey=None,
local_privkey=self.local_privkey,
is_initiator=False)
lconn.shake()
return (lconn, address)
def connect(local_privkey, node_id, host, port=9735):
if isinstance(node_id, bytes) and len(node_id) == 33:
remote_pubkey = PublicKey(node_id)
elif isinstance(node_id, ec.EllipticCurvePublicKey):
remote_pubkey = PublicKey(node_id)
elif isinstance(node_id, PublicKey):
remote_pubkey = node_id
else:
raise ValueError(
"node_id must be either a 33 byte array, or a PublicKey"
)
conn = socket.create_connection((host, port))
lconn = LightningConnection(conn, remote_pubkey, local_privkey,
is_initiator=True)
lconn.shake()
return lconn

2
contrib/pylightning/requirements.txt

@ -0,0 +1,2 @@
cryptography==2.7
coincurve==12.0.0

6
contrib/pylightning/setup.py

@ -6,6 +6,9 @@ import io
with io.open('README.md', encoding='utf-8') as f:
long_description = f.read()
with io.open('requirements.txt', encoding='utf-8') as f:
requirements = [r for r in f.read().split('\n') if len(r)]
setup(name='pylightning',
version=lightning.__version__,
description='Client library for lightningd',
@ -17,4 +20,5 @@ setup(name='pylightning',
license='MIT',
packages=['lightning'],
scripts=['lightning-pay'],
zip_safe=True)
zip_safe=True,
install_requires=requirements)

189
contrib/pylightning/tests/test_wire.py

@ -0,0 +1,189 @@
from binascii import hexlify, unhexlify
from lightning.wire import PrivateKey, PublicKey, LightningConnection
import socket
from lightning import wire
import threading
def test_primitives():
raw_privkey = unhexlify('1111111111111111111111111111111111111111111111111111111111111111')
privkey = PrivateKey(raw_privkey)
pubkey = privkey.public_key()
assert(hexlify(pubkey.serializeCompressed()) == b'034f355bdcb7cc0af728ef3cceb9615d90684bb5b2ca5f859ab0f0b704075871aa')
# Now try with the raw constructor once more
pubkey = PublicKey(unhexlify(b'034f355bdcb7cc0af728ef3cceb9615d90684bb5b2ca5f859ab0f0b704075871aa'))
assert(hexlify(pubkey.serializeCompressed()) == b'034f355bdcb7cc0af728ef3cceb9615d90684bb5b2ca5f859ab0f0b704075871aa')
def test_encrypt_decrypt():
""" Test encryptWithAD and decryptWithAD primitives
Taken from https://github.com/lightningnetwork/lightning-rfc/blob/master/08-transport.md#initiator-tests
"""
inp = [b'e68f69b7f096d7917245f5e5cf8ae1595febe4d4644333c99f9c4a1282031c9f', b'000000000000000000000000', b'9e0e7de8bb75554f21db034633de04be41a2b8a18da7a319a03c803bf02b396c', b'']
inp = [unhexlify(i) for i in inp]
c = wire.encryptWithAD(*inp)
assert(hexlify(c) == b'0df6086551151f58b8afe6c195782c6a')
def test_handshake():
rs_privkey = PrivateKey(unhexlify('2121212121212121212121212121212121212121212121212121212121212121'))
rs_pubkey = rs_privkey.public_key()
assert(hexlify(rs_pubkey.serializeCompressed()) == b'028d7500dd4c12685d1f568b4c2b5048e8534b873319f3a8daa612b469132ec7f7')
ls_privkey = PrivateKey(unhexlify('1111111111111111111111111111111111111111111111111111111111111111'))
ls_pubkey = ls_privkey.public_key()
assert(hexlify(ls_pubkey.serializeCompressed()) == b'034f355bdcb7cc0af728ef3cceb9615d90684bb5b2ca5f859ab0f0b704075871aa')
c1, c2 = socket.socketpair()
lc1 = LightningConnection(c1, rs_pubkey, ls_privkey, is_initiator=True)
lc2 = LightningConnection(c2, ls_pubkey, rs_privkey, is_initiator=False)
# Override the generated ephemeral key for the test:
lc1.handshake['e'] = PrivateKey(unhexlify('1212121212121212121212121212121212121212121212121212121212121212'))
lc2.handshake['e'] = PrivateKey(unhexlify(b'2222222222222222222222222222222222222222222222222222222222222222'))
assert(hexlify(lc1.handshake['e'].public_key().serializeCompressed()) == b'036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f7')
assert(hexlify(lc1.handshake['h']) == b'8401b3fdcaaa710b5405400536a3d5fd7792fe8e7fe29cd8b687216fe323ecbd')
assert(lc1.handshake['h'] == lc2.handshake['h'])
m = lc1.handshake_act_one_initiator()
assert(hexlify(m) == b'00036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6a')
lc2.handshake_act_one_responder(m)
assert(hexlify(lc1.handshake['h']) == b'9d1ffbb639e7e20021d9259491dc7b160aab270fb1339ef135053f6f2cebe9ce')
assert(hexlify(lc1.handshake['h']) == hexlify(lc2.handshake['h']))
assert(hexlify(lc1.chaining_key) == b'b61ec1191326fa240decc9564369dbb3ae2b34341d1e11ad64ed89f89180582f')
assert(hexlify(lc2.chaining_key) == b'b61ec1191326fa240decc9564369dbb3ae2b34341d1e11ad64ed89f89180582f')
m = lc2.handshake_act_two_responder()
assert(hexlify(m) == b'0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae')
assert(hexlify(lc2.handshake['h']) == b'90578e247e98674e661013da3c5c1ca6a8c8f48c90b485c0dfa1494e23d56d72')
lc1.handshake_act_two_initiator(m)
assert(hexlify(lc1.handshake['h']) == b'90578e247e98674e661013da3c5c1ca6a8c8f48c90b485c0dfa1494e23d56d72')
assert(hexlify(lc1.chaining_key) == b'e89d31033a1b6bf68c07d22e08ea4d7884646c4b60a9528598ccb4ee2c8f56ba')
assert(hexlify(lc2.chaining_key) == b'e89d31033a1b6bf68c07d22e08ea4d7884646c4b60a9528598ccb4ee2c8f56ba')
m = lc1.handshake_act_three_initiator()
assert(hexlify(m) == b'00b9e3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8fc28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba')
assert(hexlify(lc1.sk) == b'969ab31b4d288cedf6218839b27a3e2140827047f2c0f01bf5c04435d43511a9')
assert(hexlify(lc1.rk) == b'bb9020b8965f4df047e07f955f3c4b88418984aadc5cdb35096b9ea8fa5c3442')
lc2.handshake_act_three_responder(m)
assert(lc1.rk == lc2.sk)
assert(lc1.sk == lc2.rk)
assert(lc1.sn == lc2.rn)
assert(lc1.rn == lc2.sn)
assert(hexlify(lc2.rk) == b'969ab31b4d288cedf6218839b27a3e2140827047f2c0f01bf5c04435d43511a9')
assert(hexlify(lc2.sk) == b'bb9020b8965f4df047e07f955f3c4b88418984aadc5cdb35096b9ea8fa5c3442')
def test_shake():
rs_privkey = PrivateKey(unhexlify('2121212121212121212121212121212121212121212121212121212121212121'))
rs_pubkey = rs_privkey.public_key()
assert(hexlify(rs_pubkey.serializeCompressed()) == b'028d7500dd4c12685d1f568b4c2b5048e8534b873319f3a8daa612b469132ec7f7')
ls_privkey = PrivateKey(unhexlify('1111111111111111111111111111111111111111111111111111111111111111'))
ls_pubkey = ls_privkey.public_key()
assert(hexlify(ls_pubkey.serializeCompressed()) == b'034f355bdcb7cc0af728ef3cceb9615d90684bb5b2ca5f859ab0f0b704075871aa')
c1, c2 = socket.socketpair()
lc1 = LightningConnection(c1, rs_pubkey, ls_privkey, is_initiator=True)
lc2 = LightningConnection(c2, ls_pubkey, rs_privkey, is_initiator=False)
t = threading.Thread(target=lc2.shake)
t.start()
lc1.shake()
t.join()
assert(lc1.rk == lc2.sk)
assert(lc1.sk == lc2.rk)
assert(lc1.sn == lc2.rn)
assert(lc1.rn == lc2.sn)
def test_read_key_rotation():
ls_privkey = PrivateKey(unhexlify('1111111111111111111111111111111111111111111111111111111111111111'))
rs_privkey = PrivateKey(unhexlify('2121212121212121212121212121212121212121212121212121212121212121'))
rs_pubkey = rs_privkey.public_key()
c1, c2 = socket.socketpair()
lc = LightningConnection(c1, rs_pubkey, ls_privkey, is_initiator=True)
# ck=0x919219dbb2920afa8db80f9a51787a840bcf111ed8d588caf9ab4be716e42b01
# sk=0x969ab31b4d288cedf6218839b27a3e2140827047f2c0f01bf5c04435d43511a9
# rk=0xbb9020b8965f4df047e07f955f3c4b88418984aadc5cdb35096b9ea8fa5c3442
lc.chaining_key = unhexlify(b'919219dbb2920afa8db80f9a51787a840bcf111ed8d588caf9ab4be716e42b01')
lc.sk = unhexlify(b'969ab31b4d288cedf6218839b27a3e2140827047f2c0f01bf5c04435d43511a9')
lc.rk = unhexlify(b'bb9020b8965f4df047e07f955f3c4b88418984aadc5cdb35096b9ea8fa5c3442')
lc.sn, lc.rn = 0, 0
lc.sck, lc.rck = lc.chaining_key, lc.chaining_key
msg = unhexlify(b'68656c6c6f')
lc.send_message(msg)
m = c2.recv(18 + 21)
assert(hexlify(m) == b'cf2b30ddf0cf3f80e7c35a6e6730b59fe802473180f396d88a8fb0db8cbcf25d2f214cf9ea1d95')
# Send 498 more messages, to get just below the switch threshold
for i in range(0, 498):
lc.send_message(msg)
m = c2.recv(18 + 21)
# Check the last message against the test vector
assert(hexlify(lc.sk) == b'969ab31b4d288cedf6218839b27a3e2140827047f2c0f01bf5c04435d43511a9')
# This next message triggers the rotation:
lc.send_message(msg)
m = c2.recv(18 + 21)
# Now try to send with the new keys:
lc.send_message(msg)
m = c2.recv(18 + 21)
assert(hexlify(m) == b'178cb9d7387190fa34db9c2d50027d21793c9bc2d40b1e14dcf30ebeeeb220f48364f7a4c68bf8')
lc.send_message(msg)
m = c2.recv(18 + 21)
assert(hexlify(m) == b'1b186c57d44eb6de4c057c49940d79bb838a145cb528d6e8fd26dbe50a60ca2c104b56b60e45bd')
for i in range(0, 498):
lc.send_message(msg)
m = c2.recv(18 + 21)
lc.send_message(msg)
m = c2.recv(18 + 21)
assert(hexlify(m) == b'4a2f3cc3b5e78ddb83dcb426d9863d9d9a723b0337c89dd0b005d89f8d3c05c52b76b29b740f09')
lc.send_message(msg)
m = c2.recv(18 + 21)
assert(hexlify(m) == b'2ecd8c8a5629d0d02ab457a0fdd0f7b90a192cd46be5ecb6ca570bfc5e268338b1a16cf4ef2d36')
def test_listen_connect():
"""Roundtrip test using the public constructors.
"""
n1_privkey = PrivateKey(unhexlify(b'1111111111111111111111111111111111111111111111111111111111111111'))
n2_privkey = PrivateKey(unhexlify('2121212121212121212121212121212121212121212121212121212121212121'))
lss = wire.LightningServerSocket(n2_privkey)
lss.bind(('0.0.0.0', 1234))
lss.listen()
port = lss.getsockname()[1]
print(port)
def connect():
lc = wire.connect(n1_privkey, n2_privkey.public_key(), '127.0.0.1', port)
lc.send_message(b'hello')
m = lc.read_message()
assert(m == b'world')
t = threading.Thread(target=connect)
t.daemon = True
t.start()
c, _ = lss.accept()
m = c.read_message()
assert(m == b'hello')
c.send_message(b'world')
t.join()
Loading…
Cancel
Save