Browse Source

sqlite in lnrouter: remove useless InDB suffix

dependabot/pip/contrib/deterministic-build/ecdsa-0.13.3
Janus 6 years ago
committed by ThomasV
parent
commit
3442e51fac
  1. 64
      electrum/lnrouter.py

64
electrum/lnrouter.py

@ -75,7 +75,7 @@ DBSession = scoped_session(session_factory)
FLAG_DISABLE = 1 << 1 FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0 FLAG_DIRECTION = 1 << 0
class ChannelInfoInDB(Base): class ChannelInfo(Base):
__tablename__ = 'channel_info' __tablename__ = 'channel_info'
short_channel_id = Column(String(64), primary_key=True) short_channel_id = Column(String(64), primary_key=True)
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
@ -98,7 +98,7 @@ class ChannelInfoInDB(Base):
capacity_sat = None capacity_sat = None
return ChannelInfoInDB(short_channel_id = channel_id, node1_id = node_id_1, return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
trusted = False) trusted = False)
@ -186,7 +186,7 @@ class Policy(Base):
def is_disabled(self): def is_disabled(self):
return self.channel_flags & FLAG_DISABLE return self.channel_flags & FLAG_DISABLE
class NodeInfoInDB(Base): class NodeInfo(Base):
__tablename__ = 'node_info' __tablename__ = 'node_info'
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
features = Column(Integer, nullable=False) features = Column(Integer, nullable=False)
@ -194,7 +194,7 @@ class NodeInfoInDB(Base):
alias = Column(String(64), nullable=False) alias = Column(String(64), nullable=False)
def get_addresses(self): def get_addresses(self):
return DBSession.query(AddressInDB).join(NodeInfoInDB).filter_by(node_id = self.node_id).all() return DBSession.query(Address).join(NodeInfo).filter_by(node_id = self.node_id).all()
@staticmethod @staticmethod
def from_msg(node_announcement_payload, addresses_already_parsed=False): def from_msg(node_announcement_payload, addresses_already_parsed=False):
@ -202,12 +202,12 @@ class NodeInfoInDB(Base):
features = int.from_bytes(node_announcement_payload['features'], "big") features = int.from_bytes(node_announcement_payload['features'], "big")
validate_features(features) validate_features(features)
if not addresses_already_parsed: if not addresses_already_parsed:
addresses = NodeInfoInDB.parse_addresses_field(node_announcement_payload['addresses']) addresses = NodeInfo.parse_addresses_field(node_announcement_payload['addresses'])
else: else:
addresses = node_announcement_payload['addresses'] addresses = node_announcement_payload['addresses']
alias = node_announcement_payload['alias'].rstrip(b'\x00').hex() alias = node_announcement_payload['alias'].rstrip(b'\x00').hex()
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big")) timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big"))
return NodeInfoInDB(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [AddressInDB(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses] return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [Address(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses]
@staticmethod @staticmethod
def parse_addresses_field(addresses_field): def parse_addresses_field(addresses_field):
@ -249,7 +249,7 @@ class NodeInfoInDB(Base):
break break
return addresses return addresses
class AddressInDB(Base): class Address(Base):
__tablename__ = 'address' __tablename__ = 'address'
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
host = Column(String(256), primary_key=True) host = Column(String(256), primary_key=True)
@ -288,13 +288,13 @@ class ChannelDB:
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
def update_counts(self): def update_counts(self):
self.num_channels = DBSession.query(ChannelInfoInDB).count() self.num_channels = DBSession.query(ChannelInfo).count()
self.num_nodes = DBSession.query(NodeInfoInDB).count() self.num_nodes = DBSession.query(NodeInfo).count()
def add_recent_peer(self, peer : LNPeerAddr): def add_recent_peer(self, peer : LNPeerAddr):
addr = DBSession.query(AddressInDB).filter_by(node_id = peer.pubkey.hex()).one_or_none() addr = DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
if addr is None: if addr is None:
addr = AddressInDB(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now()) addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
else: else:
addr.last_connected_date = datetime.datetime.now() addr.last_connected_date = datetime.datetime.now()
DBSession.add(addr) DBSession.add(addr)
@ -302,8 +302,8 @@ class ChannelDB:
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
unshuffled = DBSession \ unshuffled = DBSession \
.query(NodeInfoInDB) \ .query(NodeInfo) \
.filter(not_(NodeInfoInDB.node_id.in_(x.hex() for x in node_ids_bytes))) \ .filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \
.limit(200) \ .limit(200) \
.all() .all()
return random.sample(unshuffled, len(unshuffled)) return random.sample(unshuffled, len(unshuffled))
@ -313,15 +313,15 @@ class ChannelDB:
async def _nodes_get(self, node_id): async def _nodes_get(self, node_id):
return DBSession \ return DBSession \
.query(NodeInfoInDB) \ .query(NodeInfo) \
.filter_by(node_id = node_id.hex()) \ .filter_by(node_id = node_id.hex()) \
.one_or_none() .one_or_none()
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
adr_db = DBSession \ adr_db = DBSession \
.query(AddressInDB) \ .query(Address) \
.filter_by(node_id = node_id.hex()) \ .filter_by(node_id = node_id.hex()) \
.order_by(AddressInDB.last_connected_date.desc()) \ .order_by(Address.last_connected_date.desc()) \
.one_or_none() .one_or_none()
if not adr_db: if not adr_db:
return None return None
@ -329,9 +329,9 @@ class ChannelDB:
def get_recent_peers(self): def get_recent_peers(self):
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \ return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \
.query(AddressInDB) \ .query(Address) \
.select_from(NodeInfoInDB) \ .select_from(NodeInfo) \
.order_by(AddressInDB.last_connected_date.desc()) \ .order_by(Address.last_connected_date.desc()) \
.limit(self.NUM_MAX_RECENT_PEERS)] .limit(self.NUM_MAX_RECENT_PEERS)]
def get_channel_info(self, channel_id: bytes): def get_channel_info(self, channel_id: bytes):
@ -340,13 +340,13 @@ class ChannelDB:
def get_channels_for_node(self, node_id): def get_channels_for_node(self, node_id):
"""Returns the set of channels that have node_id as one of the endpoints.""" """Returns the set of channels that have node_id as one of the endpoints."""
condition = or_( condition = or_(
ChannelInfoInDB.node1_id == node_id.hex(), ChannelInfo.node1_id == node_id.hex(),
ChannelInfoInDB.node2_id == node_id.hex()) ChannelInfo.node2_id == node_id.hex())
rows = DBSession.query(ChannelInfoInDB).filter(condition).all() rows = DBSession.query(ChannelInfo).filter(condition).all()
return [bytes.fromhex(x.short_channel_id) for x in rows] return [bytes.fromhex(x.short_channel_id) for x in rows]
def missing_short_chan_ids(self) -> Set[int]: def missing_short_chan_ids(self) -> Set[int]:
expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfoInDB.short_channel_id))) expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfo.short_channel_id)))
return set(DBSession.query(Policy.short_channel_id).filter(expr).all()) return set(DBSession.query(Policy.short_channel_id).filter(expr).all())
def add_verified_channel_info(self, short_id, capacity): def add_verified_channel_info(self, short_id, capacity):
@ -362,13 +362,13 @@ class ChannelDB:
msg_payloads = [msg_payloads] msg_payloads = [msg_payloads]
for msg in msg_payloads: for msg in msg_payloads:
short_channel_id = msg['short_channel_id'] short_channel_id = msg['short_channel_id']
if DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = bh2u(short_channel_id)).count(): if DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count():
continue continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']: if constants.net.rev_genesis_bytes() != msg['chain_hash']:
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash']))) #self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
continue continue
try: try:
channel_info = ChannelInfoInDB.from_msg(msg) channel_info = ChannelInfo.from_msg(msg)
except UnknownEvenFeatureBits: except UnknownEvenFeatureBits:
continue continue
channel_info.trusted = trusted channel_info.trusted = trusted
@ -383,7 +383,7 @@ class ChannelDB:
if type(msg_payloads) is dict: if type(msg_payloads) is dict:
msg_payloads = [msg_payloads] msg_payloads = [msg_payloads]
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads] short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
channel_infos_list = DBSession.query(ChannelInfoInDB).filter(ChannelInfoInDB.short_channel_id.in_(short_channel_ids)).all() channel_infos_list = DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
for msg_payload in msg_payloads: for msg_payload in msg_payloads:
short_channel_id = msg_payload['short_channel_id'] short_channel_id = msg_payload['short_channel_id']
@ -397,12 +397,12 @@ class ChannelDB:
def on_node_announcement(self, msg_payloads): def on_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict: if type(msg_payloads) is dict:
msg_payloads = [msg_payloads] msg_payloads = [msg_payloads]
addresses = DBSession.query(AddressInDB).all() addresses = DBSession.query(Address).all()
have_addr = {} have_addr = {}
for addr in addresses: for addr in addresses:
have_addr[(addr.node_id, addr.host, addr.port)] = addr have_addr[(addr.node_id, addr.host, addr.port)] = addr
nodes = DBSession.query(NodeInfoInDB).all() nodes = DBSession.query(NodeInfo).all()
timestamps = {} timestamps = {}
for node in nodes: for node in nodes:
no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")] no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")]
@ -415,7 +415,7 @@ class ChannelDB:
if not ecc.verify_signature(pubkey, signature, h): if not ecc.verify_signature(pubkey, signature, h):
continue continue
try: try:
new_node_info, addresses = NodeInfoInDB.from_msg(msg_payload) new_node_info, addresses = NodeInfo.from_msg(msg_payload)
except UnknownEvenFeatureBits: except UnknownEvenFeatureBits:
continue continue
if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp: if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp:
@ -464,7 +464,7 @@ class ChannelDB:
DBSession.commit() DBSession.commit()
def chan_query_for_id(self, short_channel_id) -> Query: def chan_query_for_id(self, short_channel_id) -> Query:
return DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = short_channel_id.hex()) return DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex())
def print_graph(self, full_ids=False): def print_graph(self, full_ids=False):
# used for debugging. # used for debugging.
@ -478,11 +478,11 @@ class ChannelDB:
return other if full_ids else other[-4:] return other if full_ids else other[-4:]
self.print_msg('nodes') self.print_msg('nodes')
for node in DBSession.query(NodeInfoInDB).all(): for node in DBSession.query(NodeInfo).all():
self.print_msg(node) self.print_msg(node)
self.print_msg('channels') self.print_msg('channels')
for channel_info in DBSession.query(ChannelInfoInDB).all(): for channel_info in DBSession.query(ChannelInfo).all():
node1 = channel_info.node1_id node1 = channel_info.node1_id
node2 = channel_info.node2_id node2 = channel_info.node2_id
direction1 = channel_info.get_policy_for_node(node1) is not None direction1 = channel_info.get_policy_for_node(node1) is not None

Loading…
Cancel
Save