Browse Source

move connection string decoding to lnworker, fix test_lnutil

regtest_lnd
Janus 6 years ago
committed by SomberNight
parent
commit
92d8373a42
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 6
      electrum/commands.py
  2. 54
      electrum/gui/qt/channels_list.py
  3. 47
      electrum/lnutil.py
  4. 49
      electrum/lnworker.py
  5. 46
      electrum/tests/test_lnutil.py

6
electrum/commands.py

@ -769,9 +769,9 @@ class Commands:
# lightning network commands # lightning network commands
@command('wpn') @command('wpn')
def open_channel(self, node_id, amount, channel_push=0, password=None): def open_channel(self, connection_string, amount, channel_push=0, password=None):
f = self.wallet.lnworker.open_channel(bytes.fromhex(node_id), satoshis(amount), satoshis(channel_push), password) f = self.wallet.lnworker.open_channel(connection_string, satoshis(amount), satoshis(channel_push), password)
return f.result() return f.result(5)
@command('wn') @command('wn')
def reestablish_channel(self): def reestablish_channel(self):

54
electrum/gui/qt/channels_list.py

@ -5,8 +5,7 @@ from PyQt5.QtWidgets import *
from electrum.util import inv_dict, bh2u, bfh from electrum.util import inv_dict, bh2u, bfh
from electrum.i18n import _ from electrum.i18n import _
from electrum.lnhtlc import HTLCStateMachine from electrum.lnhtlc import HTLCStateMachine
from electrum.lnaddr import lndecode from electrum.lnutil import LOCAL, REMOTE, ConnStringFormatError
from electrum.lnutil import LOCAL, REMOTE
from .util import MyTreeWidget, SortableTreeWidgetItem, WindowModalDialog, Buttons, OkButton, CancelButton from .util import MyTreeWidget, SortableTreeWidgetItem, WindowModalDialog, Buttons, OkButton, CancelButton
from .amountedit import BTCAmountEdit from .amountedit import BTCAmountEdit
@ -108,55 +107,12 @@ class ChannelsList(MyTreeWidget):
return return
local_amt = local_amt_inp.get_amount() local_amt = local_amt_inp.get_amount()
push_amt = push_amt_inp.get_amount() push_amt = push_amt_inp.get_amount()
connect_contents = str(remote_nodeid.text()) connect_contents = str(remote_nodeid.text()).strip()
nodeid_hex, rest = self.parse_connect_contents(connect_contents)
try:
node_id = bfh(nodeid_hex)
assert len(node_id) == 33
except:
self.parent.show_error(_('Invalid node ID, must be 33 bytes and hexadecimal'))
return
peer = lnworker.peers.get(node_id)
if not peer:
all_nodes = self.parent.network.channel_db.nodes
node_info = all_nodes.get(node_id, None)
if rest is not None:
try:
host, port = rest.split(":")
except ValueError:
self.parent.show_error(_('Connection strings must be in <node_pubkey>@<host>:<port> format'))
return
elif node_info:
host, port = node_info.addresses[0]
else:
self.parent.show_error(_('Unknown node:') + ' ' + nodeid_hex)
return
try:
int(port)
except:
self.parent.show_error(_('Port number must be decimal'))
return
lnworker.add_peer(host, port, node_id)
self.main_window.protect(self.open_channel, (node_id, local_amt, push_amt))
@classmethod
def parse_connect_contents(cls, connect_contents: str):
rest = None
try:
# connection string?
nodeid_hex, rest = connect_contents.split("@")
except ValueError:
try: try:
# invoice? self.main_window.protect(self.open_channel, (connect_contents, local_amt, push_amt))
invoice = lndecode(connect_contents) except ConnStringFormatError as e:
nodeid_bytes = invoice.pubkey.serialize() self.parent.show_error(str(e))
nodeid_hex = bh2u(nodeid_bytes)
except:
# node id as hex?
nodeid_hex = connect_contents
return nodeid_hex, rest
def open_channel(self, *args, **kwargs): def open_channel(self, *args, **kwargs):
self.parent.wallet.lnworker.open_channel(*args, **kwargs) self.parent.wallet.lnworker.open_channel(*args, **kwargs)

47
electrum/lnutil.py

