Browse Source

lnbase: store remote revocation store, don't store all remote revocation points, verify ctn numbers in reestablish

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 7 years ago
committed by ThomasV
parent
commit
ee87920573
  1. 43
      lib/lnbase.py
  2. 5
      lib/tests/test_lnbase.py
  3. 7
      lib/tests/test_lnbase_online.py

43
lib/lnbase.py

@ -276,7 +276,7 @@ ChannelConfig = namedtuple("ChannelConfig", [
"payment_basepoint", "multisig_key", "htlc_basepoint", "delayed_basepoint", "revocation_basepoint", "payment_basepoint", "multisig_key", "htlc_basepoint", "delayed_basepoint", "revocation_basepoint",
"to_self_delay", "dust_limit_sat", "max_htlc_value_in_flight_msat", "max_accepted_htlcs"]) "to_self_delay", "dust_limit_sat", "max_htlc_value_in_flight_msat", "max_accepted_htlcs"])
OnlyPubkeyKeypair = namedtuple("OnlyPubkeyKeypair", ["pubkey"]) OnlyPubkeyKeypair = namedtuple("OnlyPubkeyKeypair", ["pubkey"])
RemoteState = namedtuple("RemoteState", ["ctn", "next_per_commitment_point", "amount_sat", "commitment_points"]) RemoteState = namedtuple("RemoteState", ["ctn", "next_per_commitment_point", "amount_sat", "revocation_store", "last_per_commitment_point"])
LocalState = namedtuple("LocalState", ["ctn", "per_commitment_secret_seed", "amount_sat"]) LocalState = namedtuple("LocalState", ["ctn", "per_commitment_secret_seed", "amount_sat"])
ChannelConstraints = namedtuple("ChannelConstraints", ["feerate", "capacity", "is_initiator", "funding_txn_minimum_depth"]) ChannelConstraints = namedtuple("ChannelConstraints", ["feerate", "capacity", "is_initiator", "funding_txn_minimum_depth"])
OpenChannel = namedtuple("OpenChannel", ["channel_id", "funding_outpoint", "local_config", "remote_config", "remote_state", "local_state", "constraints"]) OpenChannel = namedtuple("OpenChannel", ["channel_id", "funding_outpoint", "local_config", "remote_config", "remote_state", "local_state", "constraints"])
@ -911,6 +911,7 @@ class Peer(PrintError):
# broadcast funding tx # broadcast funding tx
success, _txid = self.network.broadcast(funding_tx) success, _txid = self.network.broadcast(funding_tx)
assert success, success assert success, success
their_revocation_store = RevocationStore()
chan = OpenChannel( chan = OpenChannel(
channel_id=channel_id, channel_id=channel_id,
funding_outpoint=Outpoint(funding_txid, funding_index), funding_outpoint=Outpoint(funding_txid, funding_index),
@ -919,8 +920,9 @@ class Peer(PrintError):
remote_state=RemoteState( remote_state=RemoteState(
ctn = 0, ctn = 0,
next_per_commitment_point=None, next_per_commitment_point=None,
last_per_commitment_point=remote_per_commitment_point,
amount_sat=remote_amount, amount_sat=remote_amount,
commitment_points=[bh2u(remote_per_commitment_point)] revocation_store=their_revocation_store
), ),
local_state=LocalState( local_state=LocalState(
ctn = 0, ctn = 0,
@ -943,7 +945,15 @@ class Peer(PrintError):
# 'your_last_per_commitment_secret': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', # 'your_last_per_commitment_secret': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
# 'my_current_per_commitment_point': b'\x03\x18\xb9\x1b\x99\xd4\xc3\xf1\x92\x0f\xfe\xe4c\x9e\xae\xa4\xf1\xdeX\xcf4\xa9[\xd1\tAh\x80\x88\x01b*[' # 'my_current_per_commitment_point': b'\x03\x18\xb9\x1b\x99\xd4\xc3\xf1\x92\x0f\xfe\xe4c\x9e\xae\xa4\xf1\xdeX\xcf4\xa9[\xd1\tAh\x80\x88\x01b*['
# } # }
if channel_reestablish_msg["my_current_per_commitment_point"] != bfh(chan.remote_state.commitment_points[-1]): remote_ctn = int.from_bytes(channel_reestablish_msg["next_local_commitment_number"], 'big')
if remote_ctn != chan.remote_state.ctn + 1:
raise Exception("expected remote ctn {}, got {}".format(chan.remote_state.ctn + 1, remote_ctn))
local_ctn = int.from_bytes(channel_reestablish_msg["next_remote_revocation_number"], 'big')
if local_ctn != chan.local_state.ctn:
raise Exception("expected local ctn {}, got {}".format(chan.local_state.ctn, local_ctn))
if channel_reestablish_msg["my_current_per_commitment_point"] != chan.remote_state.last_per_commitment_point:
raise Exception("Remote PCP mismatch") raise Exception("Remote PCP mismatch")
self.send_message(gen_msg("channel_reestablish", self.send_message(gen_msg("channel_reestablish",
channel_id=chan.channel_id, channel_id=chan.channel_id,
@ -991,10 +1001,10 @@ class Peer(PrintError):
return chan._replace(remote_state=chan.remote_state._replace(next_per_commitment_point=remote_funding_locked_msg["next_per_commitment_point"])) return chan._replace(remote_state=chan.remote_state._replace(next_per_commitment_point=remote_funding_locked_msg["next_per_commitment_point"]))
async def receive_commitment_revoke_ack(self, chan, expected_received_sat, payment_preimage): async def receive_commitment_revoke_ack(self, chan, expected_received_sat, payment_preimage):
def derive_and_incr(): def derive_and_incr(last = False):
nonlocal chan nonlocal chan
last_secret = get_per_commitment_secret_from_seed(chan.local_state.per_commitment_secret_seed, 2**48-chan.local_state.ctn-1) last_secret = get_per_commitment_secret_from_seed(chan.local_state.per_commitment_secret_seed, 2**48-chan.local_state.ctn-1)
next_secret = get_per_commitment_secret_from_seed(chan.local_state.per_commitment_secret_seed, 2**48-chan.local_state.ctn-2) next_secret = get_per_commitment_secret_from_seed(chan.local_state.per_commitment_secret_seed, 2**48-chan.local_state.ctn-(2 if not last else 3))
next_point = secret_to_pubkey(int.from_bytes(next_secret, 'big')) next_point = secret_to_pubkey(int.from_bytes(next_secret, 'big'))
chan = chan._replace( chan = chan._replace(
local_state=chan.local_state._replace( local_state=chan.local_state._replace(
@ -1002,6 +1012,9 @@ class Peer(PrintError):
) )
) )
return last_secret, next_point return last_secret, next_point
their_revstore = chan.remote_state.revocation_store
channel_id = chan.channel_id channel_id = chan.channel_id
try: try:
commitment_signed_msg = await self.commitment_signed[channel_id] commitment_signed_msg = await self.commitment_signed[channel_id]
@ -1052,6 +1065,8 @@ class Peer(PrintError):
print("SENDING FIRST REVOKE AND ACK") print("SENDING FIRST REVOKE AND ACK")
their_revstore.add_next_entry(last_secret)
self.send_message(gen_msg("revoke_and_ack", self.send_message(gen_msg("revoke_and_ack",
channel_id=channel_id, channel_id=channel_id,
per_commitment_secret=last_secret, per_commitment_secret=last_secret,
@ -1111,7 +1126,9 @@ class Peer(PrintError):
# TODO check commitment_signed results # TODO check commitment_signed results
last_secret, next_point = derive_and_incr() last_secret, next_point = derive_and_incr(True)
their_revstore.add_next_entry(last_secret)
print("SENDING SECOND REVOKE AND ACK") print("SENDING SECOND REVOKE AND ACK")
self.send_message(gen_msg("revoke_and_ack", self.send_message(gen_msg("revoke_and_ack",
@ -1125,7 +1142,8 @@ class Peer(PrintError):
), ),
remote_state=chan.remote_state._replace( remote_state=chan.remote_state._replace(
ctn=chan.remote_state.ctn + 2, ctn=chan.remote_state.ctn + 2,
commitment_points=chan.remote_state.commitment_points + [bh2u(remote_next_commitment_point)], revocation_store=their_revstore,
last_per_commitment_point=remote_next_commitment_point,
next_per_commitment_point=revoke_and_ack_msg["next_per_commitment_point"], next_per_commitment_point=revoke_and_ack_msg["next_per_commitment_point"],
amount_sat=chan.remote_state.amount_sat - expected_received_sat amount_sat=chan.remote_state.amount_sat - expected_received_sat
) )
@ -1575,3 +1593,14 @@ class RevocationStore:
raise Exception("hash is not derivable: {} {} {}".format(bh2u(e.secret), bh2u(this_bucket.secret), this_bucket.index)) raise Exception("hash is not derivable: {} {} {}".format(bh2u(e.secret), bh2u(this_bucket.secret), this_bucket.index))
self.buckets[bucket] = new_element self.buckets[bucket] = new_element
self.index -= 1 self.index -= 1
def serialize(self):
return {"index": self.index, "buckets": [[bh2u(k.secret), k.index] if k is not None else None for k in self.buckets]}
@staticmethod
def from_json_obj(decoded_json_obj):
store = RevocationStore()
decode = lambda to_decode: ShachainElement(bfh(to_decode[0]), int(to_decode[1]))
store.buckets = [k if k is None else decode(k) for k in decoded_json_obj["buckets"]]
store.index = decoded_json_obj["index"]
return store
def __eq__(self, o):
return self.buckets == o.buckets and self.index == o.index

5
lib/tests/test_lnbase.py

@ -7,7 +7,7 @@ from lib.lnbase import make_commitment, get_obscured_ctn, Peer, make_offered_htl
from lib.lnbase import secret_to_pubkey, derive_pubkey, derive_privkey, derive_blinded_pubkey, overall_weight from lib.lnbase import secret_to_pubkey, derive_pubkey, derive_privkey, derive_blinded_pubkey, overall_weight
from lib.lnbase import make_htlc_tx_output, make_htlc_tx_inputs, get_per_commitment_secret_from_seed from lib.lnbase import make_htlc_tx_output, make_htlc_tx_inputs, get_per_commitment_secret_from_seed
from lib.lnbase import make_htlc_tx_witness, OnionHopsDataSingle, new_onion_packet, OnionPerHop from lib.lnbase import make_htlc_tx_witness, OnionHopsDataSingle, new_onion_packet, OnionPerHop
from lib.lnbase import RevocationStore, ShachainElement, shachain_derive from lib.lnbase import RevocationStore
from lib.transaction import Transaction from lib.transaction import Transaction
from lib import bitcoin from lib import bitcoin
import ecdsa.ellipticcurve import ecdsa.ellipticcurve
@ -790,8 +790,9 @@ class Test_LNBase(unittest.TestCase):
seed = bitcoin.sha256(b"shachaintest") seed = bitcoin.sha256(b"shachaintest")
consumer = RevocationStore() consumer = RevocationStore()
for i in range(10000): for i in range(10000):
secret = shachain_derive(ShachainElement(seed, 0), 2**48 - i - 1).secret secret = get_per_commitment_secret_from_seed(seed, 2**48 - i - 1)
try: try:
consumer.add_next_entry(secret) consumer.add_next_entry(secret)
except Exception as e: except Exception as e:
raise Exception("iteration " + str(i) + ": " + str(e)) raise Exception("iteration " + str(i) + ": " + str(e))
if i % 1000 == 0: self.assertEqual(consumer.serialize(), RevocationStore.from_json_obj(json.loads(json.dumps(consumer.serialize()))).serialize())

7
lib/tests/test_lnbase_online.py

@ -13,14 +13,14 @@ from lib.simple_config import SimpleConfig
from lib.network import Network from lib.network import Network
from lib.storage import WalletStorage from lib.storage import WalletStorage
from lib.wallet import Wallet from lib.wallet import Wallet
from lib.lnbase import Peer, node_list, Outpoint, ChannelConfig, LocalState, RemoteState, Keypair, OnlyPubkeyKeypair, OpenChannel, ChannelConstraints from lib.lnbase import Peer, node_list, Outpoint, ChannelConfig, LocalState, RemoteState, Keypair, OnlyPubkeyKeypair, OpenChannel, ChannelConstraints, RevocationStore
from lib.lightning_payencode.lnaddr import lnencode, LnAddr from lib.lightning_payencode.lnaddr import lnencode, LnAddr
import lib.constants as constants import lib.constants as constants
is_key = lambda k: k.endswith("_basepoint") or k.endswith("_key") is_key = lambda k: k.endswith("_basepoint") or k.endswith("_key")
def maybeDecode(k, v): def maybeDecode(k, v):
if k in ["pubkey", "privkey", "next_per_commitment_point", "per_commitment_secret_seed"] and v is not None: if k in ["pubkey", "privkey", "last_per_commitment_point", "next_per_commitment_point", "per_commitment_secret_seed"] and v is not None:
return binascii.unhexlify(v) return binascii.unhexlify(v)
return v return v
@ -45,6 +45,7 @@ def reconstruct_namedtuples(openingchannel):
new_local_state = decodeAll(openingchannel.local_state) new_local_state = decodeAll(openingchannel.local_state)
openingchannel = openingchannel._replace(local_state=LocalState(**new_local_state)) openingchannel = openingchannel._replace(local_state=LocalState(**new_local_state))
new_remote_state = decodeAll(openingchannel.remote_state) new_remote_state = decodeAll(openingchannel.remote_state)
new_remote_state["revocation_store"] = RevocationStore.from_json_obj(new_remote_state["revocation_store"])
openingchannel = openingchannel._replace(remote_state=RemoteState(**new_remote_state)) openingchannel = openingchannel._replace(remote_state=RemoteState(**new_remote_state))
openingchannel = openingchannel._replace(constraints=ChannelConstraints(**openingchannel.constraints)) openingchannel = openingchannel._replace(constraints=ChannelConstraints(**openingchannel.constraints))
return openingchannel return openingchannel
@ -58,6 +59,8 @@ def serialize_channels(channels):
def default(self, o): def default(self, o):
if isinstance(o, bytes): if isinstance(o, bytes):
return binascii.hexlify(o).decode("ascii") return binascii.hexlify(o).decode("ascii")
if isinstance(o, RevocationStore):
return o.serialize()
return super(MyJsonEncoder, self) return super(MyJsonEncoder, self)
dumped = MyJsonEncoder().encode(serialized_channels) dumped = MyJsonEncoder().encode(serialized_channels)
roundtripped = json.loads(dumped) roundtripped = json.loads(dumped)

Loading…
Cancel
Save