Browse Source

fix sql conflicts in lnrouter

regtest_lnd
ThomasV 6 years ago
committed by SomberNight
parent
commit
d5b5c7ddef
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 138
      electrum/lnrouter.py

138
electrum/lnrouter.py

@ -23,7 +23,7 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
import datetime import time
import random import random
import queue import queue
import os import os
@ -35,7 +35,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii import binascii
import base64 import base64
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_ from sqlalchemy.sql import not_, or_
@ -81,14 +81,14 @@ class ChannelInfo(Base):
trusted = Column(Boolean, nullable=False) trusted = Column(Boolean, nullable=False)
@staticmethod @staticmethod
def from_msg(channel_announcement_payload): def from_msg(payload):
features = int.from_bytes(channel_announcement_payload['features'], 'big') features = int.from_bytes(payload['features'], 'big')
validate_features(features) validate_features(features)
channel_id = channel_announcement_payload['short_channel_id'].hex() channel_id = payload['short_channel_id'].hex()
node_id_1 = channel_announcement_payload['node_id_1'].hex() node_id_1 = payload['node_id_1'].hex()
node_id_2 = channel_announcement_payload['node_id_2'].hex() node_id_2 = payload['node_id_2'].hex()
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex() msg_payload_hex = encode_msg('channel_announcement', **payload).hex()
capacity_sat = None capacity_sat = None
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1, return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
@ -109,17 +109,17 @@ class Policy(Base):
fee_base_msat = Column(Integer, nullable=False) fee_base_msat = Column(Integer, nullable=False)
fee_proportional_millionths = Column(Integer, nullable=False) fee_proportional_millionths = Column(Integer, nullable=False)
channel_flags = Column(Integer, nullable=False) channel_flags = Column(Integer, nullable=False)
timestamp = Column(DateTime, nullable=False) timestamp = Column(Integer, nullable=False)
@staticmethod @staticmethod
def from_msg(channel_update_payload, start_node, short_channel_id): def from_msg(payload, start_node, short_channel_id):
cltv_expiry_delta = channel_update_payload['cltv_expiry_delta'] cltv_expiry_delta = payload['cltv_expiry_delta']
htlc_minimum_msat = channel_update_payload['htlc_minimum_msat'] htlc_minimum_msat = payload['htlc_minimum_msat']
fee_base_msat = channel_update_payload['fee_base_msat'] fee_base_msat = payload['fee_base_msat']
fee_proportional_millionths = channel_update_payload['fee_proportional_millionths'] fee_proportional_millionths = payload['fee_proportional_millionths']
channel_flags = channel_update_payload['channel_flags'] channel_flags = payload['channel_flags']
timestamp = channel_update_payload['timestamp'] timestamp = payload['timestamp']
htlc_maximum_msat = channel_update_payload.get('htlc_maximum_msat') # optional htlc_maximum_msat = payload.get('htlc_maximum_msat') # optional
cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big") cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big") htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
@ -127,7 +127,7 @@ class Policy(Base):
fee_base_msat = int.from_bytes(fee_base_msat, "big") fee_base_msat = int.from_bytes(fee_base_msat, "big")
fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big") fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
channel_flags = int.from_bytes(channel_flags, "big") channel_flags = int.from_bytes(channel_flags, "big")
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(timestamp, "big")) timestamp = int.from_bytes(timestamp, "big")
return Policy(start_node=start_node, return Policy(start_node=start_node,
short_channel_id=short_channel_id, short_channel_id=short_channel_id,
@ -150,17 +150,16 @@ class NodeInfo(Base):
alias = Column(String(64), nullable=False) alias = Column(String(64), nullable=False)
@staticmethod @staticmethod
def from_msg(node_announcement_payload, addresses_already_parsed=False): def from_msg(payload):
node_id = node_announcement_payload['node_id'].hex() node_id = payload['node_id'].hex()
features = int.from_bytes(node_announcement_payload['features'], "big") features = int.from_bytes(payload['features'], "big")
validate_features(features) validate_features(features)
if not addresses_already_parsed: addresses = NodeInfo.parse_addresses_field(payload['addresses'])
addresses = NodeInfo.parse_addresses_field(node_announcement_payload['addresses']) alias = payload['alias'].rstrip(b'\x00').hex()
else: timestamp = int.from_bytes(payload['timestamp'], "big")
addresses = node_announcement_payload['addresses'] now = int(time.time())
alias = node_announcement_payload['alias'].rstrip(b'\x00').hex() return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big")) Address(host=host, port=port, node_id=node_id, last_connected_date=now) for host, port in addresses]
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [Address(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses]
@staticmethod @staticmethod
def parse_addresses_field(addresses_field): def parse_addresses_field(addresses_field):
@ -207,7 +206,7 @@ class Address(Base):
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
host = Column(String(256), primary_key=True) host = Column(String(256), primary_key=True)
port = Column(Integer, primary_key=True) port = Column(Integer, primary_key=True)
last_connected_date = Column(DateTime(), nullable=False) last_connected_date = Column(Integer(), nullable=False)
@ -235,12 +234,14 @@ class ChannelDB(SqlDB):
@sql @sql
def add_recent_peer(self, peer: LNPeerAddr): def add_recent_peer(self, peer: LNPeerAddr):
addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none() now = int(time.time())
if addr is None: node_id = peer.pubkey.hex()
addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now()) addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
if addr:
addr.last_connected_date = now
else: else:
addr.last_connected_date = datetime.datetime.now() addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
self.DBSession.add(addr) self.DBSession.add(addr)
self.DBSession.commit() self.DBSession.commit()
@sql @sql
@ -317,25 +318,31 @@ class ChannelDB(SqlDB):
self.DBSession.commit() self.DBSession.commit()
@sql @sql
@profiler #@profiler
def on_channel_announcement(self, msg_payloads, trusted=False): def on_channel_announcement(self, msg_payloads, trusted=False):
if type(msg_payloads) is dict: if type(msg_payloads) is dict:
msg_payloads = [msg_payloads] msg_payloads = [msg_payloads]
new_channels = {}
for msg in msg_payloads: for msg in msg_payloads:
short_channel_id = msg['short_channel_id'] short_channel_id = bh2u(msg['short_channel_id'])
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count(): if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count():
continue continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']: if constants.net.rev_genesis_bytes() != msg['chain_hash']:
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash']))) self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
continue continue
try: try:
channel_info = ChannelInfo.from_msg(msg) channel_info = ChannelInfo.from_msg(msg)
except UnknownEvenFeatureBits: except UnknownEvenFeatureBits:
self.print_error("unknown feature bits")
continue continue
channel_info.trusted = trusted channel_info.trusted = trusted
new_channels[short_channel_id] = channel_info
if not trusted:
self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
for channel_info in new_channels.values():
self.DBSession.add(channel_info) self.DBSession.add(channel_info)
if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
self.DBSession.commit() self.DBSession.commit()
self.print_error('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
self._update_counts() self._update_counts()
self.network.trigger_callback('ln_status') self.network.trigger_callback('ln_status')
@ -379,21 +386,13 @@ class ChannelDB(SqlDB):
self.DBSession.commit() self.DBSession.commit()
@sql @sql
@profiler #@profiler
def on_node_announcement(self, msg_payloads): def on_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict: if type(msg_payloads) is dict:
msg_payloads = [msg_payloads] msg_payloads = [msg_payloads]
addresses = self.DBSession.query(Address).all()
have_addr = {}
for addr in addresses:
have_addr[(addr.node_id, addr.host, addr.port)] = addr
nodes = self.DBSession.query(NodeInfo).all()
timestamps = {}
for node in nodes:
no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")]
timestamps[bfh(node.node_id)] = datetime.datetime.strptime(no_millisecs, "%Y-%m-%d %H:%M:%S")
old_addr = None old_addr = None
new_nodes = {}
new_addresses = {}
for msg_payload in msg_payloads: for msg_payload in msg_payloads:
pubkey = msg_payload['node_id'] pubkey = msg_payload['node_id']
signature = msg_payload['signature'] signature = msg_payload['signature']
@ -401,30 +400,33 @@ class ChannelDB(SqlDB):
if not ecc.verify_signature(pubkey, signature, h): if not ecc.verify_signature(pubkey, signature, h):
continue continue
try: try:
new_node_info, addresses = NodeInfo.from_msg(msg_payload) node_info, node_addresses = NodeInfo.from_msg(msg_payload)
except UnknownEvenFeatureBits: except UnknownEvenFeatureBits:
continue continue
if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp: node_id = node_info.node_id
continue # ignore node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none()
self.DBSession.add(new_node_info) if node and node.timestamp >= node_info.timestamp:
for new_addr in addresses: continue
key = (new_addr.node_id, new_addr.host, new_addr.port) node = new_nodes.get(node_id)
old_addr = have_addr.get(key) if node and node.timestamp >= node_info.timestamp:
if old_addr: continue
# since old_addr is embedded in have_addr, new_nodes[node_id] = node_info
# it will still live when commmit is called for addr in node_addresses:
old_addr.last_connected_date = new_addr.last_connected_date new_addresses[(addr.node_id,addr.host,addr.port)] = addr
del new_addr
else: self.print_error("on_node_announcements: %d/%d"%(len(new_nodes), len(msg_payloads)))
self.DBSession.add(new_addr) for node_info in new_nodes.values():
have_addr[key] = new_addr self.DBSession.add(node_info)
for new_addr in new_addresses.values():
old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
if old_addr:
old_addr.last_connected_date = new_addr.last_connected_date
else:
self.DBSession.add(new_addr)
# TODO if this message is for a new node, and if we have no associated # TODO if this message is for a new node, and if we have no associated
# channels for this node, we should ignore the message and return here, # channels for this node, we should ignore the message and return here,
# to mitigate DOS. but race condition: the channels we have for this # to mitigate DOS. but race condition: the channels we have for this
# node, might be under verification in self.ca_verifier, what then? # node, might be under verification in self.ca_verifier, what then?
del nodes, addresses
if old_addr:
del old_addr
self.DBSession.commit() self.DBSession.commit()
self._update_counts() self._update_counts()
self.network.trigger_callback('ln_status') self.network.trigger_callback('ln_status')

Loading…
Cancel
Save