Browse Source

Use separate lightning nodes for gossip and channel operations.

regtest_lnd
ThomasV 6 years ago
committed by SomberNight
parent
commit
244380d00d
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 5
      electrum/lnpeer.py
  2. 361
      electrum/lnworker.py
  3. 4
      electrum/network.py
  4. 27
      electrum/tests/test_lnpeer.py
  5. 4
      electrum/wallet.py

5
electrum/lnpeer.py

@ -63,6 +63,7 @@ class Peer(PrintError):
self.pubkey = pubkey self.pubkey = pubkey
self.lnworker = lnworker self.lnworker = lnworker
self.privkey = lnworker.node_keypair.privkey self.privkey = lnworker.node_keypair.privkey
self.localfeatures = self.lnworker.localfeatures
self.node_ids = [self.pubkey, privkey_to_pubkey(self.privkey)] self.node_ids = [self.pubkey, privkey_to_pubkey(self.privkey)]
self.network = lnworker.network self.network = lnworker.network
self.lnwatcher = lnworker.network.lnwatcher self.lnwatcher = lnworker.network.lnwatcher
@ -76,10 +77,6 @@ class Peer(PrintError):
self.announcement_signatures = defaultdict(asyncio.Queue) self.announcement_signatures = defaultdict(asyncio.Queue)
self.closing_signed = defaultdict(asyncio.Queue) self.closing_signed = defaultdict(asyncio.Queue)
self.payment_preimages = defaultdict(asyncio.Queue) self.payment_preimages = defaultdict(asyncio.Queue)
self.localfeatures = LnLocalFeatures(0)
self.localfeatures |= LnLocalFeatures.GOSSIP_QUERIES_REQ
#self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ
#self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT
self.attempted_route = {} self.attempted_route = {}
self.orphan_channel_updates = OrderedDict() self.orphan_channel_updates = OrderedDict()
self.sent_commitment_for_ctn_last = defaultdict(lambda: None) # type: Dict[Channel, Optional[int]] self.sent_commitment_for_ctn_last = defaultdict(lambda: None) # type: Dict[Channel, Optional[int]]

361
electrum/lnworker.py

