Browse Source

persist nodes in channel_db on disk

regtest_lnd
SomberNight 7 years ago
parent
commit
897447f40b
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 19
      electrum/gui/qt/channels_list.py
  2. 35
      electrum/lnbase.py
  3. 133
      electrum/lnrouter.py
  4. 18
      electrum/lnutil.py
  5. 9
      electrum/lnworker.py
  6. 1
      electrum/network.py
  7. 21
      electrum/tests/test_util.py
  8. 20
      electrum/util.py

19
electrum/gui/qt/channels_list.py

@ -64,10 +64,12 @@ class ChannelsList(MyTreeWidget):
return h return h
def update_status(self): def update_status(self):
n = len(self.parent.network.lightning_nodes) channel_db = self.parent.network.channel_db
nc = len(self.parent.network.channel_db) num_nodes = len(channel_db.nodes)
np = len(self.parent.wallet.lnworker.peers) num_channels = len(channel_db)
self.status.setText(_('{} peers, {} nodes, {} channels').format(np, n, nc)) num_peers = len(self.parent.wallet.lnworker.peers)
self.status.setText(_('{} peers, {} nodes, {} channels')
.format(num_peers, num_nodes, num_channels))
def new_channel_dialog(self): def new_channel_dialog(self):
lnworker = self.parent.wallet.lnworker lnworker = self.parent.wallet.lnworker
@ -116,15 +118,16 @@ class ChannelsList(MyTreeWidget):
peer = lnworker.peers.get(node_id) peer = lnworker.peers.get(node_id)
if not peer: if not peer:
known = node_id in self.parent.network.lightning_nodes all_nodes = self.parent.network.channel_db.nodes
node_info = all_nodes.get(node_id, None)
if rest is not None: if rest is not None:
try: try:
host, port = rest.split(":") host, port = rest.split(":")
except ValueError: except ValueError:
self.parent.show_error(_('Connection strings must be in <node_pubkey>@<host>:<port> format')) self.parent.show_error(_('Connection strings must be in <node_pubkey>@<host>:<port> format'))
elif known: return
node = self.network.lightning_nodes.get(node_id) elif node_info:
host, port = node['addresses'][0] host, port = node_info.addresses[0]
else: else:
self.parent.show_error(_('Unknown node:') + ' ' + nodeid_hex) self.parent.show_error(_('Unknown node:') + ' ' + nodeid_hex)
return return

35
electrum/lnbase.py

@ -29,7 +29,7 @@ from . import crypto
from .crypto import sha256 from .crypto import sha256
from . import constants from . import constants
from . import transaction from . import transaction
from .util import PrintError, bh2u, print_error, bfh, profiler, xor_bytes from .util import PrintError, bh2u, print_error, bfh
from .transaction import opcodes, Transaction from .transaction import opcodes, Transaction
from .lnonion import new_onion_packet, OnionHopsDataSingle, OnionPerHop, decode_onion_error from .lnonion import new_onion_packet, OnionHopsDataSingle, OnionPerHop, decode_onion_error
from .lnaddr import lndecode from .lnaddr import lndecode
@ -428,38 +428,7 @@ class Peer(PrintError):
self.funding_signed[channel_id].put_nowait(payload) self.funding_signed[channel_id].put_nowait(payload)
def on_node_announcement(self, payload): def on_node_announcement(self, payload):
pubkey = payload['node_id'] self.channel_db.on_node_announcement(payload)
signature = payload['signature']
h = bitcoin.Hash(payload['raw'][66:])
if not ecc.verify_signature(pubkey, signature, h):
return False
self.s = payload['addresses']
def read(n):
data, self.s = self.s[0:n], self.s[n:]
return data
addresses = []
while self.s:
atype = ord(read(1))
if atype == 0:
pass
elif atype == 1:
ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
port = int.from_bytes(read(2), 'big')
x = ipv4_addr, port, binascii.hexlify(pubkey)
addresses.append((ipv4_addr, port))
elif atype == 2:
ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(4)])
port = int.from_bytes(read(2), 'big')
addresses.append((ipv6_addr, port))
else:
pass
continue
alias = payload['alias'].rstrip(b'\x00')
self.network.lightning_nodes[pubkey] = {
'alias': alias,
'addresses': addresses
}
#self.print_error('node announcement', binascii.hexlify(pubkey), alias, addresses)
self.network.trigger_callback('ln_status') self.network.trigger_callback('ln_status')
def on_init(self, payload): def on_init(self, payload):

