Browse Source

create class for ShortChannelID and use it

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
SomberNight 6 years ago
parent
commit
509df9ddaf
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 30
      electrum/channel_db.py
  2. 5
      electrum/lnchannel.py
  3. 5
      electrum/lnonion.py
  4. 10
      electrum/lnpeer.py
  5. 28
      electrum/lnrouter.py
  6. 49
      electrum/lnutil.py
  7. 35
      electrum/lnverifier.py
  8. 24
      electrum/lnworker.py

30
electrum/channel_db.py

@ -37,7 +37,7 @@ from .sql_db import SqlDB, sql
from . import constants from . import constants
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
from .logging import Logger from .logging import Logger
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
if TYPE_CHECKING: if TYPE_CHECKING:
@ -57,10 +57,10 @@ FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0 FLAG_DIRECTION = 1 << 0
class ChannelInfo(NamedTuple): class ChannelInfo(NamedTuple):
short_channel_id: bytes short_channel_id: ShortChannelID
node1_id: bytes node1_id: bytes
node2_id: bytes node2_id: bytes
capacity_sat: int capacity_sat: Optional[int]
@staticmethod @staticmethod
def from_msg(payload): def from_msg(payload):
@ -72,10 +72,11 @@ class ChannelInfo(NamedTuple):
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
capacity_sat = None capacity_sat = None
return ChannelInfo( return ChannelInfo(
short_channel_id = channel_id, short_channel_id = ShortChannelID.normalize(channel_id),
node1_id = node_id_1, node1_id = node_id_1,
node2_id = node_id_2, node2_id = node_id_2,
capacity_sat = capacity_sat) capacity_sat = capacity_sat
)
class Policy(NamedTuple): class Policy(NamedTuple):
@ -107,8 +108,8 @@ class Policy(NamedTuple):
return self.channel_flags & FLAG_DISABLE return self.channel_flags & FLAG_DISABLE
@property @property
def short_channel_id(self): def short_channel_id(self) -> ShortChannelID:
return self.key[0:8] return ShortChannelID.normalize(self.key[0:8])
@property @property
def start_node(self): def start_node(self):
@ -290,7 +291,7 @@ class ChannelDB(SqlDB):
msg_payloads = [msg_payloads] msg_payloads = [msg_payloads]
added = 0 added = 0
for msg in msg_payloads: for msg in msg_payloads:
short_channel_id = msg['short_channel_id'] short_channel_id = ShortChannelID(msg['short_channel_id'])
if short_channel_id in self._channels: if short_channel_id in self._channels:
continue continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']: if constants.net.rev_genesis_bytes() != msg['chain_hash']:
@ -339,7 +340,7 @@ class ChannelDB(SqlDB):
known = [] known = []
now = int(time.time()) now = int(time.time())
for payload in payloads: for payload in payloads:
short_channel_id = payload['short_channel_id'] short_channel_id = ShortChannelID(payload['short_channel_id'])
timestamp = int.from_bytes(payload['timestamp'], "big") timestamp = int.from_bytes(payload['timestamp'], "big")
if max_age and now - timestamp > max_age: if max_age and now - timestamp > max_age:
expired.append(payload) expired.append(payload)
@ -357,7 +358,7 @@ class ChannelDB(SqlDB):
for payload in known: for payload in known:
timestamp = int.from_bytes(payload['timestamp'], "big") timestamp = int.from_bytes(payload['timestamp'], "big")
start_node = payload['start_node'] start_node = payload['start_node']
short_channel_id = payload['short_channel_id'] short_channel_id = ShortChannelID(payload['short_channel_id'])
key = (start_node, short_channel_id) key = (start_node, short_channel_id)
old_policy = self._policies.get(key) old_policy = self._policies.get(key)
if old_policy and timestamp <= old_policy.timestamp: if old_policy and timestamp <= old_policy.timestamp:
@ -434,11 +435,11 @@ class ChannelDB(SqlDB):
def verify_channel_update(self, payload): def verify_channel_update(self, payload):
short_channel_id = payload['short_channel_id'] short_channel_id = payload['short_channel_id']
scid = format_short_channel_id(short_channel_id) short_channel_id = ShortChannelID(short_channel_id)
if constants.net.rev_genesis_bytes() != payload['chain_hash']: if constants.net.rev_genesis_bytes() != payload['chain_hash']:
raise Exception('wrong chain hash') raise Exception('wrong chain hash')
if not verify_sig_for_channel_update(payload, payload['start_node']): if not verify_sig_for_channel_update(payload, payload['start_node']):
raise Exception(f'failed verifying channel update for {scid}') raise Exception(f'failed verifying channel update for {short_channel_id}')
def add_node_announcement(self, msg_payloads): def add_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict: if type(msg_payloads) is dict:
@ -510,11 +511,11 @@ class ChannelDB(SqlDB):
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
if not verify_sig_for_channel_update(msg_payload, start_node_id): if not verify_sig_for_channel_update(msg_payload, start_node_id):
return # ignore return # ignore
short_channel_id = msg_payload['short_channel_id'] short_channel_id = ShortChannelID(msg_payload['short_channel_id'])
msg_payload['start_node'] = start_node_id msg_payload['start_node'] = start_node_id
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
def remove_channel(self, short_channel_id): def remove_channel(self, short_channel_id: ShortChannelID):
channel_info = self._channels.pop(short_channel_id, None) channel_info = self._channels.pop(short_channel_id, None)
if channel_info: if channel_info:
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id) self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
@ -533,6 +534,7 @@ class ChannelDB(SqlDB):
self._addresses[node_id].add((str(host), int(port), int(timestamp or 0))) self._addresses[node_id].add((str(host), int(port), int(timestamp or 0)))
c.execute("""SELECT * FROM channel_info""") c.execute("""SELECT * FROM channel_info""")
for x in c: for x in c:
x = (ShortChannelID.normalize(x[0]), *x[1:])
ci = ChannelInfo(*x) ci = ChannelInfo(*x)
self._channels[ci.short_channel_id] = ci self._channels[ci.short_channel_id] = ci
c.execute("""SELECT * FROM node_info""") c.execute("""SELECT * FROM node_info""")

