From 6ba08cc8d45d51478f606a71d445891041e2e3b1 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Mon, 16 Mar 2020 22:07:00 +0100 Subject: [PATCH] ln feature bits: flatten namespaces, and impl feature deps and ctxs This implements: - flat feature bits https://github.com/lightningnetwork/lightning-rfc/pull/666 - feature bit dependencies https://github.com/lightningnetwork/lightning-rfc/pull/719 --- electrum/channel_db.py | 6 +- electrum/lnpeer.py | 24 +++--- electrum/lnutil.py | 148 ++++++++++++++++++++++++++++++---- electrum/lnworker.py | 22 ++--- electrum/tests/test_lnmsg.py | 18 ++--- electrum/tests/test_lnpeer.py | 6 +- electrum/tests/test_lnutil.py | 52 +++++++++++- 7 files changed, 225 insertions(+), 51 deletions(-) diff --git a/electrum/channel_db.py b/electrum/channel_db.py index aef25effe..59de812b9 100644 --- a/electrum/channel_db.py +++ b/electrum/channel_db.py @@ -38,7 +38,7 @@ from .sql_db import SqlDB, sql from . import constants from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits from .logging import Logger -from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID +from .lnutil import LN_FEATURES_IMPLEMENTED, LNPeerAddr, format_short_channel_id, ShortChannelID from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update from .lnmsg import decode_msg @@ -49,10 +49,10 @@ if TYPE_CHECKING: class UnknownEvenFeatureBits(Exception): pass -def validate_features(features : int): +def validate_features(features: int) -> None: enabled_features = list_enabled_bits(features) for fbit in enabled_features: - if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0: + if (1 << fbit) & LN_FEATURES_IMPLEMENTED == 0 and fbit % 2 == 0: raise UnknownEvenFeatureBits() diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 0524dd18d..2a02aee26 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -37,7 +37,7 @@ from . import lnutil from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, RemoteConfig, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore, funding_output_script, get_per_commitment_secret_from_seed, - secret_to_pubkey, PaymentFailure, LnLocalFeatures, + secret_to_pubkey, PaymentFailure, LnFeatures, LOCAL, REMOTE, HTLCOwner, generate_keypair, LnKeyFamily, ln_compare_features, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED, LightningPeerConnectionClosed, HandshakeFailed, NotFoundChanAnnouncementForUpdate, @@ -77,7 +77,7 @@ class Peer(Logger): self.pubkey = pubkey # remote pubkey self.lnworker = lnworker self.privkey = lnworker.node_keypair.privkey # local privkey - self.localfeatures = self.lnworker.localfeatures + self.features = self.lnworker.features self.node_ids = [self.pubkey, privkey_to_pubkey(self.privkey)] self.network = lnworker.network self.channel_db = lnworker.network.channel_db @@ -131,8 +131,8 @@ class Peer(Logger): async def initialize(self): if isinstance(self.transport, LNTransport): await self.transport.handshake() - # FIXME: "flen" hardcoded but actually it depends on "localfeatures"...: - self.send_message("init", gflen=0, flen=2, features=self.localfeatures, + # FIXME: "flen" hardcoded but actually it depends on "features"...: + self.send_message("init", gflen=0, flen=2, features=self.features.for_init_message(), init_tlvs={ 'networks': {'chains': constants.net.rev_genesis_bytes()} @@ -204,11 +204,15 @@ class Peer(Logger): if self._received_init: self.logger.info("ALREADY INITIALIZED BUT RECEIVED INIT") return - # if they required some even flag we don't have, they will close themselves - # but if we require an even flag they don't have, we close - their_localfeatures = int.from_bytes(payload['features'], byteorder="big") # TODO feature bit unification + their_features = LnFeatures(int.from_bytes(payload['features'], byteorder="big")) + their_globalfeatures = int.from_bytes(payload['globalfeatures'], byteorder="big") + their_features |= their_globalfeatures + # check transitive dependencies for received features + if not their_features.validate_transitive_dependecies(): + raise GracefulDisconnect("remote did not set all dependencies for the features they sent") + # check if features are compatible, and set self.features to what we negotiated try: - self.localfeatures = ln_compare_features(self.localfeatures, their_localfeatures) + self.features = ln_compare_features(self.features, their_features) except IncompatibleLightningFeatures as e: self.initialized.set_exception(e) raise GracefulDisconnect(f"{str(e)}") @@ -477,7 +481,7 @@ class Peer(Logger): self.lnworker.peer_closed(self) def is_static_remotekey(self): - return bool(self.localfeatures & LnLocalFeatures.OPTION_STATIC_REMOTEKEY_OPT) + return bool(self.features & LnFeatures.OPTION_STATIC_REMOTEKEY_OPT) def make_local_config(self, funding_sat: int, push_msat: int, initiator: HTLCOwner) -> LocalConfig: # key derivation @@ -756,7 +760,7 @@ class Peer(Logger): oldest_unrevoked_remote_ctn = chan.get_oldest_unrevoked_ctn(REMOTE) latest_remote_ctn = chan.get_latest_ctn(REMOTE) next_remote_ctn = chan.get_next_ctn(REMOTE) - assert self.localfeatures & LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT + assert self.features & LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT # send message srk_enabled = chan.is_static_remotekey_enabled() if srk_enabled: diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 1ae216919..3d035e205 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -3,8 +3,9 @@ # file LICENCE or http://www.opensource.org/licenses/mit-license.php from enum import IntFlag, IntEnum +import enum import json -from collections import namedtuple +from collections import namedtuple, defaultdict from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict, Set, Sequence import re import attr @@ -708,19 +709,135 @@ def get_ecdh(priv: bytes, pub: bytes) -> bytes: return sha256(pt.get_public_key_bytes()) -class LnLocalFeatures(IntFlag): +class LnFeatureContexts(enum.Flag): + INIT = enum.auto() + NODE_ANN = enum.auto() + CHAN_ANN_AS_IS = enum.auto() + CHAN_ANN_ALWAYS_ODD = enum.auto() + CHAN_ANN_ALWAYS_EVEN = enum.auto() + INVOICE = enum.auto() + +LNFC = LnFeatureContexts + +_ln_feature_direct_dependencies = defaultdict(set) # type: Dict[LnFeatures, Set[LnFeatures]] +_ln_feature_contexts = {} # type: Dict[LnFeatures, LnFeatureContexts] + +class LnFeatures(IntFlag): OPTION_DATA_LOSS_PROTECT_REQ = 1 << 0 OPTION_DATA_LOSS_PROTECT_OPT = 1 << 1 + _ln_feature_contexts[OPTION_DATA_LOSS_PROTECT_OPT] = (LNFC.INIT | LnFeatureContexts.NODE_ANN) + _ln_feature_contexts[OPTION_DATA_LOSS_PROTECT_REQ] = (LNFC.INIT | LnFeatureContexts.NODE_ANN) + INITIAL_ROUTING_SYNC = 1 << 3 + _ln_feature_contexts[INITIAL_ROUTING_SYNC] = LNFC.INIT + OPTION_UPFRONT_SHUTDOWN_SCRIPT_REQ = 1 << 4 OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT = 1 << 5 + _ln_feature_contexts[OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT] = (LNFC.INIT | LNFC.NODE_ANN) + _ln_feature_contexts[OPTION_UPFRONT_SHUTDOWN_SCRIPT_REQ] = (LNFC.INIT | LNFC.NODE_ANN) + GOSSIP_QUERIES_REQ = 1 << 6 GOSSIP_QUERIES_OPT = 1 << 7 + _ln_feature_contexts[GOSSIP_QUERIES_OPT] = (LNFC.INIT | LNFC.NODE_ANN) + _ln_feature_contexts[GOSSIP_QUERIES_REQ] = (LNFC.INIT | LNFC.NODE_ANN) + + VAR_ONION_REQ = 1 << 8 + VAR_ONION_OPT = 1 << 9 + _ln_feature_contexts[VAR_ONION_OPT] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE) + _ln_feature_contexts[VAR_ONION_REQ] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE) + + GOSSIP_QUERIES_EX_REQ = 1 << 10 + GOSSIP_QUERIES_EX_OPT = 1 << 11 + _ln_feature_direct_dependencies[GOSSIP_QUERIES_EX_OPT] = {GOSSIP_QUERIES_OPT} + _ln_feature_contexts[GOSSIP_QUERIES_EX_OPT] = (LNFC.INIT | LNFC.NODE_ANN) + _ln_feature_contexts[GOSSIP_QUERIES_EX_REQ] = (LNFC.INIT | LNFC.NODE_ANN) + OPTION_STATIC_REMOTEKEY_REQ = 1 << 12 OPTION_STATIC_REMOTEKEY_OPT = 1 << 13 - -# note that these are powers of two, not the bits themselves -LN_LOCAL_FEATURES_KNOWN_SET = set(LnLocalFeatures) + _ln_feature_contexts[OPTION_STATIC_REMOTEKEY_OPT] = (LNFC.INIT | LNFC.NODE_ANN) + _ln_feature_contexts[OPTION_STATIC_REMOTEKEY_REQ] = (LNFC.INIT | LNFC.NODE_ANN) + + PAYMENT_SECRET_REQ = 1 << 14 + PAYMENT_SECRET_OPT = 1 << 15 + _ln_feature_direct_dependencies[PAYMENT_SECRET_OPT] = {VAR_ONION_OPT} + _ln_feature_contexts[PAYMENT_SECRET_OPT] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE) + _ln_feature_contexts[PAYMENT_SECRET_REQ] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE) + + BASIC_MPP_REQ = 1 << 16 + BASIC_MPP_OPT = 1 << 17 + _ln_feature_direct_dependencies[BASIC_MPP_OPT] = {PAYMENT_SECRET_OPT} + _ln_feature_contexts[BASIC_MPP_OPT] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE) + _ln_feature_contexts[BASIC_MPP_REQ] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE) + + OPTION_SUPPORT_LARGE_CHANNEL_REQ = 1 << 18 + OPTION_SUPPORT_LARGE_CHANNEL_OPT = 1 << 19 + _ln_feature_contexts[OPTION_SUPPORT_LARGE_CHANNEL_OPT] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.CHAN_ANN_ALWAYS_EVEN) + _ln_feature_contexts[OPTION_SUPPORT_LARGE_CHANNEL_REQ] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.CHAN_ANN_ALWAYS_EVEN) + + def validate_transitive_dependecies(self) -> bool: + # for all even bit set, set corresponding odd bit: + features = self # copy + flags = list_enabled_bits(features) + for flag in flags: + if flag % 2 == 0: + features |= 1 << get_ln_flag_pair_of_bit(flag) + # Check dependencies. We only check that the direct dependencies of each flag set + # are satisfied: this implies that transitive dependencies are also satisfied. + flags = list_enabled_bits(features) + for flag in flags: + for dependency in _ln_feature_direct_dependencies[1 << flag]: + if not (dependency & features): + return False + return True + + def for_init_message(self) -> 'LnFeatures': + features = LnFeatures(0) + for flag in list_enabled_bits(self): + if LnFeatureContexts.INIT & _ln_feature_contexts[1 << flag]: + features |= (1 << flag) + return features + + def for_node_announcement(self) -> 'LnFeatures': + features = LnFeatures(0) + for flag in list_enabled_bits(self): + if LnFeatureContexts.NODE_ANN & _ln_feature_contexts[1 << flag]: + features |= (1 << flag) + return features + + def for_invoice(self) -> 'LnFeatures': + features = LnFeatures(0) + for flag in list_enabled_bits(self): + if LnFeatureContexts.INVOICE & _ln_feature_contexts[1 << flag]: + features |= (1 << flag) + return features + + def for_channel_announcement(self) -> 'LnFeatures': + features = LnFeatures(0) + for flag in list_enabled_bits(self): + ctxs = _ln_feature_contexts[1 << flag] + if LnFeatureContexts.CHAN_ANN_AS_IS & ctxs: + features |= (1 << flag) + elif LnFeatureContexts.CHAN_ANN_ALWAYS_EVEN & ctxs: + if flag % 2 == 0: + features |= (1 << flag) + elif LnFeatureContexts.CHAN_ANN_ALWAYS_ODD & ctxs: + if flag % 2 == 0: + flag = get_ln_flag_pair_of_bit(flag) + features |= (1 << flag) + return features + + +del LNFC # name is ambiguous without context + +# features that are actually implemented and understood in our codebase: +# (note: this is not what we send in e.g. init!) +# (note: specify both OPT and REQ here) +LN_FEATURES_IMPLEMENTED = ( + LnFeatures(0) + | LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT | LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ + | LnFeatures.GOSSIP_QUERIES_OPT | LnFeatures.GOSSIP_QUERIES_REQ + | LnFeatures.OPTION_STATIC_REMOTEKEY_OPT | LnFeatures.OPTION_STATIC_REMOTEKEY_REQ +) def get_ln_flag_pair_of_bit(flag_bit: int) -> int: @@ -735,23 +852,20 @@ def get_ln_flag_pair_of_bit(flag_bit: int) -> int: return flag_bit - 1 -class LnGlobalFeatures(IntFlag): - pass - -# note that these are powers of two, not the bits themselves -LN_GLOBAL_FEATURES_KNOWN_SET = set(LnGlobalFeatures) - class IncompatibleLightningFeatures(ValueError): pass -def ln_compare_features(our_features, their_features) -> int: - """raises IncompatibleLightningFeatures if incompatible""" +def ln_compare_features(our_features: 'LnFeatures', their_features: int) -> 'LnFeatures': + """Returns negotiated features. + Raises IncompatibleLightningFeatures if incompatible. + """ our_flags = set(list_enabled_bits(our_features)) their_flags = set(list_enabled_bits(their_features)) + # check that they have our required features, and disable the optional features they don't have for flag in our_flags: if flag not in their_flags and get_ln_flag_pair_of_bit(flag) not in their_flags: # they don't have this feature we wanted :( if flag % 2 == 0: # even flags are compulsory - raise IncompatibleLightningFeatures(f"remote does not support {LnLocalFeatures(1 << flag)!r}") + raise IncompatibleLightningFeatures(f"remote does not support {LnFeatures(1 << flag)!r}") our_features ^= 1 << flag # disable flag else: # They too have this flag. @@ -759,6 +873,12 @@ def ln_compare_features(our_features, their_features) -> int: # set the corresponding odd flag now. if flag % 2 == 0 and our_features & (1 << flag): our_features |= 1 << get_ln_flag_pair_of_bit(flag) + # check that we have their required features + for flag in their_flags: + if flag not in our_flags and get_ln_flag_pair_of_bit(flag) not in our_flags: + # we don't have this feature they wanted :( + if flag % 2 == 0: # even flags are compulsory + raise IncompatibleLightningFeatures(f"remote wanted feature we don't have: {LnFeatures(1 << flag)!r}") return our_features diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 96930517c..592bb5951 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -52,10 +52,10 @@ from .lnutil import (Outpoint, LNPeerAddr, generate_keypair, LnKeyFamily, LOCAL, REMOTE, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE, NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner, - UpdateAddHtlc, Direction, LnLocalFeatures, + UpdateAddHtlc, Direction, LnFeatures, ShortChannelID, PaymentAttemptLog, PaymentAttemptFailureDetails, BarePaymentAttemptLog) -from .lnutil import ln_dummy_address, ln_compare_features +from .lnutil import ln_dummy_address, ln_compare_features, IncompatibleLightningFeatures from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput from .lnonion import OnionFailureCode, process_onion_packet, OnionPacket from .lnmsg import decode_msg @@ -147,9 +147,9 @@ class LNWorker(Logger): self.taskgroup = SilentTaskGroup() # set some feature flags as baseline for both LNWallet and LNGossip # note that e.g. DATA_LOSS_PROTECT is needed for LNGossip as many peers require it - self.localfeatures = LnLocalFeatures(0) - self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT - self.localfeatures |= LnLocalFeatures.OPTION_STATIC_REMOTEKEY_OPT + self.features = LnFeatures(0) + self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT + self.features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT def channels_for_peer(self, node_id): return {} @@ -248,8 +248,8 @@ class LNWorker(Logger): if not node: return False try: - ln_compare_features(self.localfeatures, node.features) - except ValueError: + ln_compare_features(self.features, node.features) + except IncompatibleLightningFeatures: return False #self.logger.info(f'is_good {peer.host}') return True @@ -366,8 +366,8 @@ class LNGossip(LNWorker): node = BIP32Node.from_rootseed(seed, xtype='standard') xprv = node.to_xprv() super().__init__(xprv) - self.localfeatures |= LnLocalFeatures.GOSSIP_QUERIES_OPT - self.localfeatures |= LnLocalFeatures.GOSSIP_QUERIES_REQ + self.features |= LnFeatures.GOSSIP_QUERIES_OPT + self.features |= LnFeatures.GOSSIP_QUERIES_REQ self.unknown_ids = set() def start_network(self, network: 'Network'): @@ -419,8 +419,8 @@ class LNWallet(LNWorker): self.db = wallet.db self.config = wallet.config LNWorker.__init__(self, xprv) - self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ - self.localfeatures |= LnLocalFeatures.OPTION_STATIC_REMOTEKEY_REQ + self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ + self.features |= LnFeatures.OPTION_STATIC_REMOTEKEY_REQ self.payments = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage self.sweep_address = wallet.get_receiving_address() diff --git a/electrum/tests/test_lnmsg.py b/electrum/tests/test_lnmsg.py index 0bd144a1a..be83910a3 100644 --- a/electrum/tests/test_lnmsg.py +++ b/electrum/tests/test_lnmsg.py @@ -5,7 +5,7 @@ from electrum.lnmsg import (read_bigsize_int, write_bigsize_int, FieldEncodingNo MalformedMsg, MsgTrailingGarbage, MsgInvalidFieldOrder, encode_msg, decode_msg, UnexpectedFieldSizeForEncoder) from electrum.util import bfh -from electrum.lnutil import ShortChannelID, LnLocalFeatures +from electrum.lnutil import ShortChannelID, LnFeatures from electrum import constants from . import TestCaseForTestnet @@ -344,10 +344,10 @@ class TestLNMsg(TestCaseForTestnet): "init", gflen=0, flen=2, - features=(LnLocalFeatures.OPTION_STATIC_REMOTEKEY_OPT | - LnLocalFeatures.GOSSIP_QUERIES_OPT | - LnLocalFeatures.GOSSIP_QUERIES_REQ | - LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT), + features=(LnFeatures.OPTION_STATIC_REMOTEKEY_OPT | + LnFeatures.GOSSIP_QUERIES_OPT | + LnFeatures.GOSSIP_QUERIES_REQ | + LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT), )) self.assertEqual(bfh("00100000000220c2"), encode_msg("init", gflen=0, flen=2, features=bfh("20c2"))) @@ -356,10 +356,10 @@ class TestLNMsg(TestCaseForTestnet): "init", gflen=0, flen=2, - features=(LnLocalFeatures.OPTION_STATIC_REMOTEKEY_OPT | - LnLocalFeatures.GOSSIP_QUERIES_OPT | - LnLocalFeatures.GOSSIP_QUERIES_REQ | - LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT), + features=(LnFeatures.OPTION_STATIC_REMOTEKEY_OPT | + LnFeatures.GOSSIP_QUERIES_OPT | + LnFeatures.GOSSIP_QUERIES_REQ | + LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT), init_tlvs={ 'networks': {'chains': b'CI\x7f\xd7\xf8&\x95q\x08\xf4\xa3\x0f\xd9\xce\xc3\xae\xbay\x97 \x84\xe9\x0e\xad\x01\xea3\t\x00\x00\x00\x00'} diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 1f08568e5..8c2452b9a 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -21,7 +21,7 @@ from electrum.util import bh2u, create_and_start_event_loop from electrum.lnpeer import Peer from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving -from electrum.lnutil import PaymentFailure, LnLocalFeatures, HTLCOwner +from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner from electrum.lnchannel import channel_states, peer_states, Channel from electrum.lnrouter import LNPathFinder from electrum.channel_db import ChannelDB @@ -95,8 +95,8 @@ class MockLNWallet(Logger): self.payments = {} self.logs = defaultdict(list) self.wallet = MockWallet() - self.localfeatures = LnLocalFeatures(0) - self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT + self.features = LnFeatures(0) + self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT self.pending_payments = defaultdict(asyncio.Future) chan.lnworker = self chan.node_id = remote_keypair.pubkey diff --git a/electrum/tests/test_lnutil.py b/electrum/tests/test_lnutil.py index 885386684..aee907b3c 100644 --- a/electrum/tests/test_lnutil.py +++ b/electrum/tests/test_lnutil.py @@ -8,7 +8,7 @@ from electrum.lnutil import (RevocationStore, get_per_commitment_secret_from_see make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey, derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret, get_compressed_pubkey_from_bech32, split_host_port, ConnStringFormatError, - ScriptHtlc, extract_nodeid, calc_fees_for_commitment_tx, UpdateAddHtlc) + ScriptHtlc, extract_nodeid, calc_fees_for_commitment_tx, UpdateAddHtlc, LnFeatures) from electrum.util import bh2u, bfh, MyEncoder from electrum.transaction import Transaction, PartialTransaction @@ -755,3 +755,53 @@ class TestLNUtil(ElectrumTestCase): with self.assertRaises(ConnStringFormatError): extract_nodeid("00" * 33 + "@") self.assertEqual(extract_nodeid("00" * 33 + "@localhost"), (b"\x00" * 33, "localhost")) + + def test_ln_features_validate_transitive_dependecies(self): + features = LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ + self.assertTrue(features.validate_transitive_dependecies()) + features = LnFeatures.PAYMENT_SECRET_OPT + self.assertFalse(features.validate_transitive_dependecies()) + features = LnFeatures.PAYMENT_SECRET_REQ + self.assertFalse(features.validate_transitive_dependecies()) + features = LnFeatures.PAYMENT_SECRET_REQ | LnFeatures.VAR_ONION_REQ + self.assertTrue(features.validate_transitive_dependecies()) + features = LnFeatures.BASIC_MPP_OPT | LnFeatures.PAYMENT_SECRET_REQ + self.assertFalse(features.validate_transitive_dependecies()) + features = LnFeatures.BASIC_MPP_OPT | LnFeatures.PAYMENT_SECRET_REQ | LnFeatures.VAR_ONION_OPT + self.assertTrue(features.validate_transitive_dependecies()) + features = LnFeatures.BASIC_MPP_OPT | LnFeatures.PAYMENT_SECRET_REQ | LnFeatures.VAR_ONION_REQ + self.assertTrue(features.validate_transitive_dependecies()) + + def test_ln_features_for_init_message(self): + features = LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ + self.assertEqual(features, features.for_init_message()) + features = LnFeatures.PAYMENT_SECRET_OPT + self.assertEqual(features, features.for_init_message()) + features = LnFeatures.PAYMENT_SECRET_REQ + self.assertEqual(features, features.for_init_message()) + features = LnFeatures.PAYMENT_SECRET_REQ | LnFeatures.VAR_ONION_REQ + self.assertEqual(features, features.for_init_message()) + features = LnFeatures.BASIC_MPP_OPT | LnFeatures.PAYMENT_SECRET_REQ + self.assertEqual(features, features.for_init_message()) + features = LnFeatures.BASIC_MPP_OPT | LnFeatures.PAYMENT_SECRET_REQ | LnFeatures.VAR_ONION_OPT + self.assertEqual(features, features.for_init_message()) + features = LnFeatures.BASIC_MPP_OPT | LnFeatures.PAYMENT_SECRET_REQ | LnFeatures.VAR_ONION_REQ + self.assertEqual(features, features.for_init_message()) + + def test_ln_features_for_invoice(self): + features = LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ + self.assertEqual(LnFeatures(0), features.for_invoice()) + features = LnFeatures.PAYMENT_SECRET_OPT + self.assertEqual(features, features.for_invoice()) + features = LnFeatures.PAYMENT_SECRET_REQ + self.assertEqual(features, features.for_invoice()) + features = LnFeatures.PAYMENT_SECRET_REQ | LnFeatures.VAR_ONION_REQ + self.assertEqual(features, features.for_invoice()) + features = LnFeatures.BASIC_MPP_OPT | LnFeatures.PAYMENT_SECRET_REQ | LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ + self.assertEqual(LnFeatures.BASIC_MPP_OPT | LnFeatures.PAYMENT_SECRET_REQ, + features.for_invoice()) + features = LnFeatures.BASIC_MPP_OPT | LnFeatures.PAYMENT_SECRET_REQ | LnFeatures.VAR_ONION_OPT | LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ + self.assertEqual(LnFeatures.BASIC_MPP_OPT | LnFeatures.PAYMENT_SECRET_REQ | LnFeatures.VAR_ONION_OPT, + features.for_invoice()) + features = LnFeatures.BASIC_MPP_OPT | LnFeatures.PAYMENT_SECRET_REQ | LnFeatures.VAR_ONION_REQ + self.assertEqual(features, features.for_invoice())