133
electrum/lnrouter.py

@ -29,17 +29,31 @@ import json
import threading import threading
from collections import namedtuple, defaultdict from collections import namedtuple, defaultdict
from typing import Sequence, Union, Tuple, Optional from typing import Sequence, Union, Tuple, Optional
import binascii
import base64
from . import constants from . import constants
from .util import PrintError, bh2u, profiler, get_headers_dir, bfh from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
from .storage import JsonDB from .storage import JsonDB
from .lnchanannverifier import LNChanAnnVerifier, verify_sig_for_channel_update from .lnchanannverifier import LNChanAnnVerifier, verify_sig_for_channel_update
from .crypto import Hash
from . import ecc
from .lnutil import LN_GLOBAL_FEATURE_BITS
class UnknownEvenFeatureBits(Exception): pass
class ChannelInfo(PrintError): class ChannelInfo(PrintError):
def __init__(self, channel_announcement_payload): def __init__(self, channel_announcement_payload):
self.features_len = channel_announcement_payload['len']
self.features = channel_announcement_payload['features']
enabled_features = list_enabled_bits(int.from_bytes(self.features, "big"))
for fbit in enabled_features:
if fbit not in LN_GLOBAL_FEATURE_BITS and fbit % 2 == 0:
raise UnknownEvenFeatureBits()
self.channel_id = channel_announcement_payload['short_channel_id'] self.channel_id = channel_announcement_payload['short_channel_id']
self.node_id_1 = channel_announcement_payload['node_id_1'] self.node_id_1 = channel_announcement_payload['node_id_1']
self.node_id_2 = channel_announcement_payload['node_id_2'] self.node_id_2 = channel_announcement_payload['node_id_2']
@ -47,8 +61,6 @@ class ChannelInfo(PrintError):
assert type(self.node_id_2) is bytes assert type(self.node_id_2) is bytes
assert list(sorted([self.node_id_1, self.node_id_2])) == [self.node_id_1, self.node_id_2] assert list(sorted([self.node_id_1, self.node_id_2])) == [self.node_id_1, self.node_id_2]
self.features_len = channel_announcement_payload['len']
self.features = channel_announcement_payload['features']
self.bitcoin_key_1 = channel_announcement_payload['bitcoin_key_1'] self.bitcoin_key_1 = channel_announcement_payload['bitcoin_key_1']
self.bitcoin_key_2 = channel_announcement_payload['bitcoin_key_2'] self.bitcoin_key_2 = channel_announcement_payload['bitcoin_key_2']
@ -162,6 +174,86 @@ class ChannelInfoDirectedPolicy:
return ChannelInfoDirectedPolicy(d2) return ChannelInfoDirectedPolicy(d2)
class NodeInfo(PrintError):
def __init__(self, node_announcement_payload, addresses_already_parsed=False):
self.pubkey = node_announcement_payload['node_id']
self.features_len = node_announcement_payload['flen']
self.features = node_announcement_payload['features']
enabled_features = list_enabled_bits(int.from_bytes(self.features, "big"))
for fbit in enabled_features:
if fbit not in LN_GLOBAL_FEATURE_BITS and fbit % 2 == 0:
raise UnknownEvenFeatureBits()
if not addresses_already_parsed:
self.addresses = self.parse_addresses_field(node_announcement_payload['addresses'])
else:
self.addresses = node_announcement_payload['addresses']
self.alias = node_announcement_payload['alias'].rstrip(b'\x00')
self.timestamp = int.from_bytes(node_announcement_payload['timestamp'], "big")
@classmethod
def parse_addresses_field(cls, addresses_field):
buf = addresses_field
def read(n):
nonlocal buf
data, buf = buf[0:n], buf[n:]
return data
addresses = []
while buf:
atype = ord(read(1))
if atype == 0:
pass
elif atype == 1: # IPv4
ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
port = int.from_bytes(read(2), 'big')
if is_ip_address(ipv4_addr) and port != 0:
addresses.append((ipv4_addr, port))
elif atype == 2: # IPv6
ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)])
ipv6_addr = ipv6_addr.decode('ascii')
port = int.from_bytes(read(2), 'big')
if is_ip_address(ipv6_addr) and port != 0:
addresses.append((ipv6_addr, port))
elif atype == 3: # onion v2
host = base64.b32encode(read(10)) + b'.onion'
host = host.decode('ascii').lower()
port = int.from_bytes(read(2), 'big')
addresses.append((host, port))
elif atype == 4: # onion v3
host = base64.b32encode(read(35)) + b'.onion'
host = host.decode('ascii').lower()
port = int.from_bytes(read(2), 'big')
addresses.append((host, port))
else:
# unknown address type
# we don't know how long it is -> have to escape
# if there are other addresses we could have parsed later, they are lost.
break
return addresses
def to_json(self) -> dict:
d = {}
d['node_id'] = bh2u(self.pubkey)
d['flen'] = bh2u(self.features_len)
d['features'] = bh2u(self.features)
d['addresses'] = self.addresses
d['alias'] = bh2u(self.alias)
d['timestamp'] = self.timestamp
return d
@classmethod
def from_json(cls, d: dict):
if d is None: return None
d2 = {}
d2['node_id'] = bfh(d['node_id'])
d2['flen'] = bfh(d['flen'])
d2['features'] = bfh(d['features'])
d2['addresses'] = d['addresses']
d2['alias'] = bfh(d['alias'])
d2['timestamp'] = d['timestamp'].to_bytes(4, "big")
return NodeInfo(d2, addresses_already_parsed=True)
class ChannelDB(JsonDB): class ChannelDB(JsonDB):
def __init__(self, network): def __init__(self, network):
@ -173,6 +265,7 @@ class ChannelDB(JsonDB):
self.lock = threading.Lock() self.lock = threading.Lock()
self._id_to_channel_info = {} self._id_to_channel_info = {}
self._channels_for_node = defaultdict(set) # node -> set(short_channel_id) self._channels_for_node = defaultdict(set) # node -> set(short_channel_id)
self.nodes = {} # node_id -> NodeInfo
self.ca_verifier = LNChanAnnVerifier(network, self) self.ca_verifier = LNChanAnnVerifier(network, self)
self.network.add_jobs([self.ca_verifier]) self.network.add_jobs([self.ca_verifier])
@ -184,21 +277,35 @@ class ChannelDB(JsonDB):
with open(self.path, "r", encoding='utf-8') as f: with open(self.path, "r", encoding='utf-8') as f:
raw = f.read() raw = f.read()
self.data = json.loads(raw) self.data = json.loads(raw)
# channels
channel_infos = self.get('channel_infos', {}) channel_infos = self.get('channel_infos', {})
for short_channel_id, channel_info_d in channel_infos.items(): for short_channel_id, channel_info_d in channel_infos.items():
channel_info = ChannelInfo.from_json(channel_info_d) channel_info = ChannelInfo.from_json(channel_info_d)
short_channel_id = bfh(short_channel_id) short_channel_id = bfh(short_channel_id)
self.add_verified_channel_info(short_channel_id, channel_info) self.add_verified_channel_info(short_channel_id, channel_info)
# nodes
node_infos = self.get('node_infos', {})
for node_id, node_info_d in node_infos.items():
node_info = NodeInfo.from_json(node_info_d)
node_id = bfh(node_id)
self.nodes[node_id] = node_info
def save_data(self): def save_data(self):
with self.lock: with self.lock:
# channels
channel_infos = {} channel_infos = {}
for short_channel_id, channel_info in self._id_to_channel_info.items(): for short_channel_id, channel_info in self._id_to_channel_info.items():
channel_infos[bh2u(short_channel_id)] = channel_info channel_infos[bh2u(short_channel_id)] = channel_info
self.put('channel_infos', channel_infos) self.put('channel_infos', channel_infos)
# nodes
node_infos = {}
for node_id, node_info in self.nodes.items():
node_infos[bh2u(node_id)] = node_info
self.put('node_infos', node_infos)
self.write() self.write()
def __len__(self): def __len__(self):
# number of channels
return len(self._id_to_channel_info) return len(self._id_to_channel_info)
def get_channel_info(self, channel_id) -> Optional[ChannelInfo]: def get_channel_info(self, channel_id) -> Optional[ChannelInfo]:
@ -220,7 +327,10 @@ class ChannelDB(JsonDB):
return return
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']: if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
return return
try:
channel_info = ChannelInfo(msg_payload) channel_info = ChannelInfo(msg_payload)
except UnknownEvenFeatureBits:
return
if trusted: if trusted:
self.add_verified_channel_info(short_channel_id, channel_info) self.add_verified_channel_info(short_channel_id, channel_info)
else: else:
@ -244,6 +354,21 @@ class ChannelDB(JsonDB):
return return
channel_info.on_channel_update(msg_payload, trusted=trusted) channel_info.on_channel_update(msg_payload, trusted=trusted)
def on_node_announcement(self, msg_payload):
pubkey = msg_payload['node_id']
signature = msg_payload['signature']
h = Hash(msg_payload['raw'][66:])
if not ecc.verify_signature(pubkey, signature, h):
return
old_node_info = self.nodes.get(pubkey, None)
try:
new_node_info = NodeInfo(msg_payload)
except UnknownEvenFeatureBits:
return
if old_node_info and old_node_info.timestamp >= new_node_info.timestamp:
return # ignore
self.nodes[pubkey] = new_node_info
def remove_channel(self, short_channel_id): def remove_channel(self, short_channel_id):
try: try:
channel_info = self._id_to_channel_info[short_channel_id] channel_info = self._id_to_channel_info[short_channel_id]