@ -2,6 +2,7 @@ from enum import IntFlag
import json import json
from collections import namedtuple from collections import namedtuple
from typing import NamedTuple, List, Tuple from typing import NamedTuple, List, Tuple
import re
from .util import bfh, bh2u, inv_dict from .util import bfh, bh2u, inv_dict
from .crypto import sha256 from .crypto import sha256
@ -11,6 +12,7 @@ from . import ecc, bitcoin, crypto, transaction
from .transaction import opcodes, TxOutput from .transaction import opcodes, TxOutput
from .bitcoin import push_script from .bitcoin import push_script
from . import segwit_addr from . import segwit_addr
from .i18n import _
HTLC_TIMEOUT_WEIGHT = 663 HTLC_TIMEOUT_WEIGHT = 663
HTLC_SUCCESS_WEIGHT = 703 HTLC_SUCCESS_WEIGHT = 703
@ -478,3 +480,48 @@ def make_closing_tx(local_funding_pubkey: bytes, remote_funding_pubkey: bytes,
c_input['sequence'] = 0xFFFF_FFFF c_input['sequence'] = 0xFFFF_FFFF
tx = Transaction.from_io([c_input], outputs, locktime=0, version=2) tx = Transaction.from_io([c_input], outputs, locktime=0, version=2)
return tx return tx
class ConnStringFormatError(Exception):
pass
def split_host_port(host_port: str) -> Tuple[str, str]: # port returned as string
ipv6 = re.compile(r'\[(?P<host>[:0-9]+)\](?P<port>:\d+)?$')
other = re.compile(r'(?P<host>[^:]+)(?P<port>:\d+)?$')
m = ipv6.match(host_port)
if not m:
m = other.match(host_port)
if not m:
raise ConnStringFormatError(_('Connection strings must be in <node_pubkey>@<host>:<port> format'))
host = m.group('host')
if m.group('port'):
port = m.group('port')[1:]
else:
port = '9735'
try:
int(port)
except ValueError:
raise ConnStringFormatError(_('Port number must be decimal'))
return host, port
def extract_nodeid(connect_contents: str) -> Tuple[bytes, str]:
rest = None
try:
# connection string?
nodeid_hex, rest = connect_contents.split("@", 1)
except ValueError:
try:
# invoice?
invoice = lndecode(connect_contents)
nodeid_bytes = invoice.pubkey.serialize()
nodeid_hex = bh2u(nodeid_bytes)
except:
# node id as hex?
nodeid_hex = connect_contents
if rest == '':
raise ConnStringFormatError(_('At least a hostname must be supplied after the at symbol.'))
try:
node_id = bfh(nodeid_hex)
assert len(node_id) == 33
except:
raise ConnStringFormatError(_('Invalid node ID, must be 33 bytes and hexadecimal'))
return node_id, rest

49
electrum/lnworker.py

@ -3,9 +3,10 @@ import os
from decimal import Decimal from decimal import Decimal
import random import random
import time import time
from typing import Optional, Sequence from typing import Optional, Sequence, Tuple, List
import threading import threading
from functools import partial from functools import partial
import socket
import dns.resolver import dns.resolver
import dns.exception import dns.exception
@ -17,8 +18,10 @@ from .lnbase import Peer, privkey_to_pubkey, aiosafe
from .lnaddr import lnencode, LnAddr, lndecode from .lnaddr import lnencode, LnAddr, lndecode
from .ecc import der_sig_from_sig_string from .ecc import der_sig_from_sig_string
from .lnhtlc import HTLCStateMachine from .lnhtlc import HTLCStateMachine
from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr, get_compressed_pubkey_from_bech32, from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
PaymentFailure) get_compressed_pubkey_from_bech32, extract_nodeid,
PaymentFailure, split_host_port, ConnStringFormatError)
from electrum.lnaddr import lndecode
from .i18n import _ from .i18n import _
@ -30,7 +33,6 @@ FALLBACK_NODE_LIST = (
LNPeerAddr('ecdsa.net', 9735, bfh('038370f0e7a03eded3e1d41dc081084a87f0afa1c5b22090b4f3abb391eb15d8ff')), LNPeerAddr('ecdsa.net', 9735, bfh('038370f0e7a03eded3e1d41dc081084a87f0afa1c5b22090b4f3abb391eb15d8ff')),
) )
class LNWorker(PrintError): class LNWorker(PrintError):
def __init__(self, wallet, network): def __init__(self, wallet, network):
@ -89,6 +91,7 @@ class LNWorker(PrintError):
asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(peer.main_loop()), self.network.asyncio_loop) asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(peer.main_loop()), self.network.asyncio_loop)
self.peers[node_id] = peer self.peers[node_id] = peer
self.network.trigger_callback('ln_status') self.network.trigger_callback('ln_status')
return peer
def save_channel(self, openchannel): def save_channel(self, openchannel):
assert type(openchannel) is HTLCStateMachine assert type(openchannel) is HTLCStateMachine
@ -154,8 +157,10 @@ class LNWorker(PrintError):
conf = self.wallet.get_tx_height(chan.funding_outpoint.txid).conf conf = self.wallet.get_tx_height(chan.funding_outpoint.txid).conf
peer.on_network_update(chan, conf) peer.on_network_update(chan, conf)
async def _open_channel_coroutine(self, node_id, local_amount_sat, push_sat, password): async def _open_channel_coroutine(self, peer, local_amount_sat, push_sat, password):
peer = self.peers[node_id] # peer might just have been connected to
await asyncio.wait_for(peer.initialized, 5)
openingchannel = await peer.channel_establishment_flow(self.wallet, self.config, password, openingchannel = await peer.channel_establishment_flow(self.wallet, self.config, password,
funding_sat=local_amount_sat + push_sat, funding_sat=local_amount_sat + push_sat,
push_msat=push_sat * 1000, push_msat=push_sat * 1000,
@ -171,8 +176,34 @@ class LNWorker(PrintError):
def on_channels_updated(self): def on_channels_updated(self):
self.network.trigger_callback('channels') self.network.trigger_callback('channels')
def open_channel(self, node_id, local_amt_sat, push_amt_sat, pw): @staticmethod
coro = self._open_channel_coroutine(node_id, local_amt_sat, push_amt_sat, None if pw == "" else pw) def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
for host, port in addr_list:
if is_ip_address(host):
return host, port
# TODO maybe filter out onion if not on tor?
self.print_error('Chose random address from ' + str(node_info.addresses))
return random.choice(node_info.addresses)
def open_channel(self, connect_contents, local_amt_sat, push_amt_sat, pw):
node_id, rest = extract_nodeid(connect_contents)
peer = self.peers.get(node_id)
if not peer:
all_nodes = self.network.channel_db.nodes
node_info = all_nodes.get(node_id, None)
if rest is not None:
host, port = split_host_port(rest)
elif node_info and len(node_info.addresses) > 0:
host, port = self.choose_preferred_address(node_info.addresses)
else:
raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id))
try:
socket.getaddrinfo(host, int(port))
except socket.gaierror:
raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
peer = self.add_peer(host, port, node_id)
coro = self._open_channel_coroutine(peer, local_amt_sat, push_amt_sat, None if pw == "" else pw)
return asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) return asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
def pay(self, invoice, amount_sat=None): def pay(self, invoice, amount_sat=None):
@ -262,7 +293,7 @@ class LNWorker(PrintError):
if node is None: continue if node is None: continue
addresses = node.addresses addresses = node.addresses
if not addresses: continue if not addresses: continue
host, port = random.choice(addresses) host, port = self.choose_preferred_address(addresses)
peer = LNPeerAddr(host, port, node_id) peer = LNPeerAddr(host, port, node_id)
if peer.pubkey in self.peers: continue if peer.pubkey in self.peers: continue
if peer in self._last_tried_peer: continue if peer in self._last_tried_peer: continue

