|
@ -75,7 +75,7 @@ DBSession = scoped_session(session_factory) |
|
|
FLAG_DISABLE = 1 << 1 |
|
|
FLAG_DISABLE = 1 << 1 |
|
|
FLAG_DIRECTION = 1 << 0 |
|
|
FLAG_DIRECTION = 1 << 0 |
|
|
|
|
|
|
|
|
class ChannelInfoInDB(Base): |
|
|
class ChannelInfo(Base): |
|
|
__tablename__ = 'channel_info' |
|
|
__tablename__ = 'channel_info' |
|
|
short_channel_id = Column(String(64), primary_key=True) |
|
|
short_channel_id = Column(String(64), primary_key=True) |
|
|
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) |
|
|
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) |
|
@ -98,7 +98,7 @@ class ChannelInfoInDB(Base): |
|
|
|
|
|
|
|
|
capacity_sat = None |
|
|
capacity_sat = None |
|
|
|
|
|
|
|
|
return ChannelInfoInDB(short_channel_id = channel_id, node1_id = node_id_1, |
|
|
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1, |
|
|
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, |
|
|
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, |
|
|
trusted = False) |
|
|
trusted = False) |
|
|
|
|
|
|
|
@ -186,7 +186,7 @@ class Policy(Base): |
|
|
def is_disabled(self): |
|
|
def is_disabled(self): |
|
|
return self.channel_flags & FLAG_DISABLE |
|
|
return self.channel_flags & FLAG_DISABLE |
|
|
|
|
|
|
|
|
class NodeInfoInDB(Base): |
|
|
class NodeInfo(Base): |
|
|
__tablename__ = 'node_info' |
|
|
__tablename__ = 'node_info' |
|
|
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') |
|
|
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') |
|
|
features = Column(Integer, nullable=False) |
|
|
features = Column(Integer, nullable=False) |
|
@ -194,7 +194,7 @@ class NodeInfoInDB(Base): |
|
|
alias = Column(String(64), nullable=False) |
|
|
alias = Column(String(64), nullable=False) |
|
|
|
|
|
|
|
|
def get_addresses(self): |
|
|
def get_addresses(self): |
|
|
return DBSession.query(AddressInDB).join(NodeInfoInDB).filter_by(node_id = self.node_id).all() |
|
|
return DBSession.query(Address).join(NodeInfo).filter_by(node_id = self.node_id).all() |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def from_msg(node_announcement_payload, addresses_already_parsed=False): |
|
|
def from_msg(node_announcement_payload, addresses_already_parsed=False): |
|
@ -202,12 +202,12 @@ class NodeInfoInDB(Base): |
|
|
features = int.from_bytes(node_announcement_payload['features'], "big") |
|
|
features = int.from_bytes(node_announcement_payload['features'], "big") |
|
|
validate_features(features) |
|
|
validate_features(features) |
|
|
if not addresses_already_parsed: |
|
|
if not addresses_already_parsed: |
|
|
addresses = NodeInfoInDB.parse_addresses_field(node_announcement_payload['addresses']) |
|
|
addresses = NodeInfo.parse_addresses_field(node_announcement_payload['addresses']) |
|
|
else: |
|
|
else: |
|
|
addresses = node_announcement_payload['addresses'] |
|
|
addresses = node_announcement_payload['addresses'] |
|
|
alias = node_announcement_payload['alias'].rstrip(b'\x00').hex() |
|
|
alias = node_announcement_payload['alias'].rstrip(b'\x00').hex() |
|
|
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big")) |
|
|
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big")) |
|
|
return NodeInfoInDB(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [AddressInDB(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses] |
|
|
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [Address(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses] |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def parse_addresses_field(addresses_field): |
|
|
def parse_addresses_field(addresses_field): |
|
@ -249,7 +249,7 @@ class NodeInfoInDB(Base): |
|
|
break |
|
|
break |
|
|
return addresses |
|
|
return addresses |
|
|
|
|
|
|
|
|
class AddressInDB(Base): |
|
|
class Address(Base): |
|
|
__tablename__ = 'address' |
|
|
__tablename__ = 'address' |
|
|
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) |
|
|
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) |
|
|
host = Column(String(256), primary_key=True) |
|
|
host = Column(String(256), primary_key=True) |
|
@ -288,13 +288,13 @@ class ChannelDB: |
|
|
Base.metadata.create_all(engine) |
|
|
Base.metadata.create_all(engine) |
|
|
|
|
|
|
|
|
def update_counts(self): |
|
|
def update_counts(self): |
|
|
self.num_channels = DBSession.query(ChannelInfoInDB).count() |
|
|
self.num_channels = DBSession.query(ChannelInfo).count() |
|
|
self.num_nodes = DBSession.query(NodeInfoInDB).count() |
|
|
self.num_nodes = DBSession.query(NodeInfo).count() |
|
|
|
|
|
|
|
|
def add_recent_peer(self, peer : LNPeerAddr): |
|
|
def add_recent_peer(self, peer : LNPeerAddr): |
|
|
addr = DBSession.query(AddressInDB).filter_by(node_id = peer.pubkey.hex()).one_or_none() |
|
|
addr = DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none() |
|
|
if addr is None: |
|
|
if addr is None: |
|
|
addr = AddressInDB(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now()) |
|
|
addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now()) |
|
|
else: |
|
|
else: |
|
|
addr.last_connected_date = datetime.datetime.now() |
|
|
addr.last_connected_date = datetime.datetime.now() |
|
|
DBSession.add(addr) |
|
|
DBSession.add(addr) |
|
@ -302,8 +302,8 @@ class ChannelDB: |
|
|
|
|
|
|
|
|
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): |
|
|
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): |
|
|
unshuffled = DBSession \ |
|
|
unshuffled = DBSession \ |
|
|
.query(NodeInfoInDB) \ |
|
|
.query(NodeInfo) \ |
|
|
.filter(not_(NodeInfoInDB.node_id.in_(x.hex() for x in node_ids_bytes))) \ |
|
|
.filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \ |
|
|
.limit(200) \ |
|
|
.limit(200) \ |
|
|
.all() |
|
|
.all() |
|
|
return random.sample(unshuffled, len(unshuffled)) |
|
|
return random.sample(unshuffled, len(unshuffled)) |
|
@ -313,15 +313,15 @@ class ChannelDB: |
|
|
|
|
|
|
|
|
async def _nodes_get(self, node_id): |
|
|
async def _nodes_get(self, node_id): |
|
|
return DBSession \ |
|
|
return DBSession \ |
|
|
.query(NodeInfoInDB) \ |
|
|
.query(NodeInfo) \ |
|
|
.filter_by(node_id = node_id.hex()) \ |
|
|
.filter_by(node_id = node_id.hex()) \ |
|
|
.one_or_none() |
|
|
.one_or_none() |
|
|
|
|
|
|
|
|
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: |
|
|
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: |
|
|
adr_db = DBSession \ |
|
|
adr_db = DBSession \ |
|
|
.query(AddressInDB) \ |
|
|
.query(Address) \ |
|
|
.filter_by(node_id = node_id.hex()) \ |
|
|
.filter_by(node_id = node_id.hex()) \ |
|
|
.order_by(AddressInDB.last_connected_date.desc()) \ |
|
|
.order_by(Address.last_connected_date.desc()) \ |
|
|
.one_or_none() |
|
|
.one_or_none() |
|
|
if not adr_db: |
|
|
if not adr_db: |
|
|
return None |
|
|
return None |
|
@ -329,9 +329,9 @@ class ChannelDB: |
|
|
|
|
|
|
|
|
def get_recent_peers(self): |
|
|
def get_recent_peers(self): |
|
|
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \ |
|
|
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \ |
|
|
.query(AddressInDB) \ |
|
|
.query(Address) \ |
|
|
.select_from(NodeInfoInDB) \ |
|
|
.select_from(NodeInfo) \ |
|
|
.order_by(AddressInDB.last_connected_date.desc()) \ |
|
|
.order_by(Address.last_connected_date.desc()) \ |
|
|
.limit(self.NUM_MAX_RECENT_PEERS)] |
|
|
.limit(self.NUM_MAX_RECENT_PEERS)] |
|
|
|
|
|
|
|
|
def get_channel_info(self, channel_id: bytes): |
|
|
def get_channel_info(self, channel_id: bytes): |
|
@ -340,13 +340,13 @@ class ChannelDB: |
|
|
def get_channels_for_node(self, node_id): |
|
|
def get_channels_for_node(self, node_id): |
|
|
"""Returns the set of channels that have node_id as one of the endpoints.""" |
|
|
"""Returns the set of channels that have node_id as one of the endpoints.""" |
|
|
condition = or_( |
|
|
condition = or_( |
|
|
ChannelInfoInDB.node1_id == node_id.hex(), |
|
|
ChannelInfo.node1_id == node_id.hex(), |
|
|
ChannelInfoInDB.node2_id == node_id.hex()) |
|
|
ChannelInfo.node2_id == node_id.hex()) |
|
|
rows = DBSession.query(ChannelInfoInDB).filter(condition).all() |
|
|
rows = DBSession.query(ChannelInfo).filter(condition).all() |
|
|
return [bytes.fromhex(x.short_channel_id) for x in rows] |
|
|
return [bytes.fromhex(x.short_channel_id) for x in rows] |
|
|
|
|
|
|
|
|
def missing_short_chan_ids(self) -> Set[int]: |
|
|
def missing_short_chan_ids(self) -> Set[int]: |
|
|
expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfoInDB.short_channel_id))) |
|
|
expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfo.short_channel_id))) |
|
|
return set(DBSession.query(Policy.short_channel_id).filter(expr).all()) |
|
|
return set(DBSession.query(Policy.short_channel_id).filter(expr).all()) |
|
|
|
|
|
|
|
|
def add_verified_channel_info(self, short_id, capacity): |
|
|
def add_verified_channel_info(self, short_id, capacity): |
|
@ -362,13 +362,13 @@ class ChannelDB: |
|
|
msg_payloads = [msg_payloads] |
|
|
msg_payloads = [msg_payloads] |
|
|
for msg in msg_payloads: |
|
|
for msg in msg_payloads: |
|
|
short_channel_id = msg['short_channel_id'] |
|
|
short_channel_id = msg['short_channel_id'] |
|
|
if DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = bh2u(short_channel_id)).count(): |
|
|
if DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count(): |
|
|
continue |
|
|
continue |
|
|
if constants.net.rev_genesis_bytes() != msg['chain_hash']: |
|
|
if constants.net.rev_genesis_bytes() != msg['chain_hash']: |
|
|
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash']))) |
|
|
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash']))) |
|
|
continue |
|
|
continue |
|
|
try: |
|
|
try: |
|
|
channel_info = ChannelInfoInDB.from_msg(msg) |
|
|
channel_info = ChannelInfo.from_msg(msg) |
|
|
except UnknownEvenFeatureBits: |
|
|
except UnknownEvenFeatureBits: |
|
|
continue |
|
|
continue |
|
|
channel_info.trusted = trusted |
|
|
channel_info.trusted = trusted |
|
@ -383,7 +383,7 @@ class ChannelDB: |
|
|
if type(msg_payloads) is dict: |
|
|
if type(msg_payloads) is dict: |
|
|
msg_payloads = [msg_payloads] |
|
|
msg_payloads = [msg_payloads] |
|
|
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads] |
|
|
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads] |
|
|
channel_infos_list = DBSession.query(ChannelInfoInDB).filter(ChannelInfoInDB.short_channel_id.in_(short_channel_ids)).all() |
|
|
channel_infos_list = DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() |
|
|
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} |
|
|
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} |
|
|
for msg_payload in msg_payloads: |
|
|
for msg_payload in msg_payloads: |
|
|
short_channel_id = msg_payload['short_channel_id'] |
|
|
short_channel_id = msg_payload['short_channel_id'] |
|
@ -397,12 +397,12 @@ class ChannelDB: |
|
|
def on_node_announcement(self, msg_payloads): |
|
|
def on_node_announcement(self, msg_payloads): |
|
|
if type(msg_payloads) is dict: |
|
|
if type(msg_payloads) is dict: |
|
|
msg_payloads = [msg_payloads] |
|
|
msg_payloads = [msg_payloads] |
|
|
addresses = DBSession.query(AddressInDB).all() |
|
|
addresses = DBSession.query(Address).all() |
|
|
have_addr = {} |
|
|
have_addr = {} |
|
|
for addr in addresses: |
|
|
for addr in addresses: |
|
|
have_addr[(addr.node_id, addr.host, addr.port)] = addr |
|
|
have_addr[(addr.node_id, addr.host, addr.port)] = addr |
|
|
|
|
|
|
|
|
nodes = DBSession.query(NodeInfoInDB).all() |
|
|
nodes = DBSession.query(NodeInfo).all() |
|
|
timestamps = {} |
|
|
timestamps = {} |
|
|
for node in nodes: |
|
|
for node in nodes: |
|
|
no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")] |
|
|
no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")] |
|
@ -415,7 +415,7 @@ class ChannelDB: |
|
|
if not ecc.verify_signature(pubkey, signature, h): |
|
|
if not ecc.verify_signature(pubkey, signature, h): |
|
|
continue |
|
|
continue |
|
|
try: |
|
|
try: |
|
|
new_node_info, addresses = NodeInfoInDB.from_msg(msg_payload) |
|
|
new_node_info, addresses = NodeInfo.from_msg(msg_payload) |
|
|
except UnknownEvenFeatureBits: |
|
|
except UnknownEvenFeatureBits: |
|
|
continue |
|
|
continue |
|
|
if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp: |
|
|
if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp: |
|
@ -464,7 +464,7 @@ class ChannelDB: |
|
|
DBSession.commit() |
|
|
DBSession.commit() |
|
|
|
|
|
|
|
|
def chan_query_for_id(self, short_channel_id) -> Query: |
|
|
def chan_query_for_id(self, short_channel_id) -> Query: |
|
|
return DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = short_channel_id.hex()) |
|
|
return DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()) |
|
|
|
|
|
|
|
|
def print_graph(self, full_ids=False): |
|
|
def print_graph(self, full_ids=False): |
|
|
# used for debugging. |
|
|
# used for debugging. |
|
@ -478,11 +478,11 @@ class ChannelDB: |
|
|
return other if full_ids else other[-4:] |
|
|
return other if full_ids else other[-4:] |
|
|
|
|
|
|
|
|
self.print_msg('nodes') |
|
|
self.print_msg('nodes') |
|
|
for node in DBSession.query(NodeInfoInDB).all(): |
|
|
for node in DBSession.query(NodeInfo).all(): |
|
|
self.print_msg(node) |
|
|
self.print_msg(node) |
|
|
|
|
|
|
|
|
self.print_msg('channels') |
|
|
self.print_msg('channels') |
|
|
for channel_info in DBSession.query(ChannelInfoInDB).all(): |
|
|
for channel_info in DBSession.query(ChannelInfo).all(): |
|
|
node1 = channel_info.node1_id |
|
|
node1 = channel_info.node1_id |
|
|
node2 = channel_info.node2_id |
|
|
node2 = channel_info.node2_id |
|
|
direction1 = channel_info.get_policy_for_node(node1) is not None |
|
|
direction1 = channel_info.get_policy_for_node(node1) is not None |
|
|