18
electrum/lnutil.py

@ -1,4 +1,4 @@
from .util import bfh, bh2u from .util import bfh, bh2u, inv_dict
from .crypto import sha256 from .crypto import sha256
import json import json
from collections import namedtuple from collections import namedtuple
@ -380,3 +380,19 @@ def overall_weight(num_htlc):
def get_ecdh(priv: bytes, pub: bytes) -> bytes: def get_ecdh(priv: bytes, pub: bytes) -> bytes:
pt = ECPubkey(pub) * string_to_number(priv) pt = ECPubkey(pub) * string_to_number(priv)
return sha256(pt.get_public_key_bytes()) return sha256(pt.get_public_key_bytes())
LN_LOCAL_FEATURE_BITS = {
0: 'option_data_loss_protect_req',
1: 'option_data_loss_protect_opt',
3: 'initial_routing_sync',
4: 'option_upfront_shutdown_script_req',
5: 'option_upfront_shutdown_script_opt',
6: 'gossip_queries_req',
7: 'gossip_queries_opt',
}
LN_LOCAL_FEATURE_BITS_INV = inv_dict(LN_LOCAL_FEATURE_BITS)
LN_GLOBAL_FEATURE_BITS = {}
LN_GLOBAL_FEATURE_BITS_INV = inv_dict(LN_GLOBAL_FEATURE_BITS)

