Browse Source

move connection string decoding to lnworker, fix test_lnutil

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 6 years ago
committed by ThomasV
parent
commit
efc8d50570
  1. 6
      electrum/commands.py
  2. 54
      electrum/gui/qt/channels_list.py
  3. 47
      electrum/lnutil.py
  4. 49
      electrum/lnworker.py
  5. 46
      electrum/tests/test_lnutil.py

6
electrum/commands.py

@ -765,9 +765,9 @@ class Commands:
# lightning network commands
@command('wpn')
def open_channel(self, node_id, amount, channel_push=0, password=None):
f = self.wallet.lnworker.open_channel(bytes.fromhex(node_id), satoshis(amount), satoshis(channel_push), password)
return f.result()
def open_channel(self, connection_string, amount, channel_push=0, password=None):
f = self.wallet.lnworker.open_channel(connection_string, satoshis(amount), satoshis(channel_push), password)
return f.result(5)
@command('wn')
def reestablish_channel(self):

54
electrum/gui/qt/channels_list.py

@ -5,8 +5,7 @@ from PyQt5.QtWidgets import *
from electrum.util import inv_dict, bh2u, bfh
from electrum.i18n import _
from electrum.lnhtlc import HTLCStateMachine
from electrum.lnaddr import lndecode
from electrum.lnutil import LOCAL, REMOTE
from electrum.lnutil import LOCAL, REMOTE, ConnStringFormatError
from .util import MyTreeWidget, SortableTreeWidgetItem, WindowModalDialog, Buttons, OkButton, CancelButton
from .amountedit import BTCAmountEdit
@ -108,55 +107,12 @@ class ChannelsList(MyTreeWidget):
return
local_amt = local_amt_inp.get_amount()
push_amt = push_amt_inp.get_amount()
connect_contents = str(remote_nodeid.text())
nodeid_hex, rest = self.parse_connect_contents(connect_contents)
try:
node_id = bfh(nodeid_hex)
assert len(node_id) == 33
except:
self.parent.show_error(_('Invalid node ID, must be 33 bytes and hexadecimal'))
return
peer = lnworker.peers.get(node_id)
if not peer:
all_nodes = self.parent.network.channel_db.nodes
node_info = all_nodes.get(node_id, None)
if rest is not None:
try:
host, port = rest.split(":")
except ValueError:
self.parent.show_error(_('Connection strings must be in <node_pubkey>@<host>:<port> format'))
return
elif node_info:
host, port = node_info.addresses[0]
else:
self.parent.show_error(_('Unknown node:') + ' ' + nodeid_hex)
return
try:
int(port)
except:
self.parent.show_error(_('Port number must be decimal'))
return
lnworker.add_peer(host, port, node_id)
self.main_window.protect(self.open_channel, (node_id, local_amt, push_amt))
connect_contents = str(remote_nodeid.text()).strip()
@classmethod
def parse_connect_contents(cls, connect_contents: str):
rest = None
try:
# connection string?
nodeid_hex, rest = connect_contents.split("@")
except ValueError:
try:
# invoice?
invoice = lndecode(connect_contents)
nodeid_bytes = invoice.pubkey.serialize()
nodeid_hex = bh2u(nodeid_bytes)
except:
# node id as hex?
nodeid_hex = connect_contents
return nodeid_hex, rest
self.main_window.protect(self.open_channel, (connect_contents, local_amt, push_amt))
except ConnStringFormatError as e:
self.parent.show_error(str(e))
def open_channel(self, *args, **kwargs):
self.parent.wallet.lnworker.open_channel(*args, **kwargs)

47
electrum/lnutil.py

