From 3442e51fac536067ed5bd62091d555ecdcae092e Mon Sep 17 00:00:00 2001 From: Janus Date: Wed, 20 Feb 2019 21:06:37 +0100 Subject: [PATCH] sqlite in lnrouter: remove useless InDB suffix --- electrum/lnrouter.py | 64 ++++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 0a2bd3775..61235abc8 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -75,7 +75,7 @@ DBSession = scoped_session(session_factory) FLAG_DISABLE = 1 << 1 FLAG_DIRECTION = 1 << 0 -class ChannelInfoInDB(Base): +class ChannelInfo(Base): __tablename__ = 'channel_info' short_channel_id = Column(String(64), primary_key=True) node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) @@ -98,7 +98,7 @@ class ChannelInfoInDB(Base): capacity_sat = None - return ChannelInfoInDB(short_channel_id = channel_id, node1_id = node_id_1, + return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1, node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, trusted = False) @@ -186,7 +186,7 @@ class Policy(Base): def is_disabled(self): return self.channel_flags & FLAG_DISABLE -class NodeInfoInDB(Base): +class NodeInfo(Base): __tablename__ = 'node_info' node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') features = Column(Integer, nullable=False) @@ -194,7 +194,7 @@ class NodeInfoInDB(Base): alias = Column(String(64), nullable=False) def get_addresses(self): - return DBSession.query(AddressInDB).join(NodeInfoInDB).filter_by(node_id = self.node_id).all() + return DBSession.query(Address).join(NodeInfo).filter_by(node_id = self.node_id).all() @staticmethod def from_msg(node_announcement_payload, addresses_already_parsed=False): @@ -202,12 +202,12 @@ class NodeInfoInDB(Base): features = int.from_bytes(node_announcement_payload['features'], "big") validate_features(features) if not addresses_already_parsed: - addresses = NodeInfoInDB.parse_addresses_field(node_announcement_payload['addresses']) + addresses = NodeInfo.parse_addresses_field(node_announcement_payload['addresses']) else: addresses = node_announcement_payload['addresses'] alias = node_announcement_payload['alias'].rstrip(b'\x00').hex() timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big")) - return NodeInfoInDB(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [AddressInDB(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses] + return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [Address(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses] @staticmethod def parse_addresses_field(addresses_field): @@ -249,7 +249,7 @@ class NodeInfoInDB(Base): break return addresses -class AddressInDB(Base): +class Address(Base): __tablename__ = 'address' node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) host = Column(String(256), primary_key=True) @@ -288,13 +288,13 @@ class ChannelDB: Base.metadata.create_all(engine) def update_counts(self): - self.num_channels = DBSession.query(ChannelInfoInDB).count() - self.num_nodes = DBSession.query(NodeInfoInDB).count() + self.num_channels = DBSession.query(ChannelInfo).count() + self.num_nodes = DBSession.query(NodeInfo).count() def add_recent_peer(self, peer : LNPeerAddr): - addr = DBSession.query(AddressInDB).filter_by(node_id = peer.pubkey.hex()).one_or_none() + addr = DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none() if addr is None: - addr = AddressInDB(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now()) + addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now()) else: addr.last_connected_date = datetime.datetime.now() DBSession.add(addr) @@ -302,8 +302,8 @@ class ChannelDB: def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): unshuffled = DBSession \ - .query(NodeInfoInDB) \ - .filter(not_(NodeInfoInDB.node_id.in_(x.hex() for x in node_ids_bytes))) \ + .query(NodeInfo) \ + .filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \ .limit(200) \ .all() return random.sample(unshuffled, len(unshuffled)) @@ -313,15 +313,15 @@ class ChannelDB: async def _nodes_get(self, node_id): return DBSession \ - .query(NodeInfoInDB) \ + .query(NodeInfo) \ .filter_by(node_id = node_id.hex()) \ .one_or_none() def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: adr_db = DBSession \ - .query(AddressInDB) \ + .query(Address) \ .filter_by(node_id = node_id.hex()) \ - .order_by(AddressInDB.last_connected_date.desc()) \ + .order_by(Address.last_connected_date.desc()) \ .one_or_none() if not adr_db: return None @@ -329,9 +329,9 @@ class ChannelDB: def get_recent_peers(self): return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \ - .query(AddressInDB) \ - .select_from(NodeInfoInDB) \ - .order_by(AddressInDB.last_connected_date.desc()) \ + .query(Address) \ + .select_from(NodeInfo) \ + .order_by(Address.last_connected_date.desc()) \ .limit(self.NUM_MAX_RECENT_PEERS)] def get_channel_info(self, channel_id: bytes): @@ -340,13 +340,13 @@ class ChannelDB: def get_channels_for_node(self, node_id): """Returns the set of channels that have node_id as one of the endpoints.""" condition = or_( - ChannelInfoInDB.node1_id == node_id.hex(), - ChannelInfoInDB.node2_id == node_id.hex()) - rows = DBSession.query(ChannelInfoInDB).filter(condition).all() + ChannelInfo.node1_id == node_id.hex(), + ChannelInfo.node2_id == node_id.hex()) + rows = DBSession.query(ChannelInfo).filter(condition).all() return [bytes.fromhex(x.short_channel_id) for x in rows] def missing_short_chan_ids(self) -> Set[int]: - expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfoInDB.short_channel_id))) + expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfo.short_channel_id))) return set(DBSession.query(Policy.short_channel_id).filter(expr).all()) def add_verified_channel_info(self, short_id, capacity): @@ -362,13 +362,13 @@ class ChannelDB: msg_payloads = [msg_payloads] for msg in msg_payloads: short_channel_id = msg['short_channel_id'] - if DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = bh2u(short_channel_id)).count(): + if DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count(): continue if constants.net.rev_genesis_bytes() != msg['chain_hash']: #self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash']))) continue try: - channel_info = ChannelInfoInDB.from_msg(msg) + channel_info = ChannelInfo.from_msg(msg) except UnknownEvenFeatureBits: continue channel_info.trusted = trusted @@ -383,7 +383,7 @@ class ChannelDB: if type(msg_payloads) is dict: msg_payloads = [msg_payloads] short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads] - channel_infos_list = DBSession.query(ChannelInfoInDB).filter(ChannelInfoInDB.short_channel_id.in_(short_channel_ids)).all() + channel_infos_list = DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} for msg_payload in msg_payloads: short_channel_id = msg_payload['short_channel_id'] @@ -397,12 +397,12 @@ class ChannelDB: def on_node_announcement(self, msg_payloads): if type(msg_payloads) is dict: msg_payloads = [msg_payloads] - addresses = DBSession.query(AddressInDB).all() + addresses = DBSession.query(Address).all() have_addr = {} for addr in addresses: have_addr[(addr.node_id, addr.host, addr.port)] = addr - nodes = DBSession.query(NodeInfoInDB).all() + nodes = DBSession.query(NodeInfo).all() timestamps = {} for node in nodes: no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")] @@ -415,7 +415,7 @@ class ChannelDB: if not ecc.verify_signature(pubkey, signature, h): continue try: - new_node_info, addresses = NodeInfoInDB.from_msg(msg_payload) + new_node_info, addresses = NodeInfo.from_msg(msg_payload) except UnknownEvenFeatureBits: continue if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp: @@ -464,7 +464,7 @@ class ChannelDB: DBSession.commit() def chan_query_for_id(self, short_channel_id) -> Query: - return DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = short_channel_id.hex()) + return DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()) def print_graph(self, full_ids=False): # used for debugging. @@ -478,11 +478,11 @@ class ChannelDB: return other if full_ids else other[-4:] self.print_msg('nodes') - for node in DBSession.query(NodeInfoInDB).all(): + for node in DBSession.query(NodeInfo).all(): self.print_msg(node) self.print_msg('channels') - for channel_info in DBSession.query(ChannelInfoInDB).all(): + for channel_info in DBSession.query(ChannelInfo).all(): node1 = channel_info.node1_id node2 = channel_info.node2_id direction1 = channel_info.get_policy_for_node(node1) is not None