9
electrum/lnworker.py

@ -228,11 +228,12 @@ class LNWorker(PrintError):
self.peers.pop(k) self.peers.pop(k)
if len(self.peers) > 3: if len(self.peers) > 3:
continue continue
if not self.network.lightning_nodes: if not self.network.channel_db.nodes:
continue continue
node_id = random.choice(list(self.network.lightning_nodes.keys())) all_nodes = self.network.channel_db.nodes
node = self.network.lightning_nodes.get(node_id) node_id = random.choice(list(all_nodes))
addresses = node.get('addresses') node = all_nodes.get(node_id)
addresses = node.addresses
if addresses: if addresses:
host, port = addresses[0] host, port = addresses[0]
self.print_error("trying node", bh2u(node_id)) self.print_error("trying node", bh2u(node_id))

1
electrum/network.py

@ -299,7 +299,6 @@ class Network(Logger):
self._set_status('disconnected') self._set_status('disconnected')
# lightning network # lightning network
self.lightning_nodes = {}
self.channel_db = lnrouter.ChannelDB(self) self.channel_db = lnrouter.ChannelDB(self)
self.path_finder = lnrouter.LNPathFinder(self.channel_db) self.path_finder = lnrouter.LNPathFinder(self.channel_db)
self.lnwatcher = lnwatcher.LNWatcher(self) self.lnwatcher = lnwatcher.LNWatcher(self)

