Browse Source

ChannelDB: avoid duplicate (host,port) entries in ChannelDB._addresses

before:
node_id -> set of (host, port, ts)
after:
node_id -> NetAddress -> timestamp

Look at e.g. add_recent_peer; we only want to store
the last connection time, not all of them.
patch-4
SomberNight 4 years ago
parent
commit
2ec548dda3
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 52
      electrum/channel_db.py
  2. 8
      electrum/lnutil.py

52
electrum/channel_db.py

@ -34,6 +34,7 @@ import asyncio
import threading
from enum import IntEnum
from aiorpcx import NetAddress
from .sql_db import SqlDB, sql
from . import constants, util
@ -53,14 +54,6 @@ FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
class NodeAddress(NamedTuple):
"""Holds address information of Lightning nodes
and how up to date this info is."""
host: str
port: int
timestamp: int
class ChannelInfo(NamedTuple):
short_channel_id: ShortChannelID
node1_id: bytes
@ -295,8 +288,8 @@ class ChannelDB(SqlDB):
self._channels = {} # type: Dict[ShortChannelID, ChannelInfo]
self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy
self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo
# node_id -> (host, port, ts)
self._addresses = defaultdict(set) # type: Dict[bytes, Set[NodeAddress]]
# node_id -> NetAddress -> timestamp
self._addresses = defaultdict(dict) # type: Dict[bytes, Dict[NetAddress, int]]
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
self._recent_peers = [] # type: List[bytes] # list of node_ids
self._chans_with_0_policies = set() # type: Set[ShortChannelID]
@ -321,7 +314,7 @@ class ChannelDB(SqlDB):
now = int(time.time())
node_id = peer.pubkey
with self.lock:
self._addresses[node_id].add(NodeAddress(peer.host, peer.port, now))
self._addresses[node_id][peer.net_addr()] = now
# list is ordered
if node_id in self._recent_peers:
self._recent_peers.remove(node_id)
@ -336,12 +329,12 @@ class ChannelDB(SqlDB):
def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
"""Returns latest address we successfully connected to, for given node."""
r = self._addresses.get(node_id)
if not r:
addr_to_ts = self._addresses.get(node_id)
if not addr_to_ts:
return None
addr = sorted(list(r), key=lambda x: x.timestamp, reverse=True)[0]
addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0]
try:
return LNPeerAddr(addr.host, addr.port, node_id)
return LNPeerAddr(str(addr.host), addr.port, node_id)
except ValueError:
return None
@ -583,7 +576,8 @@ class ChannelDB(SqlDB):
self._db_save_node_info(node_id, msg_payload['raw'])
with self.lock:
for addr in node_addresses:
self._addresses[node_id].add(NodeAddress(addr.host, addr.port, 0))
net_addr = NetAddress(addr.host, addr.port)
self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0
self._db_save_node_addresses(node_addresses)
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
@ -634,8 +628,13 @@ class ChannelDB(SqlDB):
# delete from database
self._db_delete_channel(short_channel_id)
def get_node_addresses(self, node_id):
return self._addresses.get(node_id)
def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]:
"""Returns list of (host, port, timestamp)."""
addr_to_ts = self._addresses.get(node_id)
if not addr_to_ts:
return []
return [(str(net_addr.host), net_addr.port, ts)
for net_addr, ts in addr_to_ts.items()]
@sql
@profiler
@ -643,17 +642,19 @@ class ChannelDB(SqlDB):
if self.data_loaded.is_set():
return
# Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
# I believe lnmsg (and lightning.json) will need a rewrite anyway, so instead of tweaking
# load_data() here, that should be done. see #6006
c = self.conn.cursor()
c.execute("""SELECT * FROM address""")
for x in c:
node_id, host, port, timestamp = x
self._addresses[node_id].add(NodeAddress(str(host), int(port), int(timestamp or 0)))
try:
net_addr = NetAddress(host, port)
except Exception:
continue
self._addresses[node_id][net_addr] = int(timestamp or 0)
def newest_ts_for_node_id(node_id):
newest_ts = 0
for addr in self._addresses[node_id]:
newest_ts = max(newest_ts, addr.timestamp)
for addr, ts in self._addresses[node_id].items():
newest_ts = max(newest_ts, ts)
return newest_ts
sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True)
self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
@ -791,7 +792,10 @@ class ChannelDB(SqlDB):
graph['nodes'].append(
nodeinfo._asdict(),
)
graph['nodes'][-1]['addresses'] = [addr._asdict() for addr in self._addresses[pk]]
graph['nodes'][-1]['addresses'] = [
{'host': str(addr.host), 'port': addr.port, 'timestamp': ts}
for addr, ts in self._addresses[pk].items()
]
# gather channels
for cid, channelinfo in self._channels.items():

8
electrum/lnutil.py

@ -1106,6 +1106,7 @@ def derive_payment_secret_from_payment_preimage(payment_preimage: bytes) -> byte
class LNPeerAddr:
# note: while not programmatically enforced, this class is meant to be *immutable*
def __init__(self, host: str, port: int, pubkey: bytes):
assert isinstance(host, str), repr(host)
@ -1120,7 +1121,7 @@ class LNPeerAddr:
self.host = host
self.port = port
self.pubkey = pubkey
self._net_addr_str = str(net_addr)
self._net_addr = net_addr
def __str__(self):
return '{}@{}'.format(self.pubkey.hex(), self.net_addr_str())
@ -1128,8 +1129,11 @@ class LNPeerAddr:
def __repr__(self):
return f'<LNPeerAddr host={self.host} port={self.port} pubkey={self.pubkey.hex()}>'
def net_addr(self) -> NetAddress:
return self._net_addr
def net_addr_str(self) -> str:
return self._net_addr_str
return str(self._net_addr)
def __eq__(self, other):
if not isinstance(other, LNPeerAddr):

Loading…
Cancel
Save