@ -2,6 +2,7 @@ from enum import IntFlag
import json
from collections import namedtuple
from typing import NamedTuple, List, Tuple
import re
from .util import bfh, bh2u, inv_dict
from .crypto import sha256
@ -11,6 +12,7 @@ from . import ecc, bitcoin, crypto, transaction
from .transaction import opcodes, TxOutput
from .bitcoin import push_script
from . import segwit_addr
from .i18n import _
HTLC_TIMEOUT_WEIGHT = 663
HTLC_SUCCESS_WEIGHT = 703
@ -478,3 +480,48 @@ def make_closing_tx(local_funding_pubkey: bytes, remote_funding_pubkey: bytes,
c_input['sequence'] = 0xFFFF_FFFF
tx = Transaction.from_io([c_input], outputs, locktime=0, version=2)
return tx
class ConnStringFormatError(Exception):
pass
def split_host_port(host_port: str) -> Tuple[str, str]: # port returned as string
ipv6 = re.compile(r'\[(?P<host>[:0-9]+)\](?P<port>:\d+)?$')
other = re.compile(r'(?P<host>[^:]+)(?P<port>:\d+)?$')
m = ipv6.match(host_port)
if not m:
m = other.match(host_port)
if not m:
raise ConnStringFormatError(_('Connection strings must be in <node_pubkey>@<host>:<port> format'))
host = m.group('host')
if m.group('port'):
port = m.group('port')[1:]
else:
port = '9735'
try:
int(port)
except ValueError:
raise ConnStringFormatError(_('Port number must be decimal'))
return host, port
def extract_nodeid(connect_contents: str) -> Tuple[bytes, str]:
rest = None
try:
# connection string?
nodeid_hex, rest = connect_contents.split("@", 1)
except ValueError:
try:
# invoice?
invoice = lndecode(connect_contents)
nodeid_bytes = invoice.pubkey.serialize()
nodeid_hex = bh2u(nodeid_bytes)
except:
# node id as hex?
nodeid_hex = connect_contents
if rest == '':
raise ConnStringFormatError(_('At least a hostname must be supplied after the at symbol.'))
try:
node_id = bfh(nodeid_hex)
assert len(node_id) == 33
except:
raise ConnStringFormatError(_('Invalid node ID, must be 33 bytes and hexadecimal'))
return node_id, rest

49
electrum/lnworker.py

@ -3,9 +3,10 @@ import os
from decimal import Decimal
import random
import time
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple, List
import threading
from functools import partial
import socket
import dns.resolver
import dns.exception
@ -17,8 +18,10 @@ from .lnbase import Peer, privkey_to_pubkey, aiosafe
from .lnaddr import lnencode, LnAddr, lndecode
from .ecc import der_sig_from_sig_string
from .lnhtlc import HTLCStateMachine
from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr, get_compressed_pubkey_from_bech32,
PaymentFailure)
from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
get_compressed_pubkey_from_bech32, extract_nodeid,
PaymentFailure, split_host_port, ConnStringFormatError)
from electrum.lnaddr import lndecode
from .i18n import _
@ -30,7 +33,6 @@ FALLBACK_NODE_LIST = (
LNPeerAddr('ecdsa.net', 9735, bfh('038370f0e7a03eded3e1d41dc081084a87f0afa1c5b22090b4f3abb391eb15d8ff')),
)
class LNWorker(PrintError):
def __init__(self, wallet, network):
@ -89,6 +91,7 @@ class LNWorker(PrintError):
asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(peer.main_loop()), self.network.asyncio_loop)
self.peers[node_id] = peer
self.network.trigger_callback('ln_status')
return peer
def save_channel(self, openchannel):
assert type(openchannel) is HTLCStateMachine
@ -154,8 +157,10 @@ class LNWorker(PrintError):
conf = self.wallet.get_tx_height(chan.funding_outpoint.txid).conf
peer.on_network_update(chan, conf)
async def _open_channel_coroutine(self, node_id, local_amount_sat, push_sat, password):
peer = self.peers[node_id]
async def _open_channel_coroutine(self, peer, local_amount_sat, push_sat, password):
# peer might just have been connected to
await asyncio.wait_for(peer.initialized, 5)
openingchannel = await peer.channel_establishment_flow(self.wallet, self.config, password,
funding_sat=local_amount_sat + push_sat,
push_msat=push_sat * 1000,
@ -171,8 +176,34 @@ class LNWorker(PrintError):
def on_channels_updated(self):
self.network.trigger_callback('channels')
def open_channel(self, node_id, local_amt_sat, push_amt_sat, pw):
coro = self._open_channel_coroutine(node_id, local_amt_sat, push_amt_sat, None if pw == "" else pw)
@staticmethod
def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
for host, port in addr_list:
if is_ip_address(host):
return host, port
# TODO maybe filter out onion if not on tor?
self.print_error('Chose random address from ' + str(node_info.addresses))
return random.choice(node_info.addresses)
def open_channel(self, connect_contents, local_amt_sat, push_amt_sat, pw):
node_id, rest = extract_nodeid(connect_contents)
peer = self.peers.get(node_id)
if not peer:
all_nodes = self.network.channel_db.nodes
node_info = all_nodes.get(node_id, None)
if rest is not None:
host, port = split_host_port(rest)
elif node_info and len(node_info.addresses) > 0:
host, port = self.choose_preferred_address(node_info.addresses)
else:
raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id))
try:
socket.getaddrinfo(host, int(port))
except socket.gaierror:
raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
peer = self.add_peer(host, port, node_id)
coro = self._open_channel_coroutine(peer, local_amt_sat, push_amt_sat, None if pw == "" else pw)
return asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
def pay(self, invoice, amount_sat=None):
@ -262,7 +293,7 @@ class LNWorker(PrintError):
if node is None: continue
addresses = node.addresses
if not addresses: continue
host, port = random.choice(addresses)
host, port = self.choose_preferred_address(addresses)
peer = LNPeerAddr(host, port, node_id)
if peer.pubkey in self.peers: continue
if peer in self._last_tried_peer: continue

