|
|
@ -70,7 +70,6 @@ def validate_features(features : int): |
|
|
|
|
|
|
|
Base = declarative_base() |
|
|
|
session_factory = sessionmaker() |
|
|
|
DBSession = scoped_session(session_factory) |
|
|
|
|
|
|
|
FLAG_DISABLE = 1 << 1 |
|
|
|
FLAG_DIRECTION = 1 << 0 |
|
|
@ -88,16 +87,12 @@ class ChannelInfo(Base): |
|
|
|
def from_msg(channel_announcement_payload): |
|
|
|
features = int.from_bytes(channel_announcement_payload['features'], 'big') |
|
|
|
validate_features(features) |
|
|
|
|
|
|
|
channel_id = channel_announcement_payload['short_channel_id'].hex() |
|
|
|
node_id_1 = channel_announcement_payload['node_id_1'].hex() |
|
|
|
node_id_2 = channel_announcement_payload['node_id_2'].hex() |
|
|
|
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] |
|
|
|
|
|
|
|
msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex() |
|
|
|
|
|
|
|
capacity_sat = None |
|
|
|
|
|
|
|
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1, |
|
|
|
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, |
|
|
|
trusted = False) |
|
|
@ -106,42 +101,6 @@ class ChannelInfo(Base): |
|
|
|
def msg_payload(self): |
|
|
|
return bytes.fromhex(self.msg_payload_hex) |
|
|
|
|
|
|
|
def on_channel_update(self, msg: dict, trusted=False): |
|
|
|
assert self.short_channel_id == msg['short_channel_id'].hex() |
|
|
|
flags = int.from_bytes(msg['channel_flags'], 'big') |
|
|
|
direction = flags & FLAG_DIRECTION |
|
|
|
if direction == 0: |
|
|
|
node_id = self.node1_id |
|
|
|
else: |
|
|
|
node_id = self.node2_id |
|
|
|
new_policy = Policy.from_msg(msg, node_id, self.short_channel_id) |
|
|
|
old_policy = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node=node_id).one_or_none() |
|
|
|
if not old_policy: |
|
|
|
DBSession.add(new_policy) |
|
|
|
return |
|
|
|
if old_policy.timestamp >= new_policy.timestamp: |
|
|
|
return # ignore |
|
|
|
if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)): |
|
|
|
return # ignore |
|
|
|
old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta |
|
|
|
old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat |
|
|
|
old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat |
|
|
|
old_policy.fee_base_msat = new_policy.fee_base_msat |
|
|
|
old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths |
|
|
|
old_policy.channel_flags = new_policy.channel_flags |
|
|
|
old_policy.timestamp = new_policy.timestamp |
|
|
|
|
|
|
|
def get_policy_for_node(self, node) -> Optional['Policy']: |
|
|
|
""" |
|
|
|
raises when initiator/non-initiator both unequal node |
|
|
|
""" |
|
|
|
if node.hex() not in (self.node1_id, self.node2_id): |
|
|
|
raise Exception("the given node is not a party in this channel") |
|
|
|
n1 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node1_id).one_or_none() |
|
|
|
if n1: |
|
|
|
return n1 |
|
|
|
n2 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none() |
|
|
|
return n2 |
|
|
|
|
|
|
|
class Policy(Base): |
|
|
|
__tablename__ = 'policy' |
|
|
@ -193,9 +152,6 @@ class NodeInfo(Base): |
|
|
|
timestamp = Column(Integer, nullable=False) |
|
|
|
alias = Column(String(64), nullable=False) |
|
|
|
|
|
|
|
def get_addresses(self): |
|
|
|
return DBSession.query(Address).join(NodeInfo).filter_by(node_id = self.node_id).all() |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def from_msg(node_announcement_payload, addresses_already_parsed=False): |
|
|
|
node_id = node_announcement_payload['node_id'].hex() |
|
|
@ -281,27 +237,28 @@ class ChannelDB: |
|
|
|
the lnpeer loop is running from, which will do call in here |
|
|
|
""" |
|
|
|
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True) |
|
|
|
DBSession.remove() |
|
|
|
DBSession.configure(bind=engine, autoflush=False) |
|
|
|
self.DBSession = scoped_session(session_factory) |
|
|
|
self.DBSession.remove() |
|
|
|
self.DBSession.configure(bind=engine, autoflush=False) |
|
|
|
|
|
|
|
Base.metadata.drop_all(engine) |
|
|
|
Base.metadata.create_all(engine) |
|
|
|
|
|
|
|
def update_counts(self): |
|
|
|
self.num_channels = DBSession.query(ChannelInfo).count() |
|
|
|
self.num_nodes = DBSession.query(NodeInfo).count() |
|
|
|
self.num_channels = self.DBSession.query(ChannelInfo).count() |
|
|
|
self.num_nodes = self.DBSession.query(NodeInfo).count() |
|
|
|
|
|
|
|
def add_recent_peer(self, peer : LNPeerAddr): |
|
|
|
addr = DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none() |
|
|
|
addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none() |
|
|
|
if addr is None: |
|
|
|
addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now()) |
|
|
|
else: |
|
|
|
addr.last_connected_date = datetime.datetime.now() |
|
|
|
DBSession.add(addr) |
|
|
|
DBSession.commit() |
|
|
|
self.DBSession.add(addr) |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): |
|
|
|
unshuffled = DBSession \ |
|
|
|
unshuffled = self.DBSession \ |
|
|
|
.query(NodeInfo) \ |
|
|
|
.filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \ |
|
|
|
.limit(200) \ |
|
|
@ -312,13 +269,13 @@ class ChannelDB: |
|
|
|
return self.network.run_from_another_thread(self._nodes_get(node_id)) |
|
|
|
|
|
|
|
async def _nodes_get(self, node_id): |
|
|
|
return DBSession \ |
|
|
|
return self.DBSession \ |
|
|
|
.query(NodeInfo) \ |
|
|
|
.filter_by(node_id = node_id.hex()) \ |
|
|
|
.one_or_none() |
|
|
|
|
|
|
|
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: |
|
|
|
adr_db = DBSession \ |
|
|
|
adr_db = self.DBSession \ |
|
|
|
.query(Address) \ |
|
|
|
.filter_by(node_id = node_id.hex()) \ |
|
|
|
.order_by(Address.last_connected_date.desc()) \ |
|
|
@ -328,7 +285,7 @@ class ChannelDB: |
|
|
|
return LNPeerAddr(adr_db.host, adr_db.port, bytes.fromhex(adr_db.node_id)) |
|
|
|
|
|
|
|
def get_recent_peers(self): |
|
|
|
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \ |
|
|
|
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in self.DBSession \ |
|
|
|
.query(Address) \ |
|
|
|
.select_from(NodeInfo) \ |
|
|
|
.order_by(Address.last_connected_date.desc()) \ |
|
|
@ -342,21 +299,21 @@ class ChannelDB: |
|
|
|
condition = or_( |
|
|
|
ChannelInfo.node1_id == node_id.hex(), |
|
|
|
ChannelInfo.node2_id == node_id.hex()) |
|
|
|
rows = DBSession.query(ChannelInfo).filter(condition).all() |
|
|
|
rows = self.DBSession.query(ChannelInfo).filter(condition).all() |
|
|
|
return [bytes.fromhex(x.short_channel_id) for x in rows] |
|
|
|
|
|
|
|
def missing_short_chan_ids(self) -> Set[int]: |
|
|
|
expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfo.short_channel_id))) |
|
|
|
chan_ids_from_policy = set(x[0] for x in DBSession.query(Policy.short_channel_id).filter(expr).all()) |
|
|
|
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id))) |
|
|
|
chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all()) |
|
|
|
if chan_ids_from_policy: |
|
|
|
return chan_ids_from_policy |
|
|
|
# fetch channels for node_ids missing in node_info. that will also give us node_announcement |
|
|
|
expr = not_(ChannelInfo.node1_id.in_(DBSession.query(NodeInfo.node_id))) |
|
|
|
chan_ids_from_id1 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) |
|
|
|
expr = not_(ChannelInfo.node1_id.in_(self.DBSession.query(NodeInfo.node_id))) |
|
|
|
chan_ids_from_id1 = set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) |
|
|
|
if chan_ids_from_id1: |
|
|
|
return chan_ids_from_id1 |
|
|
|
expr = not_(ChannelInfo.node2_id.in_(DBSession.query(NodeInfo.node_id))) |
|
|
|
chan_ids_from_id2 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) |
|
|
|
expr = not_(ChannelInfo.node2_id.in_(self.DBSession.query(NodeInfo.node_id))) |
|
|
|
chan_ids_from_id2 = set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) |
|
|
|
if chan_ids_from_id2: |
|
|
|
return chan_ids_from_id2 |
|
|
|
return set() |
|
|
@ -366,7 +323,7 @@ class ChannelDB: |
|
|
|
channel_info = self.get_channel_info(short_id) |
|
|
|
channel_info.trusted = True |
|
|
|
channel_info.capacity = capacity |
|
|
|
DBSession.commit() |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
@profiler |
|
|
|
def on_channel_announcement(self, msg_payloads, trusted=False): |
|
|
@ -374,7 +331,7 @@ class ChannelDB: |
|
|
|
msg_payloads = [msg_payloads] |
|
|
|
for msg in msg_payloads: |
|
|
|
short_channel_id = msg['short_channel_id'] |
|
|
|
if DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count(): |
|
|
|
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count(): |
|
|
|
continue |
|
|
|
if constants.net.rev_genesis_bytes() != msg['chain_hash']: |
|
|
|
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash']))) |
|
|
@ -384,9 +341,9 @@ class ChannelDB: |
|
|
|
except UnknownEvenFeatureBits: |
|
|
|
continue |
|
|
|
channel_info.trusted = trusted |
|
|
|
DBSession.add(channel_info) |
|
|
|
self.DBSession.add(channel_info) |
|
|
|
if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) |
|
|
|
DBSession.commit() |
|
|
|
self.DBSession.commit() |
|
|
|
self.network.trigger_callback('ln_status') |
|
|
|
self.update_counts() |
|
|
|
|
|
|
@ -395,7 +352,7 @@ class ChannelDB: |
|
|
|
if type(msg_payloads) is dict: |
|
|
|
msg_payloads = [msg_payloads] |
|
|
|
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads] |
|
|
|
channel_infos_list = DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() |
|
|
|
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() |
|
|
|
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} |
|
|
|
for msg_payload in msg_payloads: |
|
|
|
short_channel_id = msg_payload['short_channel_id'] |
|
|
@ -404,19 +361,19 @@ class ChannelDB: |
|
|
|
channel_info = channel_infos.get(short_channel_id) |
|
|
|
if not channel_info: |
|
|
|
continue |
|
|
|
channel_info.on_channel_update(msg_payload, trusted=trusted) |
|
|
|
DBSession.commit() |
|
|
|
self._update_channel_info(channel_info, msg_payload, trusted=trusted) |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
@profiler |
|
|
|
def on_node_announcement(self, msg_payloads): |
|
|
|
if type(msg_payloads) is dict: |
|
|
|
msg_payloads = [msg_payloads] |
|
|
|
addresses = DBSession.query(Address).all() |
|
|
|
addresses = self.DBSession.query(Address).all() |
|
|
|
have_addr = {} |
|
|
|
for addr in addresses: |
|
|
|
have_addr[(addr.node_id, addr.host, addr.port)] = addr |
|
|
|
|
|
|
|
nodes = DBSession.query(NodeInfo).all() |
|
|
|
nodes = self.DBSession.query(NodeInfo).all() |
|
|
|
timestamps = {} |
|
|
|
for node in nodes: |
|
|
|
no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")] |
|
|
@ -434,7 +391,7 @@ class ChannelDB: |
|
|
|
continue |
|
|
|
if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp: |
|
|
|
continue # ignore |
|
|
|
DBSession.add(new_node_info) |
|
|
|
self.DBSession.add(new_node_info) |
|
|
|
for new_addr in addresses: |
|
|
|
key = (new_addr.node_id, new_addr.host, new_addr.port) |
|
|
|
old_addr = have_addr.get(key) |
|
|
@ -444,7 +401,7 @@ class ChannelDB: |
|
|
|
old_addr.last_connected_date = new_addr.last_connected_date |
|
|
|
del new_addr |
|
|
|
else: |
|
|
|
DBSession.add(new_addr) |
|
|
|
self.DBSession.add(new_addr) |
|
|
|
have_addr[key] = new_addr |
|
|
|
# TODO if this message is for a new node, and if we have no associated |
|
|
|
# channels for this node, we should ignore the message and return here, |
|
|
@ -453,7 +410,7 @@ class ChannelDB: |
|
|
|
del nodes, addresses |
|
|
|
if old_addr: |
|
|
|
del old_addr |
|
|
|
DBSession.commit() |
|
|
|
self.DBSession.commit() |
|
|
|
self.network.trigger_callback('ln_status') |
|
|
|
self.update_counts() |
|
|
|
|
|
|
@ -462,9 +419,10 @@ class ChannelDB: |
|
|
|
if not start_node_id or not short_channel_id: return None |
|
|
|
channel_info = self.get_channel_info(short_channel_id) |
|
|
|
if channel_info is not None: |
|
|
|
return channel_info.get_policy_for_node(start_node_id) |
|
|
|
return self.get_policy_for_node(channel_info, start_node_id) |
|
|
|
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) |
|
|
|
if not msg: return None |
|
|
|
if not msg: |
|
|
|
return None |
|
|
|
return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB |
|
|
|
|
|
|
|
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): |
|
|
@ -475,10 +433,10 @@ class ChannelDB: |
|
|
|
|
|
|
|
def remove_channel(self, short_channel_id): |
|
|
|
self.chan_query_for_id(short_channel_id).delete('evaluate') |
|
|
|
DBSession.commit() |
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
def chan_query_for_id(self, short_channel_id) -> Query: |
|
|
|
return DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()) |
|
|
|
return self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()) |
|
|
|
|
|
|
|
def print_graph(self, full_ids=False): |
|
|
|
# used for debugging. |
|
|
@ -492,15 +450,15 @@ class ChannelDB: |
|
|
|
return other if full_ids else other[-4:] |
|
|
|
|
|
|
|
self.print_msg('nodes') |
|
|
|
for node in DBSession.query(NodeInfo).all(): |
|
|
|
for node in self.DBSession.query(NodeInfo).all(): |
|
|
|
self.print_msg(node) |
|
|
|
|
|
|
|
self.print_msg('channels') |
|
|
|
for channel_info in DBSession.query(ChannelInfo).all(): |
|
|
|
for channel_info in self.DBSession.query(ChannelInfo).all(): |
|
|
|
node1 = channel_info.node1_id |
|
|
|
node2 = channel_info.node2_id |
|
|
|
direction1 = channel_info.get_policy_for_node(node1) is not None |
|
|
|
direction2 = channel_info.get_policy_for_node(node2) is not None |
|
|
|
direction1 = self.get_policy_for_node(channel_info, node1) is not None |
|
|
|
direction2 = self.get_policy_for_node(channel_info, node2) is not None |
|
|
|
if direction1 and direction2: |
|
|
|
direction = 'both' |
|
|
|
elif direction1: |
|
|
@ -515,6 +473,44 @@ class ChannelDB: |
|
|
|
bh2u(node2) if full_ids else bh2u(node2[-4:]), |
|
|
|
direction)) |
|
|
|
|
|
|
|
def _update_channel_info(self, channel_info, msg: dict, trusted=False): |
|
|
|
assert channel_info.short_channel_id == msg['short_channel_id'].hex() |
|
|
|
flags = int.from_bytes(msg['channel_flags'], 'big') |
|
|
|
direction = flags & FLAG_DIRECTION |
|
|
|
node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id |
|
|
|
new_policy = Policy.from_msg(msg, node_id, channel_info.short_channel_id) |
|
|
|
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node=node_id).one_or_none() |
|
|
|
if not old_policy: |
|
|
|
self.DBSession.add(new_policy) |
|
|
|
return |
|
|
|
if old_policy.timestamp >= new_policy.timestamp: |
|
|
|
return # ignore |
|
|
|
if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)): |
|
|
|
return # ignore |
|
|
|
old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta |
|
|
|
old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat |
|
|
|
old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat |
|
|
|
old_policy.fee_base_msat = new_policy.fee_base_msat |
|
|
|
old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths |
|
|
|
old_policy.channel_flags = new_policy.channel_flags |
|
|
|
old_policy.timestamp = new_policy.timestamp |
|
|
|
|
|
|
|
def get_policy_for_node(self, node) -> Optional['Policy']: |
|
|
|
""" |
|
|
|
raises when initiator/non-initiator both unequal node |
|
|
|
""" |
|
|
|
if node.hex() not in (self.node1_id, self.node2_id): |
|
|
|
raise Exception("the given node is not a party in this channel") |
|
|
|
n1 = self.DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node1_id).one_or_none() |
|
|
|
if n1: |
|
|
|
return n1 |
|
|
|
n2 = self.DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none() |
|
|
|
return n2 |
|
|
|
|
|
|
|
def get_node_addresses(self, node_info): |
|
|
|
return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes), |
|
|
|
('short_channel_id', bytes), |
|
|
@ -596,7 +592,7 @@ class LNPathFinder(PrintError): |
|
|
|
if channel_info is None: |
|
|
|
return float('inf'), 0 |
|
|
|
|
|
|
|
channel_policy = channel_info.get_policy_for_node(start_node) |
|
|
|
channel_policy = self.channel_db.get_policy_for_node(channel_info, start_node) |
|
|
|
if channel_policy is None: return float('inf'), 0 |
|
|
|
if channel_policy.is_disabled(): return float('inf'), 0 |
|
|
|
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node) |
|
|
|