@ -39,7 +39,7 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
generate_keypair, LnKeyFamily, LOCAL, REMOTE, generate_keypair, LnKeyFamily, LOCAL, REMOTE,
UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE,
NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner, NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner,
UpdateAddHtlc, Direction) UpdateAddHtlc, Direction, LnLocalFeatures)
from .i18n import _ from .i18n import _
from .lnrouter import RouteEdge, is_route_sane_to_use from .lnrouter import RouteEdge, is_route_sane_to_use
from .address_synchronizer import TX_HEIGHT_LOCAL from .address_synchronizer import TX_HEIGHT_LOCAL
@ -74,16 +74,178 @@ encoder = ChannelJsonEncoder()
class LNWorker(PrintError): class LNWorker(PrintError):
def __init__(self, xprv):
self.node_keypair = generate_keypair(keystore.from_xprv(xprv), LnKeyFamily.NODE_KEY, 0)
self.peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer
self.localfeatures = LnLocalFeatures(0)
async def maybe_listen(self):
listen_addr = self.config.get('lightning_listen')
if listen_addr:
addr, port = listen_addr.rsplit(':', 2)
if addr[0] == '[':
# ipv6
addr = addr[1:-1]
async def cb(reader, writer):
transport = LNResponderTransport(self.node_keypair.privkey, reader, writer)
try:
node_id = await transport.handshake()
except:
self.print_error('handshake failure from incoming connection')
return
peer = Peer(self, node_id, transport)
self.peers[node_id] = peer
await self.network.main_taskgroup.spawn(peer.main_loop())
self.network.trigger_callback('ln_status')
await asyncio.start_server(cb, addr, int(port))
async def main_loop(self):
while True:
await asyncio.sleep(1)
now = time.time()
if len(self.peers) >= NUM_PEERS_TARGET:
continue
peers = self._get_next_peers_to_try()
for peer in peers:
last_tried = self._last_tried_peer.get(peer, 0)
if last_tried + PEER_RETRY_INTERVAL < now:
await self.add_peer(peer.host, peer.port, peer.pubkey)
async def add_peer(self, host, port, node_id):
if node_id in self.peers:
return self.peers[node_id]
port = int(port)
peer_addr = LNPeerAddr(host, port, node_id)
transport = LNTransport(self.node_keypair.privkey, peer_addr)
self._last_tried_peer[peer_addr] = time.time()
self.print_error("adding peer", peer_addr)
peer = Peer(self, node_id, transport)
await self.network.main_taskgroup.spawn(peer.main_loop())
self.peers[node_id] = peer
self.network.trigger_callback('ln_status')
return peer
def start_network(self, network: 'Network'):
self.network = network
self.config = network.config
self.channel_db = self.network.channel_db
self._last_tried_peer = {} # LNPeerAddr -> unix timestamp
self._add_peers_from_config()
# wait until we see confirmations
asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.main_loop()), self.network.asyncio_loop)
self.first_timestamp_requested = None
def _add_peers_from_config(self):
peer_list = self.config.get('lightning_peers', [])
for host, port, pubkey in peer_list:
asyncio.run_coroutine_threadsafe(
self.add_peer(host, int(port), bfh(pubkey)),
self.network.asyncio_loop)
def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
now = time.time()
recent_peers = self.channel_db.get_recent_peers()
# maintenance for last tried times
# due to this, below we can just test membership in _last_tried_peer
for peer in list(self._last_tried_peer):
if now >= self._last_tried_peer[peer] + PEER_RETRY_INTERVAL:
del self._last_tried_peer[peer]
# first try from recent peers
for peer in recent_peers:
if peer.pubkey in self.peers: continue
if peer in self._last_tried_peer: continue
return [peer]
# try random peer from graph
unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
if unconnected_nodes:
for node in unconnected_nodes:
addrs = self.channel_db.get_node_addresses(node)
if not addrs:
continue
host, port = self.choose_preferred_address(addrs)
peer = LNPeerAddr(host, port, bytes.fromhex(node.node_id))
if peer in self._last_tried_peer: continue
self.print_error('taking random ln peer from our channel db')
return [peer]
# TODO remove this. For some reason the dns seeds seem to ignore the realm byte
# and only return mainnet nodes. so for the time being dns seeding is disabled:
if constants.net in (constants.BitcoinTestnet, ):
return [random.choice(FALLBACK_NODE_LIST_TESTNET)]
elif constants.net in (constants.BitcoinMainnet, ):
return [random.choice(FALLBACK_NODE_LIST_MAINNET)]
else:
return []
# try peers from dns seed.
# return several peers to reduce the number of dns queries.
if not constants.net.LN_DNS_SEEDS:
return []
dns_seed = random.choice(constants.net.LN_DNS_SEEDS)
self.print_error('asking dns seed "{}" for ln peers'.format(dns_seed))
try:
# note: this might block for several seconds
# this will include bech32-encoded-pubkeys and ports
srv_answers = resolve_dns_srv('r{}.{}'.format(
constants.net.LN_REALM_BYTE, dns_seed))
except dns.exception.DNSException as e:
return []
random.shuffle(srv_answers)
num_peers = 2 * NUM_PEERS_TARGET
srv_answers = srv_answers[:num_peers]
# we now have pubkeys and ports but host is still needed
peers = []
for srv_ans in srv_answers:
try:
# note: this might block for several seconds
answers = dns.resolver.query(srv_ans['host'])
except dns.exception.DNSException:
continue
try:
ln_host = str(answers[0])
port = int(srv_ans['port'])
bech32_pubkey = srv_ans['host'].split('.')[0]
pubkey = get_compressed_pubkey_from_bech32(bech32_pubkey)
peers.append(LNPeerAddr(ln_host, port, pubkey))
except Exception as e:
self.print_error('error with parsing peer from dns seed: {}'.format(e))
continue
self.print_error('got {} ln peers from dns seed'.format(len(peers)))
return peers
class LNGossip(LNWorker):
def __init__(self, network):
seed = os.urandom(32)
node = BIP32Node.from_rootseed(seed, xtype='standard')
xprv = node.to_xprv()
super().__init__(xprv)
self.localfeatures |= LnLocalFeatures.GOSSIP_QUERIES_REQ
class LNWallet(LNWorker):
def __init__(self, wallet: 'Abstract_Wallet'): def __init__(self, wallet: 'Abstract_Wallet'):
self.wallet = wallet self.wallet = wallet
self.storage = wallet.storage self.storage = wallet.storage
xprv = self.storage.get('lightning_privkey2')
if xprv is None:
# TODO derive this deterministically from wallet.keystore at keystore generation time
# probably along a hardened path ( lnd-equivalent would be m/1017'/coinType'/ )
seed = os.urandom(32)
node = BIP32Node.from_rootseed(seed, xtype='standard')
xprv = node.to_xprv()
self.storage.put('lightning_privkey2', xprv)
super().__init__(xprv)
self.ln_keystore = keystore.from_xprv(xprv)
#self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ
#self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT
self.invoices = self.storage.get('lightning_invoices', {}) # RHASH -> (invoice, direction, is_paid) self.invoices = self.storage.get('lightning_invoices', {}) # RHASH -> (invoice, direction, is_paid)
self.preimages = self.storage.get('lightning_preimages', {}) # RHASH -> preimage self.preimages = self.storage.get('lightning_preimages', {}) # RHASH -> preimage
self.sweep_address = wallet.get_receiving_address() self.sweep_address = wallet.get_receiving_address()
self.lock = threading.RLock() self.lock = threading.RLock()
self.ln_keystore = self._read_ln_keystore()
self.node_keypair = generate_keypair(self.ln_keystore, LnKeyFamily.NODE_KEY, 0)
self.peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer
self.channels = {} # type: Dict[bytes, Channel] self.channels = {} # type: Dict[bytes, Channel]
for x in wallet.storage.get("channels", []): for x in wallet.storage.get("channels", []):
c = Channel(x, sweep_address=self.sweep_address, lnworker=self) c = Channel(x, sweep_address=self.sweep_address, lnworker=self)
@ -95,19 +257,20 @@ class LNWorker(PrintError):
def start_network(self, network: 'Network'): def start_network(self, network: 'Network'):
self.network = network self.network = network
self.config = network.config
self.channel_db = self.network.channel_db
for chan_id, chan in self.channels.items():
self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
chan.lnwatcher = network.lnwatcher
self._last_tried_peer = {} # LNPeerAddr -> unix timestamp
self._add_peers_from_config()
# wait until we see confirmations
self.network.register_callback(self.on_network_update, ['wallet_updated', 'network_updated', 'verified', 'fee']) # thread safe self.network.register_callback(self.on_network_update, ['wallet_updated', 'network_updated', 'verified', 'fee']) # thread safe
self.network.register_callback(self.on_channel_open, ['channel_open']) self.network.register_callback(self.on_channel_open, ['channel_open'])
self.network.register_callback(self.on_channel_closed, ['channel_closed']) self.network.register_callback(self.on_channel_closed, ['channel_closed'])
asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.main_loop()), self.network.asyncio_loop) for chan_id, chan in self.channels.items():
self.first_timestamp_requested = None self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
chan.lnwatcher = network.lnwatcher
super().start_network(network)
for coro in [
self.maybe_listen(),
self.on_network_update('network_updated'), # shortcut (don't block) if funding tx locked and verified
self.network.lnwatcher.on_network_update('network_updated'), # ping watcher to check our channels
self.reestablish_peers_and_channels()
]:
asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(coro), self.network.asyncio_loop)
def payment_completed(self, chan: Channel, direction: Direction, def payment_completed(self, chan: Channel, direction: Direction,
htlc: UpdateAddHtlc): htlc: UpdateAddHtlc):
@ -193,17 +356,6 @@ class LNWorker(PrintError):
item['balance_msat'] = balance_msat item['balance_msat'] = balance_msat
return out return out
def _read_ln_keystore(self) -> BIP32_KeyStore:
xprv = self.storage.get('lightning_privkey2')
if xprv is None:
# TODO derive this deterministically from wallet.keystore at keystore generation time
# probably along a hardened path ( lnd-equivalent would be m/1017'/coinType'/ )
seed = os.urandom(32)
node = BIP32Node.from_rootseed(seed, xtype='standard')
xprv = node.to_xprv()
self.storage.put('lightning_privkey2', xprv)
return keystore.from_xprv(xprv)
def get_and_inc_counter_for_channel_keys(self): def get_and_inc_counter_for_channel_keys(self):
with self.lock: with self.lock:
ctr = self.storage.get('lightning_channel_key_der_ctr', -1) ctr = self.storage.get('lightning_channel_key_der_ctr', -1)
@ -212,14 +364,6 @@ class LNWorker(PrintError):
self.storage.write() self.storage.write()
return ctr return ctr
def _add_peers_from_config(self):
peer_list = self.config.get('lightning_peers', [])
for host, port, pubkey in peer_list:
asyncio.run_coroutine_threadsafe(
self.add_peer(host, int(port), bfh(pubkey)),
self.network.asyncio_loop)
def suggest_peer(self): def suggest_peer(self):
for node_id, peer in self.peers.items(): for node_id, peer in self.peers.items():
if not peer.initialized.is_set(): if not peer.initialized.is_set():
@ -233,20 +377,6 @@ class LNWorker(PrintError):
with self.lock: with self.lock:
return {x: y for (x, y) in self.channels.items() if y.node_id == node_id} return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
async def add_peer(self, host, port, node_id):
if node_id in self.peers:
return self.peers[node_id]
port = int(port)
peer_addr = LNPeerAddr(host, port, node_id)
transport = LNTransport(self.node_keypair.privkey, peer_addr)
self._last_tried_peer[peer_addr] = time.time()
self.print_error("adding peer", peer_addr)
peer = Peer(self, node_id, transport)
await self.network.main_taskgroup.spawn(peer.main_loop())
self.peers[node_id] = peer
self.network.trigger_callback('ln_status')
return peer
def save_channel(self, openchannel): def save_channel(self, openchannel):
assert type(openchannel) is Channel assert type(openchannel) is Channel
if openchannel.config[REMOTE].next_per_commitment_point == openchannel.config[REMOTE].current_per_commitment_point: if openchannel.config[REMOTE].next_per_commitment_point == openchannel.config[REMOTE].current_per_commitment_point:
@ -692,77 +822,6 @@ class LNWorker(PrintError):
await self.network.broadcast_transaction(tx) await self.network.broadcast_transaction(tx)
return tx.txid() return tx.txid()
def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
now = time.time()
recent_peers = self.channel_db.get_recent_peers()
# maintenance for last tried times
# due to this, below we can just test membership in _last_tried_peer
for peer in list(self._last_tried_peer):
if now >= self._last_tried_peer[peer] + PEER_RETRY_INTERVAL:
del self._last_tried_peer[peer]
# first try from recent peers
for peer in recent_peers:
if peer.pubkey in self.peers: continue
if peer in self._last_tried_peer: continue
return [peer]
# try random peer from graph
unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
if unconnected_nodes:
for node in unconnected_nodes:
addrs = self.channel_db.get_node_addresses(node)
if not addrs:
continue
host, port = self.choose_preferred_address(addrs)
peer = LNPeerAddr(host, port, bytes.fromhex(node.node_id))
if peer in self._last_tried_peer: continue
self.print_error('taking random ln peer from our channel db')
return [peer]
# TODO remove this. For some reason the dns seeds seem to ignore the realm byte
# and only return mainnet nodes. so for the time being dns seeding is disabled:
if constants.net in (constants.BitcoinTestnet, ):
return [random.choice(FALLBACK_NODE_LIST_TESTNET)]
elif constants.net in (constants.BitcoinMainnet, ):
return [random.choice(FALLBACK_NODE_LIST_MAINNET)]
else:
return []
# try peers from dns seed.
# return several peers to reduce the number of dns queries.
if not constants.net.LN_DNS_SEEDS:
return []
dns_seed = random.choice(constants.net.LN_DNS_SEEDS)
self.print_error('asking dns seed "{}" for ln peers'.format(dns_seed))
try:
# note: this might block for several seconds
# this will include bech32-encoded-pubkeys and ports
srv_answers = resolve_dns_srv('r{}.{}'.format(
constants.net.LN_REALM_BYTE, dns_seed))
except dns.exception.DNSException as e:
return []
random.shuffle(srv_answers)
num_peers = 2 * NUM_PEERS_TARGET
srv_answers = srv_answers[:num_peers]
# we now have pubkeys and ports but host is still needed
peers = []
for srv_ans in srv_answers:
try:
# note: this might block for several seconds
answers = dns.resolver.query(srv_ans['host'])
except dns.exception.DNSException:
continue
try:
ln_host = str(answers[0])
port = int(srv_ans['port'])
bech32_pubkey = srv_ans['host'].split('.')[0]
pubkey = get_compressed_pubkey_from_bech32(bech32_pubkey)
peers.append(LNPeerAddr(ln_host, port, pubkey))
except Exception as e:
self.print_error('error with parsing peer from dns seed: {}'.format(e))
continue
self.print_error('got {} ln peers from dns seed'.format(len(peers)))
return peers
async def reestablish_peers_and_channels(self): async def reestablish_peers_and_channels(self):
async def reestablish_peer_for_given_channel(): async def reestablish_peer_for_given_channel():
# try last good address first # try last good address first
@ -784,21 +843,23 @@ class LNWorker(PrintError):
if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now: if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now:
await self.add_peer(host, port, chan.node_id) await self.add_peer(host, port, chan.node_id)
with self.lock: while True:
channels = list(self.channels.values()) await asyncio.sleep(1)
now = time.time() with self.lock:
for chan in channels: channels = list(self.channels.values())
if chan.is_closed(): now = time.time()
continue for chan in channels:
if constants.net is not constants.BitcoinRegtest: if chan.is_closed():
ratio = chan.constraints.feerate / self.current_feerate_per_kw() continue
if ratio < 0.5: if constants.net is not constants.BitcoinRegtest:
self.print_error(f"WARNING: fee level for channel {bh2u(chan.channel_id)} is {chan.constraints.feerate} sat/kiloweight, current recommended feerate is {self.current_feerate_per_kw()} sat/kiloweight, consider force closing!") ratio = chan.constraints.feerate / self.current_feerate_per_kw()
if not chan.should_try_to_reestablish_peer(): if ratio < 0.5:
continue self.print_error(f"WARNING: fee level for channel {bh2u(chan.channel_id)} is {chan.constraints.feerate} sat/kiloweight, current recommended feerate is {self.current_feerate_per_kw()} sat/kiloweight, consider force closing!")
peer = self.peers.get(chan.node_id, None) if not chan.should_try_to_reestablish_peer():
coro = peer.reestablish_channel(chan) if peer else reestablish_peer_for_given_channel() continue
await self.network.main_taskgroup.spawn(coro) peer = self.peers.get(chan.node_id, None)
coro = peer.reestablish_channel(chan) if peer else reestablish_peer_for_given_channel()
await self.network.main_taskgroup.spawn(coro)
def current_feerate_per_kw(self): def current_feerate_per_kw(self):
from .simple_config import FEE_LN_ETA_TARGET, FEERATE_FALLBACK_STATIC_FEE, FEERATE_REGTEST_HARDCODED from .simple_config import FEE_LN_ETA_TARGET, FEERATE_FALLBACK_STATIC_FEE, FEERATE_REGTEST_HARDCODED
@ -808,37 +869,3 @@ class LNWorker(PrintError):
if feerate_per_kvbyte is None: if feerate_per_kvbyte is None:
feerate_per_kvbyte = FEERATE_FALLBACK_STATIC_FEE feerate_per_kvbyte = FEERATE_FALLBACK_STATIC_FEE
return max(253, feerate_per_kvbyte // 4) return max(253, feerate_per_kvbyte // 4)
async def main_loop(self):
await self.on_network_update('network_updated') # shortcut (don't block) if funding tx locked and verified
await self.network.lnwatcher.on_network_update('network_updated') # ping watcher to check our channels
listen_addr = self.config.get('lightning_listen')
if listen_addr:
addr, port = listen_addr.rsplit(':', 2)
if addr[0] == '[':
# ipv6
addr = addr[1:-1]
async def cb(reader, writer):
transport = LNResponderTransport(self.node_keypair.privkey, reader, writer)
try:
node_id = await transport.handshake()
except:
self.print_error('handshake failure from incoming connection')
return
peer = Peer(self, node_id, transport)
self.peers[node_id] = peer
await self.network.main_taskgroup.spawn(peer.main_loop())
self.network.trigger_callback('ln_status')
await asyncio.start_server(cb, addr, int(port))
while True:
await asyncio.sleep(1)
now = time.time()
await self.reestablish_peers_and_channels()
if len(self.peers) >= NUM_PEERS_TARGET:
continue
peers = self._get_next_peers_to_try()
for peer in peers:
last_tried = self._last_tried_peer.get(peer, 0)
if last_tried + PEER_RETRY_INTERVAL < now:
await self.add_peer(peer.host, peer.port, peer.pubkey)

4
electrum/network.py

@ -297,10 +297,12 @@ class Network(Logger):
# lightning network # lightning network
from . import lnwatcher from . import lnwatcher
from . import lnworker
from . import lnrouter from . import lnrouter
self.channel_db = lnrouter.ChannelDB(self) self.channel_db = lnrouter.ChannelDB(self)
self.path_finder = lnrouter.LNPathFinder(self.channel_db) self.path_finder = lnrouter.LNPathFinder(self.channel_db)
self.lnwatcher = lnwatcher.LNWatcher(self) self.lnwatcher = lnwatcher.LNWatcher(self)
self.lngossip = lnworker.LNGossip(self)
def run_from_another_thread(self, coro): def run_from_another_thread(self, coro):
assert self._loop_thread != threading.current_thread(), 'must not be called from network thread' assert self._loop_thread != threading.current_thread(), 'must not be called from network thread'
@ -1151,6 +1153,8 @@ class Network(Logger):
asyncio.run_coroutine_threadsafe(main(), self.asyncio_loop) asyncio.run_coroutine_threadsafe(main(), self.asyncio_loop)
self.trigger_callback('network_updated') self.trigger_callback('network_updated')
#
self.lngossip.start_network(self)
def start(self, jobs: List=None): def start(self, jobs: List=None):
self._jobs = jobs or [] self._jobs = jobs or []

27
electrum/tests/test_lnpeer.py

@ -14,9 +14,9 @@ from electrum.util import bh2u, set_verbosity, create_and_start_event_loop
from electrum.lnpeer import Peer from electrum.lnpeer import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure from electrum.lnutil import PaymentFailure, LnLocalFeatures
from electrum.lnrouter import ChannelDB, LNPathFinder from electrum.lnrouter import ChannelDB, LNPathFinder
from electrum.lnworker import LNWorker from electrum.lnworker import LNWallet
from electrum.lnmsg import encode_msg, decode_msg from electrum.lnmsg import encode_msg, decode_msg
from .test_lnchannel import create_test_channels from .test_lnchannel import create_test_channels
@ -74,7 +74,7 @@ class MockStorage:
class MockWallet: class MockWallet:
storage = MockStorage() storage = MockStorage()
class MockLNWorker: class MockLNWallet:
def __init__(self, remote_keypair, local_keypair, chan, tx_queue): def __init__(self, remote_keypair, local_keypair, chan, tx_queue):
self.chan = chan self.chan = chan
self.remote_keypair = remote_keypair self.remote_keypair = remote_keypair
@ -85,6 +85,7 @@ class MockLNWorker:
self.preimages = {} self.preimages = {}
self.inflight = {} self.inflight = {}
self.wallet = MockWallet() self.wallet = MockWallet()
self.localfeatures = LnLocalFeatures(0)
@property @property
def lock(self): def lock(self):
@ -112,12 +113,12 @@ class MockLNWorker:
def save_invoice(*args, is_paid=False): def save_invoice(*args, is_paid=False):
pass pass
get_invoice = LNWorker.get_invoice get_invoice = LNWallet.get_invoice
get_preimage = LNWorker.get_preimage get_preimage = LNWallet.get_preimage
_create_route_from_invoice = LNWorker._create_route_from_invoice _create_route_from_invoice = LNWallet._create_route_from_invoice
_check_invoice = staticmethod(LNWorker._check_invoice) _check_invoice = staticmethod(LNWallet._check_invoice)
_pay_to_route = LNWorker._pay_to_route _pay_to_route = LNWallet._pay_to_route
force_close_channel = LNWorker.force_close_channel force_close_channel = LNWallet.force_close_channel
get_first_timestamp = lambda self: 0 get_first_timestamp = lambda self: 0
class MockTransport: class MockTransport:
@ -179,8 +180,8 @@ class TestPeer(SequentialTestCase):
k1, k2 = keypair(), keypair() k1, k2 = keypair(), keypair()
t1, t2 = transport_pair(self.alice_channel.name, self.bob_channel.name) t1, t2 = transport_pair(self.alice_channel.name, self.bob_channel.name)
q1, q2 = asyncio.Queue(), asyncio.Queue() q1, q2 = asyncio.Queue(), asyncio.Queue()
w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1) w1 = MockLNWallet(k1, k2, self.alice_channel, tx_queue=q1)
w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2) w2 = MockLNWallet(k2, k1, self.bob_channel, tx_queue=q2)
p1 = Peer(w1, k1.pubkey, t1) p1 = Peer(w1, k1.pubkey, t1)
p2 = Peer(w2, k2.pubkey, t2) p2 = Peer(w2, k2.pubkey, t2)
w1.peer = p1 w1.peer = p1
@ -215,7 +216,7 @@ class TestPeer(SequentialTestCase):
def prepare_ln_message_future(w2 # receiver def prepare_ln_message_future(w2 # receiver
): ):
fut = asyncio.Future() fut = asyncio.Future()
def evt_set(event, _lnworker, msg, _htlc_id): def evt_set(event, _lnwallet, msg, _htlc_id):
fut.set_result(msg) fut.set_result(msg)
w2.network.register_callback(evt_set, ['ln_message']) w2.network.register_callback(evt_set, ['ln_message'])
return fut return fut
@ -226,7 +227,7 @@ class TestPeer(SequentialTestCase):
fut = self.prepare_ln_message_future(w2) fut = self.prepare_ln_message_future(w2)
async def pay(): async def pay():
addr, peer, coro = await LNWorker._pay(w1, pay_req, same_thread=True) addr, peer, coro = await LNWallet._pay(w1, pay_req, same_thread=True)
await coro await coro
print("HTLC ADDED") print("HTLC ADDED")
self.assertEqual(await fut, 'Payment received') self.assertEqual(await fut, 'Payment received')

4
electrum/wallet.py

@ -64,7 +64,7 @@ from .interface import RequestTimedOut
from .ecc_fast import is_using_fast_ecc from .ecc_fast import is_using_fast_ecc
from .mnemonic import Mnemonic from .mnemonic import Mnemonic
from .logging import get_logger from .logging import get_logger
from .lnworker import LNWorker from .lnworker import LNWallet
if TYPE_CHECKING: if TYPE_CHECKING:
from .network import Network from .network import Network
@ -230,7 +230,7 @@ class Abstract_Wallet(AddressSynchronizer):
self.storage.put('wallet_type', self.wallet_type) self.storage.put('wallet_type', self.wallet_type)
# lightning # lightning
self.lnworker = LNWorker(self) self.lnworker = LNWallet(self)
# invoices and contacts # invoices and contacts
self.invoices = InvoiceStore(self.storage) self.invoices = InvoiceStore(self.storage)
self.contacts = Contacts(self.storage) self.contacts = Contacts(self.storage)

Loading…
Cancel
Save