Browse Source

fix sql conflicts in lnrouter

regtest_lnd
ThomasV 6 years ago
committed by SomberNight
parent
commit
d5b5c7ddef
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 138
      electrum/lnrouter.py

138
electrum/lnrouter.py

@ -23,7 +23,7 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import datetime
import time
import random
import queue
import os
@ -35,7 +35,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii
import base64
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
@ -81,14 +81,14 @@ class ChannelInfo(Base):
trusted = Column(Boolean, nullable=False)
@staticmethod
def from_msg(channel_announcement_payload):
features = int.from_bytes(channel_announcement_payload['features'], 'big')
def from_msg(payload):
features = int.from_bytes(payload['features'], 'big')
validate_features(features)
channel_id = channel_announcement_payload['short_channel_id'].hex()
node_id_1 = channel_announcement_payload['node_id_1'].hex()
node_id_2 = channel_announcement_payload['node_id_2'].hex()
channel_id = payload['short_channel_id'].hex()
node_id_1 = payload['node_id_1'].hex()
node_id_2 = payload['node_id_2'].hex()
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex()
msg_payload_hex = encode_msg('channel_announcement', **payload).hex()
capacity_sat = None
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
@ -109,17 +109,17 @@ class Policy(Base):
fee_base_msat = Column(Integer, nullable=False)
fee_proportional_millionths = Column(Integer, nullable=False)
channel_flags = Column(Integer, nullable=False)
timestamp = Column(DateTime, nullable=False)
timestamp = Column(Integer, nullable=False)
@staticmethod
def from_msg(channel_update_payload, start_node, short_channel_id):
cltv_expiry_delta = channel_update_payload['cltv_expiry_delta']
htlc_minimum_msat = channel_update_payload['htlc_minimum_msat']
fee_base_msat = channel_update_payload['fee_base_msat']
fee_proportional_millionths = channel_update_payload['fee_proportional_millionths']
channel_flags = channel_update_payload['channel_flags']
timestamp = channel_update_payload['timestamp']
htlc_maximum_msat = channel_update_payload.get('htlc_maximum_msat') # optional
def from_msg(payload, start_node, short_channel_id):
cltv_expiry_delta = payload['cltv_expiry_delta']
htlc_minimum_msat = payload['htlc_minimum_msat']
fee_base_msat = payload['fee_base_msat']
fee_proportional_millionths = payload['fee_proportional_millionths']
channel_flags = payload['channel_flags']
timestamp = payload['timestamp']
htlc_maximum_msat = payload.get('htlc_maximum_msat') # optional
cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
@ -127,7 +127,7 @@ class Policy(Base):
fee_base_msat = int.from_bytes(fee_base_msat, "big")
fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
channel_flags = int.from_bytes(channel_flags, "big")
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(timestamp, "big"))
timestamp = int.from_bytes(timestamp, "big")
return Policy(start_node=start_node,
short_channel_id=short_channel_id,
@ -150,17 +150,16 @@ class NodeInfo(Base):
alias = Column(String(64), nullable=False)
@staticmethod
def from_msg(node_announcement_payload, addresses_already_parsed=False):
node_id = node_announcement_payload['node_id'].hex()
features = int.from_bytes(node_announcement_payload['features'], "big")
def from_msg(payload):
node_id = payload['node_id'].hex()
features = int.from_bytes(payload['features'], "big")
validate_features(features)
if not addresses_already_parsed:
addresses = NodeInfo.parse_addresses_field(node_announcement_payload['addresses'])
else:
addresses = node_announcement_payload['addresses']
alias = node_announcement_payload['alias'].rstrip(b'\x00').hex()
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big"))
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [Address(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses]
addresses = NodeInfo.parse_addresses_field(payload['addresses'])
alias = payload['alias'].rstrip(b'\x00').hex()
timestamp = int.from_bytes(payload['timestamp'], "big")
now = int(time.time())
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
Address(host=host, port=port, node_id=node_id, last_connected_date=now) for host, port in addresses]
@staticmethod
def parse_addresses_field(addresses_field):
@ -207,7 +206,7 @@ class Address(Base):
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
host = Column(String(256), primary_key=True)
port = Column(Integer, primary_key=True)
last_connected_date = Column(DateTime(), nullable=False)
last_connected_date = Column(Integer(), nullable=False)
@ -235,12 +234,14 @@ class ChannelDB(SqlDB):
@sql
def add_recent_peer(self, peer: LNPeerAddr):
addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
if addr is None:
addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
now = int(time.time())
node_id = peer.pubkey.hex()
addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
if addr:
addr.last_connected_date = now
else:
addr.last_connected_date = datetime.datetime.now()
self.DBSession.add(addr)
addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
self.DBSession.add(addr)
self.DBSession.commit()
@sql
@ -317,25 +318,31 @@ class ChannelDB(SqlDB):
self.DBSession.commit()
@sql
@profiler
#@profiler
def on_channel_announcement(self, msg_payloads, trusted=False):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
new_channels = {}
for msg in msg_payloads:
short_channel_id = msg['short_channel_id']
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count():
short_channel_id = bh2u(msg['short_channel_id'])
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count():
continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
continue
try:
channel_info = ChannelInfo.from_msg(msg)
except UnknownEvenFeatureBits:
self.print_error("unknown feature bits")
continue
channel_info.trusted = trusted
new_channels[short_channel_id] = channel_info
if not trusted:
self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
for channel_info in new_channels.values():
self.DBSession.add(channel_info)
if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
self.DBSession.commit()
self.print_error('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
self._update_counts()
self.network.trigger_callback('ln_status')
@ -379,21 +386,13 @@ class ChannelDB(SqlDB):
self.DBSession.commit()
@sql
@profiler
#@profiler
def on_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
addresses = self.DBSession.query(Address).all()
have_addr = {}
for addr in addresses:
have_addr[(addr.node_id, addr.host, addr.port)] = addr
nodes = self.DBSession.query(NodeInfo).all()
timestamps = {}
for node in nodes:
no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")]
timestamps[bfh(node.node_id)] = datetime.datetime.strptime(no_millisecs, "%Y-%m-%d %H:%M:%S")
old_addr = None
new_nodes = {}
new_addresses = {}
for msg_payload in msg_payloads:
pubkey = msg_payload['node_id']
signature = msg_payload['signature']
@ -401,30 +400,33 @@ class ChannelDB(SqlDB):
if not ecc.verify_signature(pubkey, signature, h):
continue
try:
new_node_info, addresses = NodeInfo.from_msg(msg_payload)
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
except UnknownEvenFeatureBits:
continue
if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp:
continue # ignore
self.DBSession.add(new_node_info)
for new_addr in addresses:
key = (new_addr.node_id, new_addr.host, new_addr.port)
old_addr = have_addr.get(key)
if old_addr:
# since old_addr is embedded in have_addr,
# it will still live when commmit is called
old_addr.last_connected_date = new_addr.last_connected_date
del new_addr
else:
self.DBSession.add(new_addr)
have_addr[key] = new_addr
node_id = node_info.node_id
node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none()
if node and node.timestamp >= node_info.timestamp:
continue
node = new_nodes.get(node_id)
if node and node.timestamp >= node_info.timestamp:
continue
new_nodes[node_id] = node_info
for addr in node_addresses:
new_addresses[(addr.node_id,addr.host,addr.port)] = addr
self.print_error("on_node_announcements: %d/%d"%(len(new_nodes), len(msg_payloads)))
for node_info in new_nodes.values():
self.DBSession.add(node_info)
for new_addr in new_addresses.values():
old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
if old_addr:
old_addr.last_connected_date = new_addr.last_connected_date
else:
self.DBSession.add(new_addr)
# TODO if this message is for a new node, and if we have no associated
# channels for this node, we should ignore the message and return here,
# to mitigate DOS. but race condition: the channels we have for this
# node, might be under verification in self.ca_verifier, what then?
del nodes, addresses
if old_addr:
del old_addr
self.DBSession.commit()
self._update_counts()
self.network.trigger_callback('ln_status')

Loading…
Cancel
Save