5
electrum/lnchannel.py

@ -45,7 +45,8 @@ from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKey
make_htlc_tx_with_open_channel, make_commitment, make_received_htlc, make_offered_htlc, make_htlc_tx_with_open_channel, make_commitment, make_received_htlc, make_offered_htlc,
HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT, extract_ctn_from_tx_and_chan, UpdateAddHtlc, HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT, extract_ctn_from_tx_and_chan, UpdateAddHtlc,
funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs, funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs,
ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script) ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script,
ShortChannelID)
from .lnutil import FeeUpdate from .lnutil import FeeUpdate
from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx
from .lnsweep import create_sweeptx_for_their_revoked_htlc from .lnsweep import create_sweeptx_for_their_revoked_htlc
@ -130,7 +131,7 @@ class Channel(Logger):
self.constraints = ChannelConstraints(**state["constraints"]) if type(state["constraints"]) is not ChannelConstraints else state["constraints"] self.constraints = ChannelConstraints(**state["constraints"]) if type(state["constraints"]) is not ChannelConstraints else state["constraints"]
self.funding_outpoint = Outpoint(**dict(decodeAll(state["funding_outpoint"], False))) if type(state["funding_outpoint"]) is not Outpoint else state["funding_outpoint"] self.funding_outpoint = Outpoint(**dict(decodeAll(state["funding_outpoint"], False))) if type(state["funding_outpoint"]) is not Outpoint else state["funding_outpoint"]
self.node_id = bfh(state["node_id"]) if type(state["node_id"]) not in (bytes, type(None)) else state["node_id"] # type: bytes self.node_id = bfh(state["node_id"]) if type(state["node_id"]) not in (bytes, type(None)) else state["node_id"] # type: bytes
self.short_channel_id = bfh(state["short_channel_id"]) if type(state["short_channel_id"]) not in (bytes, type(None)) else state["short_channel_id"] self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"])
self.short_channel_id_predicted = self.short_channel_id self.short_channel_id_predicted = self.short_channel_id
self.onion_keys = str_bytes_dict_from_save(state.get('onion_keys', {})) self.onion_keys = str_bytes_dict_from_save(state.get('onion_keys', {}))
self.force_closed = state.get('force_closed') self.force_closed = state.get('force_closed')

5
electrum/lnonion.py