21
electrum/tests/test_util.py

@ -1,7 +1,7 @@
from decimal import Decimal from decimal import Decimal
from electrum.util import (format_satoshis, format_fee_satoshis, parse_URI, from electrum.util import (format_satoshis, format_fee_satoshis, parse_URI,
is_hash256_str, chunks) is_hash256_str, chunks, is_ip_address, list_enabled_bits)
from . import SequentialTestCase from . import SequentialTestCase
@ -110,3 +110,22 @@ class TestUtil(SequentialTestCase):
list(chunks([1, 2, 3, 4, 5], 2))) list(chunks([1, 2, 3, 4, 5], 2)))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
list(chunks([1, 2, 3], 0)) list(chunks([1, 2, 3], 0))
def test_list_enabled_bits(self):
self.assertEqual((0, 2, 3, 6), list_enabled_bits(77))
self.assertEqual((), list_enabled_bits(0))
def test_is_ip_address(self):
self.assertTrue(is_ip_address("127.0.0.1"))
self.assertTrue(is_ip_address("127.000.000.1"))
self.assertTrue(is_ip_address("255.255.255.255"))
self.assertFalse(is_ip_address("255.255.256.255"))
self.assertFalse(is_ip_address("123.456.789.000"))
self.assertTrue(is_ip_address("2001:0db8:0000:0000:0000:ff00:0042:8329"))
self.assertTrue(is_ip_address("2001:db8:0:0:0:ff00:42:8329"))
self.assertTrue(is_ip_address("2001:db8::ff00:42:8329"))
self.assertFalse(is_ip_address("2001:::db8::ff00:42:8329"))
self.assertTrue(is_ip_address("::1"))
self.assertFalse(is_ip_address("2001:db8:0:0:g:ff00:42:8329"))
self.assertFalse(is_ip_address("lol"))
self.assertFalse(is_ip_address(":@ASD:@AS\x77\x22\xff¬!"))

20
electrum/util.py

@ -23,7 +23,7 @@
import binascii import binascii
import os, sys, re, json import os, sys, re, json
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from typing import NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any from typing import NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any, Sequence
from datetime import datetime from datetime import datetime
import decimal import decimal
from decimal import Decimal from decimal import Decimal
@ -40,6 +40,7 @@ import json
import time import time
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
import ssl import ssl
import ipaddress
import aiohttp import aiohttp
from aiohttp_socks import SocksConnector, SocksVer from aiohttp_socks import SocksConnector, SocksVer
@ -1157,3 +1158,20 @@ def multisig_type(wallet_type):
if match: if match:
match = [int(x) for x in match.group(1, 2)] match = [int(x) for x in match.group(1, 2)]
return match return match
def is_ip_address(x: Union[str, bytes]) -> bool:
if isinstance(x, bytes):
x = x.decode("utf-8")
try:
ipaddress.ip_address(x)
return True
except ValueError:
return False
def list_enabled_bits(x: int) -> Sequence[int]:
"""e.g. 77 (0b1001101) --> (0, 2, 3, 6)"""
binary = bin(x)[2:]
rev_bin = reversed(binary)
return tuple(i for i, b in enumerate(rev_bin) if b == '1')

Loading…
Cancel
Save