46
electrum/tests/test_lnutil.py

@ -5,7 +5,9 @@ from electrum.lnutil import (RevocationStore, get_per_commitment_secret_from_see
make_received_htlc, make_commitment, make_htlc_tx_witness, make_htlc_tx_output, make_received_htlc, make_commitment, make_htlc_tx_witness, make_htlc_tx_output,
make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey, make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey,
derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret, derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret,
get_compressed_pubkey_from_bech32) get_compressed_pubkey_from_bech32, split_host_port, ConnStringFormatError,
ScriptHtlc, extract_nodeid)
from electrum import lnhtlc
from electrum.util import bh2u, bfh from electrum.util import bh2u, bfh
from electrum.transaction import Transaction from electrum.transaction import Transaction
@ -488,13 +490,14 @@ class TestLNUtil(unittest.TestCase):
remote_signature = "304402204fd4928835db1ccdfc40f5c78ce9bd65249b16348df81f0c44328dcdefc97d630220194d3869c38bc732dd87d13d2958015e2fc16829e74cd4377f84d215c0b70606" remote_signature = "304402204fd4928835db1ccdfc40f5c78ce9bd65249b16348df81f0c44328dcdefc97d630220194d3869c38bc732dd87d13d2958015e2fc16829e74cd4377f84d215c0b70606"
output_commit_tx = "02000000000101bef67e4e2fb9ddeeb3461973cd4c62abb35050b1add772995b820b584a488489000000000038b02b8007e80300000000000022002052bfef0479d7b293c27e0f1eb294bea154c63a3294ef092c19af51409bce0e2ad007000000000000220020403d394747cae42e98ff01734ad5c08f82ba123d3d9a620abda88989651e2ab5d007000000000000220020748eba944fedc8827f6b06bc44678f93c0f9e6078b35c6331ed31e75f8ce0c2db80b000000000000220020c20b5d1f8584fd90443e7b7b720136174fa4b9333c261d04dbbd012635c0f419a00f0000000000002200208c48d15160397c9731df9bc3b236656efb6665fbfe92b4a6878e88a499f741c4c0c62d0000000000160014ccf1af2f2aabee14bb40fa3851ab2301de843110e0a06a00000000002200204adb4e2f00643db396dd120d4e7dc17625f5f2c11a40d857accc862d6b7dd80e04004730440220275b0c325a5e9355650dc30c0eccfbc7efb23987c24b556b9dfdd40effca18d202206caceb2c067836c51f296740c7ae807ffcbfbf1dd3a0d56b6de9a5b247985f060147304402204fd4928835db1ccdfc40f5c78ce9bd65249b16348df81f0c44328dcdefc97d630220194d3869c38bc732dd87d13d2958015e2fc16829e74cd4377f84d215c0b7060601475221023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb21030e9f7b623d2ccc7c9bd44d66d5ce21ce504c0acf6385a132cec6d3c39fa711c152ae3e195220" output_commit_tx = "02000000000101bef67e4e2fb9ddeeb3461973cd4c62abb35050b1add772995b820b584a488489000000000038b02b8007e80300000000000022002052bfef0479d7b293c27e0f1eb294bea154c63a3294ef092c19af51409bce0e2ad007000000000000220020403d394747cae42e98ff01734ad5c08f82ba123d3d9a620abda88989651e2ab5d007000000000000220020748eba944fedc8827f6b06bc44678f93c0f9e6078b35c6331ed31e75f8ce0c2db80b000000000000220020c20b5d1f8584fd90443e7b7b720136174fa4b9333c261d04dbbd012635c0f419a00f0000000000002200208c48d15160397c9731df9bc3b236656efb6665fbfe92b4a6878e88a499f741c4c0c62d0000000000160014ccf1af2f2aabee14bb40fa3851ab2301de843110e0a06a00000000002200204adb4e2f00643db396dd120d4e7dc17625f5f2c11a40d857accc862d6b7dd80e04004730440220275b0c325a5e9355650dc30c0eccfbc7efb23987c24b556b9dfdd40effca18d202206caceb2c067836c51f296740c7ae807ffcbfbf1dd3a0d56b6de9a5b247985f060147304402204fd4928835db1ccdfc40f5c78ce9bd65249b16348df81f0c44328dcdefc97d630220194d3869c38bc732dd87d13d2958015e2fc16829e74cd4377f84d215c0b7060601475221023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb21030e9f7b623d2ccc7c9bd44d66d5ce21ce504c0acf6385a132cec6d3c39fa711c152ae3e195220"
htlc_msat = {} htlc_obj = {}
htlc_msat[0] = 1000 * 1000 for num, msat in [(0, 1000 * 1000),
htlc_msat[2] = 2000 * 1000 (2, 2000 * 1000),
htlc_msat[1] = 2000 * 1000 (1, 2000 * 1000),
htlc_msat[3] = 3000 * 1000 (3, 3000 * 1000),
htlc_msat[4] = 4000 * 1000 (4, 4000 * 1000)]:
htlcs = [(htlc[x], htlc_msat[x]) for x in range(5)] htlc_obj[num] = lnhtlc.UpdateAddHtlc(amount_msat=msat, payment_hash=bitcoin.sha256(htlc_payment_preimage[num]), cltv_expiry=None, htlc_id=None)
htlcs = [ScriptHtlc(htlc[x], htlc_obj[x]) for x in range(5)]
our_commit_tx = make_commitment( our_commit_tx = make_commitment(
commitment_number, commitment_number,
@ -531,7 +534,7 @@ class TestLNUtil(unittest.TestCase):
for i in range(5): for i in range(5):
self.assertEqual(output_htlc_tx[i][1], self.htlc_tx(htlc[i], htlc_output_index[i], self.assertEqual(output_htlc_tx[i][1], self.htlc_tx(htlc[i], htlc_output_index[i],
htlc_msat[i], htlcs[i].htlc.amount_msat,
htlc_payment_preimage[i], htlc_payment_preimage[i],
signature_for_output_remote_htlc[i], signature_for_output_remote_htlc[i],
output_htlc_tx[i][0], htlc_cltv_timeout[i] if not output_htlc_tx[i][0] else 0, output_htlc_tx[i][0], htlc_cltv_timeout[i] if not output_htlc_tx[i][0] else 0,
@ -680,3 +683,28 @@ class TestLNUtil(unittest.TestCase):
def test_get_compressed_pubkey_from_bech32(self): def test_get_compressed_pubkey_from_bech32(self):
self.assertEqual(b'\x03\x84\xef\x87\xd9d\xa2\xaaa7=\xff\xb8\xfe=t8[}>;\n\x13\xa8e\x8eo:\xf5Mi\xb5H', self.assertEqual(b'\x03\x84\xef\x87\xd9d\xa2\xaaa7=\xff\xb8\xfe=t8[}>;\n\x13\xa8e\x8eo:\xf5Mi\xb5H',
get_compressed_pubkey_from_bech32('ln1qwzwlp7evj325cfh8hlm3l3awsu9klf78v9p82r93ehn4a2ddx65s66awg5')) get_compressed_pubkey_from_bech32('ln1qwzwlp7evj325cfh8hlm3l3awsu9klf78v9p82r93ehn4a2ddx65s66awg5'))
def test_split_host_port(self):
self.assertEqual(split_host_port("[::1]:8000"), ("::1", "8000"))
self.assertEqual(split_host_port("[::1]"), ("::1", "9735"))
self.assertEqual(split_host_port("kæn.guru:8000"), ("kæn.guru", "8000"))
self.assertEqual(split_host_port("kæn.guru"), ("kæn.guru", "9735"))
self.assertEqual(split_host_port("127.0.0.1:8000"), ("127.0.0.1", "8000"))
self.assertEqual(split_host_port("127.0.0.1"), ("127.0.0.1", "9735"))
# accepted by getaddrinfo but not ipaddress.ip_address
self.assertEqual(split_host_port("127.0.0:8000"), ("127.0.0", "8000"))
self.assertEqual(split_host_port("127.0.0"), ("127.0.0", "9735"))
self.assertEqual(split_host_port("electrum.org:8000"), ("electrum.org", "8000"))
self.assertEqual(split_host_port("electrum.org"), ("electrum.org", "9735"))
with self.assertRaises(ConnStringFormatError):
split_host_port("electrum.org:8000:")
with self.assertRaises(ConnStringFormatError):
split_host_port("electrum.org:")
def test_extract_nodeid(self):
with self.assertRaises(ConnStringFormatError):
extract_nodeid("00" * 32 + "@localhost")
with self.assertRaises(ConnStringFormatError):
extract_nodeid("00" * 33 + "@")
self.assertEqual(extract_nodeid("00" * 33 + "@localhost"), (b"\x00" * 33, "localhost"))

Loading…
Cancel
Save