@ -32,7 +32,8 @@ from Cryptodome.Cipher import ChaCha20
from . import ecc from . import ecc
from .crypto import sha256, hmac_oneshot from .crypto import sha256, hmac_oneshot
from .util import bh2u, profiler, xor_bytes, bfh from .util import bh2u, profiler, xor_bytes, bfh
from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, NUM_MAX_EDGES_IN_PAYMENT_PATH from .lnutil import (get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH,
NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID)
if TYPE_CHECKING: if TYPE_CHECKING:
from .lnrouter import RouteEdge from .lnrouter import RouteEdge
@ -51,7 +52,7 @@ class InvalidOnionMac(Exception): pass
class OnionPerHop: class OnionPerHop:
def __init__(self, short_channel_id: bytes, amt_to_forward: bytes, outgoing_cltv_value: bytes): def __init__(self, short_channel_id: bytes, amt_to_forward: bytes, outgoing_cltv_value: bytes):
self.short_channel_id = short_channel_id self.short_channel_id = ShortChannelID(short_channel_id)
self.amt_to_forward = amt_to_forward self.amt_to_forward = amt_to_forward
self.outgoing_cltv_value = outgoing_cltv_value self.outgoing_cltv_value = outgoing_cltv_value

10
electrum/lnpeer.py

@ -41,7 +41,7 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
LightningPeerConnectionClosed, HandshakeFailed, NotFoundChanAnnouncementForUpdate, LightningPeerConnectionClosed, HandshakeFailed, NotFoundChanAnnouncementForUpdate,
MINIMUM_MAX_HTLC_VALUE_IN_FLIGHT_ACCEPTED, MAXIMUM_HTLC_MINIMUM_MSAT_ACCEPTED, MINIMUM_MAX_HTLC_VALUE_IN_FLIGHT_ACCEPTED, MAXIMUM_HTLC_MINIMUM_MSAT_ACCEPTED,
MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY, MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY,
NBLOCK_OUR_CLTV_EXPIRY_DELTA, format_short_channel_id) NBLOCK_OUR_CLTV_EXPIRY_DELTA, format_short_channel_id, ShortChannelID)
from .lnutil import FeeUpdate from .lnutil import FeeUpdate
from .lntransport import LNTransport, LNTransportBase from .lntransport import LNTransport, LNTransportBase
from .lnmsg import encode_msg, decode_msg from .lnmsg import encode_msg, decode_msg
@ -283,7 +283,7 @@ class Peer(Logger):
# as it might be for our own direct channel with this peer # as it might be for our own direct channel with this peer
# (and we might not yet know the short channel id for that) # (and we might not yet know the short channel id for that)
for chan_upd_payload in orphaned: for chan_upd_payload in orphaned:
short_channel_id = chan_upd_payload['short_channel_id'] short_channel_id = ShortChannelID(chan_upd_payload['short_channel_id'])
self.orphan_channel_updates[short_channel_id] = chan_upd_payload self.orphan_channel_updates[short_channel_id] = chan_upd_payload
while len(self.orphan_channel_updates) > 25: while len(self.orphan_channel_updates) > 25:
self.orphan_channel_updates.popitem(last=False) self.orphan_channel_updates.popitem(last=False)
@ -959,7 +959,7 @@ class Peer(Logger):
def mark_open(self, chan: Channel): def mark_open(self, chan: Channel):
assert chan.short_channel_id is not None assert chan.short_channel_id is not None
scid = format_short_channel_id(chan.short_channel_id) scid = chan.short_channel_id
# only allow state transition to "OPEN" from "OPENING" # only allow state transition to "OPEN" from "OPENING"
if chan.get_state() != "OPENING": if chan.get_state() != "OPENING":
return return
@ -1096,7 +1096,7 @@ class Peer(Logger):
chan = self.channels[channel_id] chan = self.channels[channel_id]
key = (channel_id, htlc_id) key = (channel_id, htlc_id)
try: try:
route = self.attempted_route[key] route = self.attempted_route[key] # type: List[RouteEdge]
except KeyError: except KeyError:
# the remote might try to fail an htlc after we restarted... # the remote might try to fail an htlc after we restarted...
# attempted_route is not persisted, so we will get here then # attempted_route is not persisted, so we will get here then
@ -1310,7 +1310,7 @@ class Peer(Logger):
return return
dph = processed_onion.hop_data.per_hop dph = processed_onion.hop_data.per_hop
next_chan = self.lnworker.get_channel_by_short_id(dph.short_channel_id) next_chan = self.lnworker.get_channel_by_short_id(dph.short_channel_id)
next_chan_scid = format_short_channel_id(dph.short_channel_id) next_chan_scid = dph.short_channel_id
next_peer = self.lnworker.peers[next_chan.node_id] next_peer = self.lnworker.peers[next_chan.node_id]
local_height = self.network.get_local_height() local_height = self.network.get_local_height()
if next_chan is None: if next_chan is None:

28
electrum/lnrouter.py

@ -29,7 +29,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
from .util import bh2u, profiler from .util import bh2u, profiler
from .logging import Logger from .logging import Logger
from .lnutil import NUM_MAX_EDGES_IN_PAYMENT_PATH from .lnutil import NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID
from .channel_db import ChannelDB, Policy from .channel_db import ChannelDB, Policy
if TYPE_CHECKING: if TYPE_CHECKING:
@ -38,7 +38,8 @@ if TYPE_CHECKING:
class NoChannelPolicy(Exception): class NoChannelPolicy(Exception):
def __init__(self, short_channel_id: bytes): def __init__(self, short_channel_id: bytes):
super().__init__(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}') short_channel_id = ShortChannelID.normalize(short_channel_id)
super().__init__(f'cannot find channel policy for short_channel_id: {short_channel_id}')
def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_proportional_millionths: int) -> int: def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_proportional_millionths: int) -> int:
@ -46,12 +47,13 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor
+ (forwarded_amount_msat * fee_proportional_millionths // 1_000_000) + (forwarded_amount_msat * fee_proportional_millionths // 1_000_000)
class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes), class RouteEdge(NamedTuple):
('short_channel_id', bytes),
('fee_base_msat', int),
('fee_proportional_millionths', int),
('cltv_expiry_delta', int)])):
"""if you travel through short_channel_id, you will reach node_id""" """if you travel through short_channel_id, you will reach node_id"""
node_id: bytes
short_channel_id: ShortChannelID
fee_base_msat: int
fee_proportional_millionths: int
cltv_expiry_delta: int
def fee_for_edge(self, amount_msat: int) -> int: def fee_for_edge(self, amount_msat: int) -> int:
return fee_for_edge_msat(forwarded_amount_msat=amount_msat, return fee_for_edge_msat(forwarded_amount_msat=amount_msat,
@ -61,10 +63,10 @@ class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
@classmethod @classmethod
def from_channel_policy(cls, channel_policy: 'Policy', def from_channel_policy(cls, channel_policy: 'Policy',
short_channel_id: bytes, end_node: bytes) -> 'RouteEdge': short_channel_id: bytes, end_node: bytes) -> 'RouteEdge':
assert type(short_channel_id) is bytes assert isinstance(short_channel_id, bytes)
assert type(end_node) is bytes assert type(end_node) is bytes
return RouteEdge(end_node, return RouteEdge(end_node,
short_channel_id, ShortChannelID.normalize(short_channel_id),
channel_policy.fee_base_msat, channel_policy.fee_base_msat,
channel_policy.fee_proportional_millionths, channel_policy.fee_proportional_millionths,
channel_policy.cltv_expiry_delta) channel_policy.cltv_expiry_delta)
@ -119,8 +121,8 @@ class LNPathFinder(Logger):
self.channel_db = channel_db self.channel_db = channel_db
self.blacklist = set() self.blacklist = set()
def add_to_blacklist(self, short_channel_id): def add_to_blacklist(self, short_channel_id: ShortChannelID):
self.logger.info(f'blacklisting channel {bh2u(short_channel_id)}') self.logger.info(f'blacklisting channel {short_channel_id}')
self.blacklist.add(short_channel_id) self.blacklist.add(short_channel_id)
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes, def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
@ -218,7 +220,7 @@ class LNPathFinder(Logger):
# so there are duplicates in the queue, that we discard now: # so there are duplicates in the queue, that we discard now:
continue continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode): for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
assert type(edge_channel_id) is bytes assert isinstance(edge_channel_id, bytes)
if edge_channel_id in self.blacklist: if edge_channel_id in self.blacklist:
continue continue
channel_info = self.channel_db.get_channel_info(edge_channel_id) channel_info = self.channel_db.get_channel_info(edge_channel_id)
@ -237,7 +239,7 @@ class LNPathFinder(Logger):
return path return path
def create_route_from_path(self, path, from_node_id: bytes) -> List[RouteEdge]: def create_route_from_path(self, path, from_node_id: bytes) -> List[RouteEdge]:
assert type(from_node_id) is bytes assert isinstance(from_node_id, bytes)
if path is None: if path is None:
raise Exception('cannot create route from None path') raise Exception('cannot create route from None path')
route = [] route = []

49
electrum/lnutil.py

@ -546,17 +546,6 @@ def funding_output_script_from_keys(pubkey1: bytes, pubkey2: bytes) -> str:
pubkeys = sorted([bh2u(pubkey1), bh2u(pubkey2)]) pubkeys = sorted([bh2u(pubkey1), bh2u(pubkey2)])
return transaction.multisig_script(pubkeys, 2) return transaction.multisig_script(pubkeys, 2)
def calc_short_channel_id(block_height: int, tx_pos_in_block: int, output_index: int) -> bytes:
bh = block_height.to_bytes(3, byteorder='big')
tpos = tx_pos_in_block.to_bytes(3, byteorder='big')
oi = output_index.to_bytes(2, byteorder='big')
return bh + tpos + oi
def invert_short_channel_id(short_channel_id: bytes) -> (int, int, int):
bh = int.from_bytes(short_channel_id[:3], byteorder='big')
tpos = int.from_bytes(short_channel_id[3:6], byteorder='big')
oi = int.from_bytes(short_channel_id[6:8], byteorder='big')
return bh, tpos, oi
def get_obscured_ctn(ctn: int, funder: bytes, fundee: bytes) -> int: def get_obscured_ctn(ctn: int, funder: bytes, fundee: bytes) -> int:
mask = int.from_bytes(sha256(funder + fundee)[-6:], 'big') mask = int.from_bytes(sha256(funder + fundee)[-6:], 'big')
@ -705,6 +694,44 @@ def generate_keypair(ln_keystore: BIP32_KeyStore, key_family: LnKeyFamily, index
NUM_MAX_HOPS_IN_PAYMENT_PATH = 20 NUM_MAX_HOPS_IN_PAYMENT_PATH = 20
NUM_MAX_EDGES_IN_PAYMENT_PATH = NUM_MAX_HOPS_IN_PAYMENT_PATH + 1 NUM_MAX_EDGES_IN_PAYMENT_PATH = NUM_MAX_HOPS_IN_PAYMENT_PATH + 1
class ShortChannelID(bytes):
def __repr__(self):
return f"<ShortChannelID: {format_short_channel_id(self)}>"
def __str__(self):
return format_short_channel_id(self)
@classmethod
def from_components(cls, block_height: int, tx_pos_in_block: int, output_index: int) -> 'ShortChannelID':
bh = block_height.to_bytes(3, byteorder='big')
tpos = tx_pos_in_block.to_bytes(3, byteorder='big')
oi = output_index.to_bytes(2, byteorder='big')
return ShortChannelID(bh + tpos + oi)
@classmethod
def normalize(cls, data: Union[None, str, bytes, 'ShortChannelID']) -> Optional['ShortChannelID']:
if isinstance(data, ShortChannelID) or data is None:
return data
if isinstance(data, str):
return ShortChannelID.fromhex(data)
if isinstance(data, bytes):
return ShortChannelID(data)
@property
def block_height(self) -> int:
return int.from_bytes(self[:3], byteorder='big')
@property
def txpos(self) -> int:
return int.from_bytes(self[3:6], byteorder='big')
@property
def output_index(self) -> int:
return int.from_bytes(self[6:8], byteorder='big')
def format_short_channel_id(short_channel_id: Optional[bytes]): def format_short_channel_id(short_channel_id: Optional[bytes]):
if not short_channel_id: if not short_channel_id:
return _('Not yet available') return _('Not yet available')

35
electrum/lnverifier.py

@ -25,7 +25,7 @@
import asyncio import asyncio
import threading import threading
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Dict, Set
import aiorpcx import aiorpcx
@ -33,7 +33,7 @@ from . import bitcoin
from . import ecc from . import ecc
from . import constants from . import constants
from .util import bh2u, bfh, NetworkJobOnDefaultServer from .util import bh2u, bfh, NetworkJobOnDefaultServer
from .lnutil import invert_short_channel_id, funding_output_script_from_keys from .lnutil import funding_output_script_from_keys, ShortChannelID
from .verifier import verify_tx_is_in_block, MerkleVerificationFailure from .verifier import verify_tx_is_in_block, MerkleVerificationFailure
from .transaction import Transaction from .transaction import Transaction
from .interface import GracefulDisconnect from .interface import GracefulDisconnect
@ -56,17 +56,16 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
NetworkJobOnDefaultServer.__init__(self, network) NetworkJobOnDefaultServer.__init__(self, network)
self.channel_db = channel_db self.channel_db = channel_db
self.lock = threading.Lock() self.lock = threading.Lock()
self.unverified_channel_info = {} # short_channel_id -> msg_payload self.unverified_channel_info = {} # type: Dict[ShortChannelID, dict] # scid -> msg_payload
# channel announcements that seem to be invalid: # channel announcements that seem to be invalid:
self.blacklist = set() # short_channel_id self.blacklist = set() # type: Set[ShortChannelID]
def _reset(self): def _reset(self):
super()._reset() super()._reset()
self.started_verifying_channel = set() # short_channel_id self.started_verifying_channel = set() # type: Set[ShortChannelID]
# TODO make async; and rm self.lock completely # TODO make async; and rm self.lock completely
def add_new_channel_info(self, short_channel_id_hex, msg_payload): def add_new_channel_info(self, short_channel_id: ShortChannelID, msg_payload):
short_channel_id = bfh(short_channel_id_hex)
if short_channel_id in self.unverified_channel_info: if short_channel_id in self.unverified_channel_info:
return return
if short_channel_id in self.blacklist: if short_channel_id in self.blacklist:
@ -93,7 +92,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
for short_channel_id in unverified_channel_info: for short_channel_id in unverified_channel_info:
if short_channel_id in self.started_verifying_channel: if short_channel_id in self.started_verifying_channel:
continue continue
block_height, tx_pos, output_idx = invert_short_channel_id(short_channel_id) block_height = short_channel_id.block_height
# only resolve short_channel_id if headers are available. # only resolve short_channel_id if headers are available.
if block_height <= 0 or block_height > local_height: if block_height <= 0 or block_height > local_height:
continue continue
@ -103,16 +102,17 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
await self.group.spawn(self.network.request_chunk(block_height, None, can_return_early=True)) await self.group.spawn(self.network.request_chunk(block_height, None, can_return_early=True))
continue continue
self.started_verifying_channel.add(short_channel_id) self.started_verifying_channel.add(short_channel_id)
await self.group.spawn(self.verify_channel(block_height, tx_pos, short_channel_id)) await self.group.spawn(self.verify_channel(block_height, short_channel_id))
#self.logger.info(f'requested short_channel_id {bh2u(short_channel_id)}') #self.logger.info(f'requested short_channel_id {bh2u(short_channel_id)}')
async def verify_channel(self, block_height: int, tx_pos: int, short_channel_id: bytes): async def verify_channel(self, block_height: int, short_channel_id: ShortChannelID):
# we are verifying channel announcements as they are from untrusted ln peers. # we are verifying channel announcements as they are from untrusted ln peers.
# we use electrum servers to do this. however we don't trust electrum servers either... # we use electrum servers to do this. however we don't trust electrum servers either...
try: try:
result = await self.network.get_txid_from_txpos(block_height, tx_pos, True) result = await self.network.get_txid_from_txpos(
block_height, short_channel_id.txpos, True)
except aiorpcx.jsonrpc.RPCError: except aiorpcx.jsonrpc.RPCError:
# the electrum server is complaining about the tx_pos for given block. # the electrum server is complaining about the txpos for given block.
# it is not clear what to do now, but let's believe the server. # it is not clear what to do now, but let's believe the server.
self._blacklist_short_channel_id(short_channel_id) self._blacklist_short_channel_id(short_channel_id)
return return
@ -122,7 +122,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
async with self.network.bhi_lock: async with self.network.bhi_lock:
header = self.network.blockchain().read_header(block_height) header = self.network.blockchain().read_header(block_height)
try: try:
verify_tx_is_in_block(tx_hash, merkle_branch, tx_pos, header, block_height) verify_tx_is_in_block(tx_hash, merkle_branch, short_channel_id.txpos, header, block_height)
except MerkleVerificationFailure as e: except MerkleVerificationFailure as e:
# the electrum server sent an incorrect proof. blame is on server, not the ln peer # the electrum server sent an incorrect proof. blame is on server, not the ln peer
raise GracefulDisconnect(e) from e raise GracefulDisconnect(e) from e
@ -151,28 +151,27 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
assert msg_type == 'channel_announcement' assert msg_type == 'channel_announcement'
redeem_script = funding_output_script_from_keys(chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2']) redeem_script = funding_output_script_from_keys(chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2'])
expected_address = bitcoin.redeem_script_to_address('p2wsh', redeem_script) expected_address = bitcoin.redeem_script_to_address('p2wsh', redeem_script)
output_idx = invert_short_channel_id(short_channel_id)[2]
try: try:
actual_output = tx.outputs()[output_idx] actual_output = tx.outputs()[short_channel_id.output_index]
except IndexError: except IndexError:
self._blacklist_short_channel_id(short_channel_id) self._blacklist_short_channel_id(short_channel_id)
return return
if expected_address != actual_output.address: if expected_address != actual_output.address:
# FIXME what now? best would be to ban the originating ln peer. # FIXME what now? best would be to ban the originating ln peer.
self.logger.info(f"funding output script mismatch for {bh2u(short_channel_id)}") self.logger.info(f"funding output script mismatch for {short_channel_id}")
self._remove_channel_from_unverified_db(short_channel_id) self._remove_channel_from_unverified_db(short_channel_id)
return return
# put channel into channel DB # put channel into channel DB
self.channel_db.add_verified_channel_info(short_channel_id, actual_output.value) self.channel_db.add_verified_channel_info(short_channel_id, actual_output.value)
self._remove_channel_from_unverified_db(short_channel_id) self._remove_channel_from_unverified_db(short_channel_id)
def _remove_channel_from_unverified_db(self, short_channel_id: bytes): def _remove_channel_from_unverified_db(self, short_channel_id: ShortChannelID):
with self.lock: with self.lock:
self.unverified_channel_info.pop(short_channel_id, None) self.unverified_channel_info.pop(short_channel_id, None)
try: self.started_verifying_channel.remove(short_channel_id) try: self.started_verifying_channel.remove(short_channel_id)
except KeyError: pass except KeyError: pass
def _blacklist_short_channel_id(self, short_channel_id: bytes) -> None: def _blacklist_short_channel_id(self, short_channel_id: ShortChannelID) -> None:
self.blacklist.add(short_channel_id) self.blacklist.add(short_channel_id)
with self.lock: with self.lock:
self.unverified_channel_info.pop(short_channel_id, None) self.unverified_channel_info.pop(short_channel_id, None)

24
electrum/lnworker.py

@ -39,13 +39,14 @@ from .ecc import der_sig_from_sig_string
from .ecc_fast import is_using_fast_ecc from .ecc_fast import is_using_fast_ecc
from .lnchannel import Channel, ChannelJsonEncoder from .lnchannel import Channel, ChannelJsonEncoder
from . import lnutil from . import lnutil
from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr, from .lnutil import (Outpoint, LNPeerAddr,
get_compressed_pubkey_from_bech32, extract_nodeid, get_compressed_pubkey_from_bech32, extract_nodeid,
PaymentFailure, split_host_port, ConnStringFormatError, PaymentFailure, split_host_port, ConnStringFormatError,
generate_keypair, LnKeyFamily, LOCAL, REMOTE, generate_keypair, LnKeyFamily, LOCAL, REMOTE,
UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE,
NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner, NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner,
UpdateAddHtlc, Direction, LnLocalFeatures, format_short_channel_id) UpdateAddHtlc, Direction, LnLocalFeatures, format_short_channel_id,
ShortChannelID)
from .i18n import _ from .i18n import _
from .lnrouter import RouteEdge, is_route_sane_to_use from .lnrouter import RouteEdge, is_route_sane_to_use
from .address_synchronizer import TX_HEIGHT_LOCAL from .address_synchronizer import TX_HEIGHT_LOCAL
@ -553,10 +554,11 @@ class LNWallet(LNWorker):
if conf > 0: if conf > 0:
block_height, tx_pos = self.lnwatcher.get_txpos(chan.funding_outpoint.txid) block_height, tx_pos = self.lnwatcher.get_txpos(chan.funding_outpoint.txid)
assert tx_pos >= 0 assert tx_pos >= 0
chan.short_channel_id_predicted = calc_short_channel_id(block_height, tx_pos, chan.funding_outpoint.output_index) chan.short_channel_id_predicted = ShortChannelID.from_components(
block_height, tx_pos, chan.funding_outpoint.output_index)
if conf >= chan.constraints.funding_txn_minimum_depth > 0: if conf >= chan.constraints.funding_txn_minimum_depth > 0:
self.logger.info(f"save_short_channel_id")
chan.short_channel_id = chan.short_channel_id_predicted chan.short_channel_id = chan.short_channel_id_predicted
self.logger.info(f"save_short_channel_id: {chan.short_channel_id}")
self.save_channel(chan) self.save_channel(chan)
self.on_channels_updated() self.on_channels_updated()
else: else:
@ -795,7 +797,7 @@ class LNWallet(LNWorker):
else: else:
self.network.trigger_callback('payment_status', key, 'failure') self.network.trigger_callback('payment_status', key, 'failure')
def get_channel_by_short_id(self, short_channel_id): def get_channel_by_short_id(self, short_channel_id: ShortChannelID) -> Channel:
with self.lock: with self.lock:
for chan in self.channels.values(): for chan in self.channels.values():
if chan.short_channel_id == short_channel_id: if chan.short_channel_id == short_channel_id:
@ -815,7 +817,7 @@ class LNWallet(LNWorker):
for i in range(attempts): for i in range(attempts):
route = await self._create_route_from_invoice(decoded_invoice=addr) route = await self._create_route_from_invoice(decoded_invoice=addr)
if not self.get_channel_by_short_id(route[0].short_channel_id): if not self.get_channel_by_short_id(route[0].short_channel_id):
scid = format_short_channel_id(route[0].short_channel_id) scid = route[0].short_channel_id
raise Exception(f"Got route with unknown first channel: {scid}") raise Exception(f"Got route with unknown first channel: {scid}")
self.network.trigger_callback('payment_status', key, 'progress', i) self.network.trigger_callback('payment_status', key, 'progress', i)
if await self._pay_to_route(route, addr, invoice): if await self._pay_to_route(route, addr, invoice):
@ -826,8 +828,8 @@ class LNWallet(LNWorker):
short_channel_id = route[0].short_channel_id short_channel_id = route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id) chan = self.get_channel_by_short_id(short_channel_id)
if not chan: if not chan:
scid = format_short_channel_id(short_channel_id) raise Exception(f"PathFinder returned path with short_channel_id "
raise Exception(f"PathFinder returned path with short_channel_id {scid} that is not in channel list") f"{short_channel_id} that is not in channel list")
peer = self.peers[route[0].node_id] peer = self.peers[route[0].node_id]
htlc = await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry()) htlc = await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry())
self.network.trigger_callback('htlc_added', htlc, addr, SENT) self.network.trigger_callback('htlc_added', htlc, addr, SENT)
@ -879,6 +881,7 @@ class LNWallet(LNWorker):
prev_node_id = border_node_pubkey prev_node_id = border_node_pubkey
for node_pubkey, edge_rest in zip(private_route_nodes, private_route_rest): for node_pubkey, edge_rest in zip(private_route_nodes, private_route_rest):
short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest
short_channel_id = ShortChannelID(short_channel_id)
# if we have a routing policy for this edge in the db, that takes precedence, # if we have a routing policy for this edge in the db, that takes precedence,
# as it is likely from a previous failure # as it is likely from a previous failure
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id) channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
@ -1030,7 +1033,7 @@ class LNWallet(LNWorker):
if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat: if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat:
continue continue
chan_id = chan.short_channel_id chan_id = chan.short_channel_id
assert type(chan_id) is bytes, chan_id assert isinstance(chan_id, bytes), chan_id
channel_info = self.channel_db.get_channel_info(chan_id) channel_info = self.channel_db.get_channel_info(chan_id)
# note: as a fallback, if we don't have a channel update for the # note: as a fallback, if we don't have a channel update for the
# incoming direction of our private channel, we fill the invoice with garbage. # incoming direction of our private channel, we fill the invoice with garbage.
@ -1048,8 +1051,7 @@ class LNWallet(LNWorker):
cltv_expiry_delta = policy.cltv_expiry_delta cltv_expiry_delta = policy.cltv_expiry_delta
missing_info = False missing_info = False
if missing_info: if missing_info:
scid = format_short_channel_id(chan_id) self.logger.info(f"Warning. Missing channel update for our channel {chan_id}; "
self.logger.info(f"Warning. Missing channel update for our channel {scid}; "
f"filling invoice with incorrect data.") f"filling invoice with incorrect data.")
routing_hints.append(('r', [(chan.node_id, routing_hints.append(('r', [(chan.node_id,
chan_id, chan_id,

Loading…
Cancel
Save