46
electrum/tests/test_lnutil.py

@ -5,7 +5,9 @@ from electrum.lnutil import (RevocationStore, get_per_commitment_secret_from_see
make_received_htlc, make_commitment, make_htlc_tx_witness, make_htlc_tx_output,
make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey,
derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret,
get_compressed_pubkey_from_bech32)
get_compressed_pubkey_from_bech32, split_host_port, ConnStringFormatError,
ScriptHtlc, extract_nodeid)
from electrum import lnhtlc
from electrum.util import bh2u, bfh
from electrum.transaction import Transaction
@ -488,13 +490,14 @@ class TestLNUtil(unittest.TestCase):
remote_signature = "304402204fd4928835db1ccdfc40f5c78ce9bd65249b16348df81f0c44328dcdefc97d630220194d3869c38bc732dd87d13d2958015e2fc16829e74cd4377f84d215c0b70606"
output_commit_tx = "02000000000101bef67e4e2fb9ddeeb3461973cd4c62abb35050b1add772995b820b584a488489000000000038b02b8007e80300000000000022002052bfef0479d7b293c27e0f1eb294bea154c63a3294ef092c19af51409bce0e2ad007000000000000220020403d394747cae42e98ff01734ad5c08f82ba123d3d9a620abda88989651e2ab5d007000000000000220020748eba944fedc8827f6b06bc44678f93c0f9e6078b35c6331ed31e75f8ce0c2db80b000000000000220020c20b5d1f8584fd90443e7b7b720136174fa4b9333c261d04dbbd012635c0f419a00f0000000000002200208c48d15160397c9731df9bc3b236656efb6665fbfe92b4a6878e88a499f741c4c0c62d0000000000160014ccf1af2f2aabee14bb40fa3851ab2301de843110e0a06a00000000002200204adb4e2f00643db396dd120d4e7dc17625f5f2c11a40d857accc862d6b7dd80e04004730440220275b0c325a5e9355650dc30c0eccfbc7efb23987c24b556b9dfdd40effca18d202206caceb2c067836c51f296740c7ae807ffcbfbf1dd3a0d56b6de9a5b247985f060147304402204fd4928835db1ccdfc40f5c78ce9bd65249b16348df81f0c44328dcdefc97d630220194d3869c38bc732dd87d13d2958015e2fc16829e74cd4377f84d215c0b7060601475221023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb21030e9f7b623d2ccc7c9bd44d66d5ce21ce504c0acf6385a132cec6d3c39fa711c152ae3e195220"
htlc_msat = {}
htlc_msat[0] = 1000 * 1000
htlc_msat[2] = 2000 * 1000
htlc_msat[1] = 2000 * 1000
htlc_msat[3] = 3000 * 1000
htlc_msat[4] = 4000 * 1000
htlcs = [(htlc[x], htlc_msat[x]) for x in range(5)]
htlc_obj = {}
for num, msat in [(0, 1000 * 1000),
(2, 2000 * 1000),
(1, 2000 * 1000),
(3, 3000 * 1000),
(4, 4000 * 1000)]:
htlc_obj[num] = lnhtlc.UpdateAddHtlc(amount_msat=msat, payment_hash=bitcoin.sha256(htlc_payment_preimage[num]), cltv_expiry=None, htlc_id=None)
htlcs = [ScriptHtlc(htlc[x], htlc_obj[x]) for x in range(5)]
our_commit_tx = make_commitment(
commitment_number,
@ -531,7 +534,7 @@ class TestLNUtil(unittest.TestCase):
for i in range(5):
self.assertEqual(output_htlc_tx[i][1], self.htlc_tx(htlc[i], htlc_output_index[i],
htlc_msat[i],
htlcs[i].htlc.amount_msat,
htlc_payment_preimage[i],
signature_for_output_remote_htlc[i],
output_htlc_tx[i][0], htlc_cltv_timeout[i] if not output_htlc_tx[i][0] else 0,
@ -680,3 +683,28 @@ class TestLNUtil(unittest.TestCase):
def test_get_compressed_pubkey_from_bech32(self):
self.assertEqual(b'\x03\x84\xef\x87\xd9d\xa2\xaaa7=\xff\xb8\xfe=t8[}>;\n\x13\xa8e\x8eo:\xf5Mi\xb5H',
get_compressed_pubkey_from_bech32('ln1qwzwlp7evj325cfh8hlm3l3awsu9klf78v9p82r93ehn4a2ddx65s66awg5'))
def test_split_host_port(self):
self.assertEqual(split_host_port("[::1]:8000"), ("::1", "8000"))
self.assertEqual(split_host_port("[::1]"), ("::1", "9735"))
self.assertEqual(split_host_port("kæn.guru:8000"), ("kæn.guru", "8000"))
self.assertEqual(split_host_port("kæn.guru"), ("kæn.guru", "9735"))
self.assertEqual(split_host_port("127.0.0.1:8000"), ("127.0.0.1", "8000"))
self.assertEqual(split_host_port("127.0.0.1"), ("127.0.0.1", "9735"))
# accepted by getaddrinfo but not ipaddress.ip_address
self.assertEqual(split_host_port("127.0.0:8000"), ("127.0.0", "8000"))
self.assertEqual(split_host_port("127.0.0"), ("127.0.0", "9735"))
self.assertEqual(split_host_port("electrum.org:8000"), ("electrum.org", "8000"))
self.assertEqual(split_host_port("electrum.org"), ("electrum.org", "9735"))
with self.assertRaises(ConnStringFormatError):
split_host_port("electrum.org:8000:")
with self.assertRaises(ConnStringFormatError):
split_host_port("electrum.org:")
def test_extract_nodeid(self):
with self.assertRaises(ConnStringFormatError):
extract_nodeid("00" * 32 + "@localhost")
with self.assertRaises(ConnStringFormatError):
extract_nodeid("00" * 33 + "@")
self.assertEqual(extract_nodeid("00" * 33 + "@localhost"), (b"\x00" * 33, "localhost"))

Loading…
Cancel
Save