Browse Source

Merge branch 'master' into patch-2

patch-2
ghost43 5 years ago
committed by GitHub
parent
commit
f114d1ffe2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      .gitignore
  2. 2
      contrib/build-linux/appimage/Dockerfile
  3. 2
      contrib/build-wine/Dockerfile
  4. 8
      electrum/address_synchronizer.py
  5. 8
      electrum/channel_db.py
  6. 15
      electrum/commands.py
  7. 9
      electrum/daemon.py
  8. 12
      electrum/exchange_rate.py
  9. 47
      electrum/gui/kivy/main_window.py
  10. 23
      electrum/gui/kivy/uix/ui_screens/server.kv
  11. 9
      electrum/gui/qt/channel_details.py
  12. 7
      electrum/gui/qt/lightning_dialog.py
  13. 11
      electrum/gui/qt/main_window.py
  14. 124
      electrum/gui/qt/network_dialog.py
  15. 3
      electrum/gui/stdio.py
  16. 31
      electrum/gui/text.py
  17. 129
      electrum/interface.py
  18. 8
      electrum/lnchannel.py
  19. 25
      electrum/lnpeer.py
  20. 15
      electrum/lntransport.py
  21. 6
      electrum/lnwatcher.py
  22. 228
      electrum/lnworker.py
  23. 302
      electrum/network.py
  24. 2
      electrum/scripts/peers.py
  25. 2
      electrum/scripts/txradar.py
  26. 6
      electrum/sql_db.py
  27. 5
      electrum/synchronizer.py
  28. 14
      electrum/tests/test_lnpeer.py
  29. 2
      electrum/tests/test_lntransport.py
  30. 4
      electrum/tests/test_network.py
  31. 128
      electrum/util.py

1
.gitignore

@ -16,6 +16,7 @@ bin/
.idea .idea
.mypy_cache .mypy_cache
.vscode .vscode
electrum_data
# icons # icons
electrum/gui/kivy/theming/light-0.png electrum/gui/kivy/theming/light-0.png

2
contrib/build-linux/appimage/Dockerfile

@ -4,7 +4,7 @@ ENV LC_ALL=C.UTF-8 LANG=C.UTF-8
RUN apt-get update -q && \ RUN apt-get update -q && \
apt-get install -qy \ apt-get install -qy \
git=1:2.7.4-0ubuntu1.7 \ git=1:2.7.4-0ubuntu1.8 \
wget=1.17.1-1ubuntu1.5 \ wget=1.17.1-1ubuntu1.5 \
make=4.1-6 \ make=4.1-6 \
autotools-dev=20150820.1 \ autotools-dev=20150820.1 \

2
contrib/build-wine/Dockerfile

@ -13,7 +13,7 @@ RUN dpkg --add-architecture i386 && \
RUN apt-get update -q && \ RUN apt-get update -q && \
apt-get install -qy \ apt-get install -qy \
git=1:2.17.1-1ubuntu0.5 \ git=1:2.17.1-1ubuntu0.6 \
p7zip-full=16.02+dfsg-6 \ p7zip-full=16.02+dfsg-6 \
make=4.1-9.1ubuntu1 \ make=4.1-9.1ubuntu1 \
mingw-w64=5.0.3-1 \ mingw-w64=5.0.3-1 \

8
electrum/address_synchronizer.py

@ -28,7 +28,7 @@ import itertools
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence, List from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence, List
from . import bitcoin from . import bitcoin, util
from .bitcoin import COINBASE_MATURITY from .bitcoin import COINBASE_MATURITY
from .util import profiler, bfh, TxMinedInfo from .util import profiler, bfh, TxMinedInfo
from .transaction import Transaction, TxOutput, TxInput, PartialTxInput, TxOutpoint, PartialTransaction from .transaction import Transaction, TxOutput, TxInput, PartialTxInput, TxOutpoint, PartialTransaction
@ -161,7 +161,7 @@ class AddressSynchronizer(Logger):
if self.network is not None: if self.network is not None:
self.synchronizer = Synchronizer(self) self.synchronizer = Synchronizer(self)
self.verifier = SPV(self.network, self) self.verifier = SPV(self.network, self)
self.network.register_callback(self.on_blockchain_updated, ['blockchain_updated']) util.register_callback(self.on_blockchain_updated, ['blockchain_updated'])
def on_blockchain_updated(self, event, *args): def on_blockchain_updated(self, event, *args):
self._get_addr_balance_cache = {} # invalidate cache self._get_addr_balance_cache = {} # invalidate cache
@ -174,7 +174,7 @@ class AddressSynchronizer(Logger):
if self.verifier: if self.verifier:
asyncio.run_coroutine_threadsafe(self.verifier.stop(), self.network.asyncio_loop) asyncio.run_coroutine_threadsafe(self.verifier.stop(), self.network.asyncio_loop)
self.verifier = None self.verifier = None
self.network.unregister_callback(self.on_blockchain_updated) util.unregister_callback(self.on_blockchain_updated)
self.db.put('stored_height', self.get_local_height()) self.db.put('stored_height', self.get_local_height())
def add_address(self, address): def add_address(self, address):
@ -546,7 +546,7 @@ class AddressSynchronizer(Logger):
self.unverified_tx.pop(tx_hash, None) self.unverified_tx.pop(tx_hash, None)
self.db.add_verified_tx(tx_hash, info) self.db.add_verified_tx(tx_hash, info)
tx_mined_status = self.get_tx_height(tx_hash) tx_mined_status = self.get_tx_height(tx_hash)
self.network.trigger_callback('verified', self, tx_hash, tx_mined_status) util.trigger_callback('verified', self, tx_hash, tx_mined_status)
def get_unverified_txs(self): def get_unverified_txs(self):
'''Returns a map from tx hash to transaction height''' '''Returns a map from tx hash to transaction height'''

8
electrum/channel_db.py

@ -35,7 +35,7 @@ import threading
from .sql_db import SqlDB, sql from .sql_db import SqlDB, sql
from . import constants from . import constants, util
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
from .logging import Logger from .logging import Logger
from .lnutil import (LNPeerAddr, format_short_channel_id, ShortChannelID, from .lnutil import (LNPeerAddr, format_short_channel_id, ShortChannelID,
@ -242,7 +242,7 @@ class ChannelDB(SqlDB):
def __init__(self, network: 'Network'): def __init__(self, network: 'Network'):
path = os.path.join(get_headers_dir(network.config), 'gossip_db') path = os.path.join(get_headers_dir(network.config), 'gossip_db')
super().__init__(network, path, commit_interval=100) super().__init__(network.asyncio_loop, path, commit_interval=100)
self.lock = threading.RLock() self.lock = threading.RLock()
self.num_nodes = 0 self.num_nodes = 0
self.num_channels = 0 self.num_channels = 0
@ -269,8 +269,8 @@ class ChannelDB(SqlDB):
self.num_nodes = len(self._nodes) self.num_nodes = len(self._nodes)
self.num_channels = len(self._channels) self.num_channels = len(self._channels)
self.num_policies = len(self._policies) self.num_policies = len(self._policies)
self.network.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies) util.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
self.network.trigger_callback('ln_gossip_sync_progress') util.trigger_callback('ln_gossip_sync_progress')
def get_channel_ids(self): def get_channel_ids(self):
with self.lock: with self.lock:

15
electrum/commands.py

@ -53,6 +53,7 @@ from .wallet import Abstract_Wallet, create_new_wallet, restore_wallet_from_text
from .address_synchronizer import TX_HEIGHT_LOCAL from .address_synchronizer import TX_HEIGHT_LOCAL
from .mnemonic import Mnemonic from .mnemonic import Mnemonic
from .lnutil import SENT, RECEIVED from .lnutil import SENT, RECEIVED
from .lnutil import LnFeatures
from .lnutil import ln_dummy_address from .lnutil import ln_dummy_address
from .lnpeer import channel_id_from_funding_tx from .lnpeer import channel_id_from_funding_tx
from .plugin import run_hook from .plugin import run_hook
@ -186,7 +187,7 @@ class Commands:
net_params = self.network.get_parameters() net_params = self.network.get_parameters()
response = { response = {
'path': self.network.config.path, 'path': self.network.config.path,
'server': net_params.host, 'server': net_params.server.host,
'blockchain_height': self.network.get_local_height(), 'blockchain_height': self.network.get_local_height(),
'server_height': self.network.get_server_height(), 'server_height': self.network.get_server_height(),
'spv_nodes': len(self.network.get_interfaces()), 'spv_nodes': len(self.network.get_interfaces()),
@ -965,18 +966,21 @@ class Commands:
# lightning network commands # lightning network commands
@command('wn') @command('wn')
async def add_peer(self, connection_string, timeout=20, wallet: Abstract_Wallet = None): async def add_peer(self, connection_string, timeout=20, gossip=False, wallet: Abstract_Wallet = None):
await wallet.lnworker.add_peer(connection_string) lnworker = self.network.lngossip if gossip else wallet.lnworker
await lnworker.add_peer(connection_string)
return True return True
@command('wn') @command('wn')
async def list_peers(self, wallet: Abstract_Wallet = None): async def list_peers(self, gossip=False, wallet: Abstract_Wallet = None):
lnworker = self.network.lngossip if gossip else wallet.lnworker
return [{ return [{
'node_id':p.pubkey.hex(), 'node_id':p.pubkey.hex(),
'address':p.transport.name(), 'address':p.transport.name(),
'initialized':p.is_initialized(), 'initialized':p.is_initialized(),
'features': str(LnFeatures(p.features)),
'channels': [c.funding_outpoint.to_str() for c in p.channels.values()], 'channels': [c.funding_outpoint.to_str() for c in p.channels.values()],
} for p in wallet.lnworker.peers.values()] } for p in lnworker.peers.values()]
@command('wpn') @command('wpn')
async def open_channel(self, connection_string, amount, push_amount=0, password=None, wallet: Abstract_Wallet = None): async def open_channel(self, connection_string, amount, push_amount=0, password=None, wallet: Abstract_Wallet = None):
@ -1165,6 +1169,7 @@ command_options = {
'from_height': (None, "Only show transactions that confirmed after given block height"), 'from_height': (None, "Only show transactions that confirmed after given block height"),
'to_height': (None, "Only show transactions that confirmed before given block height"), 'to_height': (None, "Only show transactions that confirmed before given block height"),
'iknowwhatimdoing': (None, "Acknowledge that I understand the full implications of what I am about to do"), 'iknowwhatimdoing': (None, "Acknowledge that I understand the full implications of what I am about to do"),
'gossip': (None, "Apply command to gossip node instead of wallet"),
} }

9
electrum/daemon.py

@ -32,6 +32,8 @@ import threading
from typing import Dict, Optional, Tuple, Iterable from typing import Dict, Optional, Tuple, Iterable
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from collections import defaultdict from collections import defaultdict
import concurrent
from concurrent import futures
import aiohttp import aiohttp
from aiohttp import web, client_exceptions from aiohttp import web, client_exceptions
@ -41,6 +43,7 @@ from jsonrpcserver import response
from jsonrpcclient.clients.aiohttp_client import AiohttpClient from jsonrpcclient.clients.aiohttp_client import AiohttpClient
from aiorpcx import TaskGroup from aiorpcx import TaskGroup
from . import util
from .network import Network from .network import Network
from .util import (json_decode, to_bytes, to_string, profiler, standardize_path, constant_time_compare) from .util import (json_decode, to_bytes, to_string, profiler, standardize_path, constant_time_compare)
from .util import PR_PAID, PR_EXPIRED, get_request_status from .util import PR_PAID, PR_EXPIRED, get_request_status
@ -181,7 +184,7 @@ class PayServer(Logger):
self.daemon = daemon self.daemon = daemon
self.config = daemon.config self.config = daemon.config
self.pending = defaultdict(asyncio.Event) self.pending = defaultdict(asyncio.Event)
self.daemon.network.register_callback(self.on_payment, ['payment_received']) util.register_callback(self.on_payment, ['payment_received'])
async def on_payment(self, evt, wallet, key, status): async def on_payment(self, evt, wallet, key, status):
if status == PR_PAID: if status == PR_PAID:
@ -269,6 +272,8 @@ class AuthenticationCredentialsInvalid(AuthenticationError):
class Daemon(Logger): class Daemon(Logger):
network: Optional[Network]
@profiler @profiler
def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True): def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
Logger.__init__(self) Logger.__init__(self)
@ -504,7 +509,7 @@ class Daemon(Logger):
fut = asyncio.run_coroutine_threadsafe(self.taskgroup.cancel_remaining(), self.asyncio_loop) fut = asyncio.run_coroutine_threadsafe(self.taskgroup.cancel_remaining(), self.asyncio_loop)
try: try:
fut.result(timeout=2) fut.result(timeout=2)
except (asyncio.TimeoutError, asyncio.CancelledError): except (concurrent.futures.TimeoutError, concurrent.futures.CancelledError, asyncio.CancelledError):
pass pass
self.logger.info("removing lockfile") self.logger.info("removing lockfile")
remove_lockfile(get_lockfile(self.config)) remove_lockfile(get_lockfile(self.config))

12
electrum/exchange_rate.py

@ -12,6 +12,7 @@ from typing import Sequence, Optional
from aiorpcx.curio import timeout_after, TaskTimeout, TaskGroup from aiorpcx.curio import timeout_after, TaskTimeout, TaskGroup
from . import util
from .bitcoin import COIN from .bitcoin import COIN
from .i18n import _ from .i18n import _
from .util import (ThreadJob, make_dir, log_exceptions, from .util import (ThreadJob, make_dir, log_exceptions,
@ -452,12 +453,11 @@ def get_exchanges_by_ccy(history=True):
class FxThread(ThreadJob): class FxThread(ThreadJob):
def __init__(self, config: SimpleConfig, network: Network): def __init__(self, config: SimpleConfig, network: Optional[Network]):
ThreadJob.__init__(self) ThreadJob.__init__(self)
self.config = config self.config = config
self.network = network self.network = network
if self.network: util.register_callback(self.set_proxy, ['proxy_set'])
self.network.register_callback(self.set_proxy, ['proxy_set'])
self.ccy = self.get_currency() self.ccy = self.get_currency()
self.history_used_spot = False self.history_used_spot = False
self.ccy_combo = None self.ccy_combo = None
@ -567,12 +567,10 @@ class FxThread(ThreadJob):
self.exchange.read_historical_rates(self.ccy, self.cache_dir) self.exchange.read_historical_rates(self.ccy, self.cache_dir)
def on_quotes(self): def on_quotes(self):
if self.network: util.trigger_callback('on_quotes')
self.network.trigger_callback('on_quotes')
def on_history(self): def on_history(self):
if self.network: util.trigger_callback('on_history')
self.network.trigger_callback('on_history')
def exchange_rate(self) -> Decimal: def exchange_rate(self) -> Decimal:
"""Returns the exchange rate as a Decimal""" """Returns the exchange rate as a Decimal"""

47
electrum/gui/kivy/main_window.py

@ -13,6 +13,7 @@ from electrum.storage import WalletStorage, StorageReadWriteError
from electrum.wallet_db import WalletDB from electrum.wallet_db import WalletDB
from electrum.wallet import Wallet, InternalAddressCorruption, Abstract_Wallet from electrum.wallet import Wallet, InternalAddressCorruption, Abstract_Wallet
from electrum.plugin import run_hook from electrum.plugin import run_hook
from electrum import util
from electrum.util import (profiler, InvalidPassword, send_exception_to_crash_reporter, from electrum.util import (profiler, InvalidPassword, send_exception_to_crash_reporter,
format_satoshis, format_satoshis_plain, format_fee_satoshis, format_satoshis, format_satoshis_plain, format_fee_satoshis,
PR_PAID, PR_FAILED, maybe_extract_bolt11_invoice) PR_PAID, PR_FAILED, maybe_extract_bolt11_invoice)
@ -50,7 +51,6 @@ from .uix.dialogs.question import Question
# delayed imports: for startup speed on android # delayed imports: for startup speed on android
notification = app = ref = None notification = app = ref = None
util = False
# register widget cache for keeping memory down timeout to forever to cache # register widget cache for keeping memory down timeout to forever to cache
# the data # the data
@ -145,6 +145,17 @@ class ElectrumWindow(App):
servers = self.network.get_servers() servers = self.network.get_servers()
ChoiceDialog(_('Choose a server'), sorted(servers), popup.ids.host.text, cb2).open() ChoiceDialog(_('Choose a server'), sorted(servers), popup.ids.host.text, cb2).open()
def maybe_switch_to_server(self, server_str: str):
from electrum.interface import ServerAddr
net_params = self.network.get_parameters()
try:
server = ServerAddr.from_str_with_inference(server_str)
except Exception as e:
self.show_error(_("Invalid server details: {}").format(repr(e)))
return
net_params = net_params._replace(server=server)
self.network.run_from_another_thread(self.network.set_parameters(net_params))
def choose_blockchain_dialog(self, dt): def choose_blockchain_dialog(self, dt):
from .uix.dialogs.choice_dialog import ChoiceDialog from .uix.dialogs.choice_dialog import ChoiceDialog
chains = self.network.get_blockchains() chains = self.network.get_blockchains()
@ -348,8 +359,8 @@ class ElectrumWindow(App):
self.num_blocks = self.network.get_local_height() self.num_blocks = self.network.get_local_height()
self.num_nodes = len(self.network.get_interfaces()) self.num_nodes = len(self.network.get_interfaces())
net_params = self.network.get_parameters() net_params = self.network.get_parameters()
self.server_host = net_params.host self.server_host = net_params.server.host
self.server_port = net_params.port self.server_port = str(net_params.server.port)
self.auto_connect = net_params.auto_connect self.auto_connect = net_params.auto_connect
self.oneserver = net_params.oneserver self.oneserver = net_params.oneserver
self.proxy_config = net_params.proxy if net_params.proxy else {} self.proxy_config = net_params.proxy if net_params.proxy else {}
@ -565,20 +576,20 @@ class ElectrumWindow(App):
if self.network: if self.network:
interests = ['wallet_updated', 'network_updated', 'blockchain_updated', interests = ['wallet_updated', 'network_updated', 'blockchain_updated',
'status', 'new_transaction', 'verified'] 'status', 'new_transaction', 'verified']
self.network.register_callback(self.on_network_event, interests) util.register_callback(self.on_network_event, interests)
self.network.register_callback(self.on_fee, ['fee']) util.register_callback(self.on_fee, ['fee'])
self.network.register_callback(self.on_fee_histogram, ['fee_histogram']) util.register_callback(self.on_fee_histogram, ['fee_histogram'])
self.network.register_callback(self.on_quotes, ['on_quotes']) util.register_callback(self.on_quotes, ['on_quotes'])
self.network.register_callback(self.on_history, ['on_history']) util.register_callback(self.on_history, ['on_history'])
self.network.register_callback(self.on_channels, ['channels_updated']) util.register_callback(self.on_channels, ['channels_updated'])
self.network.register_callback(self.on_channel, ['channel']) util.register_callback(self.on_channel, ['channel'])
self.network.register_callback(self.on_invoice_status, ['invoice_status']) util.register_callback(self.on_invoice_status, ['invoice_status'])
self.network.register_callback(self.on_request_status, ['request_status']) util.register_callback(self.on_request_status, ['request_status'])
self.network.register_callback(self.on_payment_failed, ['payment_failed']) util.register_callback(self.on_payment_failed, ['payment_failed'])
self.network.register_callback(self.on_payment_succeeded, ['payment_succeeded']) util.register_callback(self.on_payment_succeeded, ['payment_succeeded'])
self.network.register_callback(self.on_channel_db, ['channel_db']) util.register_callback(self.on_channel_db, ['channel_db'])
self.network.register_callback(self.set_num_peers, ['gossip_peers']) util.register_callback(self.set_num_peers, ['gossip_peers'])
self.network.register_callback(self.set_unknown_channels, ['unknown_channels']) util.register_callback(self.set_unknown_channels, ['unknown_channels'])
# load wallet # load wallet
self.load_wallet_by_name(self.electrum_config.get_wallet_path(use_gui_last_wallet=True)) self.load_wallet_by_name(self.electrum_config.get_wallet_path(use_gui_last_wallet=True))
# URI passed in config # URI passed in config
@ -814,7 +825,7 @@ class ElectrumWindow(App):
if interface: if interface:
self.server_host = interface.host self.server_host = interface.host
else: else:
self.server_host = str(net_params.host) + ' (connecting...)' self.server_host = str(net_params.server.host) + ' (connecting...)'
self.proxy_config = net_params.proxy or {} self.proxy_config = net_params.proxy or {}
self.update_proxy_str(self.proxy_config) self.update_proxy_str(self.proxy_config)

23
electrum/gui/kivy/uix/ui_screens/server.kv

@ -16,27 +16,14 @@ Popup:
height: '36dp' height: '36dp'
size_hint_x: 1 size_hint_x: 1
size_hint_y: None size_hint_y: None
text: _('Host') + ':' text: _('Server') + ':'
TextInput: TextInput:
id: host id: server_str
multiline: False multiline: False
height: '36dp' height: '36dp'
size_hint_x: 3 size_hint_x: 3
size_hint_y: None size_hint_y: None
text: app.network.get_parameters().host text: app.network.get_parameters().server.net_addr_str()
Label:
height: '36dp'
size_hint_x: 1
size_hint_y: None
text: _('Port') + ':'
TextInput:
id: port
multiline: False
input_type: 'number'
height: '36dp'
size_hint_x: 3
size_hint_y: None
text: app.network.get_parameters().port
Widget Widget
Button: Button:
id: chooser id: chooser
@ -56,7 +43,5 @@ Popup:
height: '48dp' height: '48dp'
text: _('OK') text: _('OK')
on_release: on_release:
net_params = app.network.get_parameters() app.maybe_switch_to_server(str(root.ids.server_str.text))
net_params = net_params._replace(host=str(root.ids.host.text), port=str(root.ids.port.text))
app.network.run_from_another_thread(app.network.set_parameters(net_params))
nd.dismiss() nd.dismiss()

9
electrum/gui/qt/channel_details.py

@ -5,6 +5,7 @@ import PyQt5.QtWidgets as QtWidgets
import PyQt5.QtCore as QtCore import PyQt5.QtCore as QtCore
from PyQt5.QtWidgets import QLabel, QLineEdit from PyQt5.QtWidgets import QLabel, QLineEdit
from electrum import util
from electrum.i18n import _ from electrum.i18n import _
from electrum.util import bh2u, format_time from electrum.util import bh2u, format_time
from electrum.lnutil import format_short_channel_id, LOCAL, REMOTE, UpdateAddHtlc, Direction from electrum.lnutil import format_short_channel_id, LOCAL, REMOTE, UpdateAddHtlc, Direction
@ -132,10 +133,10 @@ class ChannelDetailsDialog(QtWidgets.QDialog):
self.htlc_added.connect(self.do_htlc_added) self.htlc_added.connect(self.do_htlc_added)
# register callbacks for updating # register callbacks for updating
window.network.register_callback(self.ln_payment_completed.emit, ['ln_payment_completed']) util.register_callback(self.ln_payment_completed.emit, ['ln_payment_completed'])
window.network.register_callback(self.ln_payment_failed.emit, ['ln_payment_failed']) util.register_callback(self.ln_payment_failed.emit, ['ln_payment_failed'])
window.network.register_callback(self.htlc_added.emit, ['htlc_added']) util.register_callback(self.htlc_added.emit, ['htlc_added'])
window.network.register_callback(self.state_changed.emit, ['channel']) util.register_callback(self.state_changed.emit, ['channel'])
# set attributes of QDialog # set attributes of QDialog
self.setWindowTitle(_('Channel Details')) self.setWindowTitle(_('Channel Details'))

7
electrum/gui/qt/lightning_dialog.py

@ -27,6 +27,7 @@ from typing import TYPE_CHECKING
from PyQt5.QtWidgets import (QDialog, QLabel, QVBoxLayout, QPushButton) from PyQt5.QtWidgets import (QDialog, QLabel, QVBoxLayout, QPushButton)
from electrum import util
from electrum.i18n import _ from electrum.i18n import _
from .util import Buttons from .util import Buttons
@ -58,9 +59,9 @@ class LightningDialog(QDialog):
b = QPushButton(_('Close')) b = QPushButton(_('Close'))
b.clicked.connect(self.close) b.clicked.connect(self.close)
vbox.addLayout(Buttons(b)) vbox.addLayout(Buttons(b))
self.network.register_callback(self.on_channel_db, ['channel_db']) util.register_callback(self.on_channel_db, ['channel_db'])
self.network.register_callback(self.set_num_peers, ['gossip_peers']) util.register_callback(self.set_num_peers, ['gossip_peers'])
self.network.register_callback(self.set_unknown_channels, ['unknown_channels']) util.register_callback(self.set_unknown_channels, ['unknown_channels'])
self.network.channel_db.update_counts() # trigger callback self.network.channel_db.update_counts() # trigger callback
self.set_num_peers('', self.network.lngossip.num_peers()) self.set_num_peers('', self.network.lngossip.num_peers())
self.set_unknown_channels('', len(self.network.lngossip.unknown_ids)) self.set_unknown_channels('', len(self.network.lngossip.unknown_ids))

11
electrum/gui/qt/main_window.py

@ -272,7 +272,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
# window from being GC-ed when closed, callbacks should be # window from being GC-ed when closed, callbacks should be
# methods of this class only, and specifically not be # methods of this class only, and specifically not be
# partials, lambdas or methods of subobjects. Hence... # partials, lambdas or methods of subobjects. Hence...
self.network.register_callback(self.on_network, interests) util.register_callback(self.on_network, interests)
# set initial message # set initial message
self.console.showMessage(self.network.banner) self.console.showMessage(self.network.banner)
@ -466,8 +466,8 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
def load_wallet(self, wallet): def load_wallet(self, wallet):
wallet.thread = TaskThread(self, self.on_error) wallet.thread = TaskThread(self, self.on_error)
self.update_recently_visited(wallet.storage.path) self.update_recently_visited(wallet.storage.path)
if wallet.lnworker and wallet.network: if wallet.lnworker:
wallet.network.trigger_callback('channels_updated', wallet) util.trigger_callback('channels_updated', wallet)
self.need_update.set() self.need_update.set()
# Once GUI has been initialized check if we want to announce something since the callback has been called before the GUI was initialized # Once GUI has been initialized check if we want to announce something since the callback has been called before the GUI was initialized
# update menus # update menus
@ -738,7 +738,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
def donate_to_server(self): def donate_to_server(self):
d = self.network.get_donation_address() d = self.network.get_donation_address()
if d: if d:
host = self.network.get_parameters().host host = self.network.get_parameters().server.host
self.pay_to_URI('bitcoin:%s?message=donation for %s'%(d, host)) self.pay_to_URI('bitcoin:%s?message=donation for %s'%(d, host))
else: else:
self.show_error(_('No donation address for this server')) self.show_error(_('No donation address for this server'))
@ -2889,8 +2889,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
def clean_up(self): def clean_up(self):
self.wallet.thread.stop() self.wallet.thread.stop()
if self.network: util.unregister_callback(self.on_network)
self.network.unregister_callback(self.on_network)
self.config.set_key("is_maximized", self.isMaximized()) self.config.set_key("is_maximized", self.isMaximized())
if not self.isMaximized(): if not self.isMaximized():
g = self.geometry() g = self.geometry()

124
electrum/gui/qt/network_dialog.py

@ -35,8 +35,8 @@ from PyQt5.QtWidgets import (QTreeWidget, QTreeWidgetItem, QMenu, QGridLayout, Q
from PyQt5.QtGui import QFontMetrics from PyQt5.QtGui import QFontMetrics
from electrum.i18n import _ from electrum.i18n import _
from electrum import constants, blockchain from electrum import constants, blockchain, util
from electrum.interface import serialize_server, deserialize_server from electrum.interface import ServerAddr, PREFERRED_NETWORK_PROTOCOL
from electrum.network import Network from electrum.network import Network
from electrum.logging import get_logger from electrum.logging import get_logger
@ -61,7 +61,7 @@ class NetworkDialog(QDialog):
vbox.addLayout(Buttons(CloseButton(self))) vbox.addLayout(Buttons(CloseButton(self)))
self.network_updated_signal_obj.network_updated_signal.connect( self.network_updated_signal_obj.network_updated_signal.connect(
self.on_update) self.on_update)
network.register_callback(self.on_network, ['network_updated']) util.register_callback(self.on_network, ['network_updated'])
def on_network(self, event, *args): def on_network(self, event, *args):
self.network_updated_signal_obj.network_updated_signal.emit(event, args) self.network_updated_signal_obj.network_updated_signal.emit(event, args)
@ -72,10 +72,15 @@ class NetworkDialog(QDialog):
class NodesListWidget(QTreeWidget): class NodesListWidget(QTreeWidget):
"""List of connected servers."""
SERVER_ADDR_ROLE = Qt.UserRole + 100
CHAIN_ID_ROLE = Qt.UserRole + 101
IS_SERVER_ROLE = Qt.UserRole + 102
def __init__(self, parent): def __init__(self, parent):
QTreeWidget.__init__(self) QTreeWidget.__init__(self)
self.parent = parent self.parent = parent # type: NetworkChoiceLayout
self.setHeaderLabels([_('Connected node'), _('Height')]) self.setHeaderLabels([_('Connected node'), _('Height')])
self.setContextMenuPolicy(Qt.CustomContextMenu) self.setContextMenuPolicy(Qt.CustomContextMenu)
self.customContextMenuRequested.connect(self.create_menu) self.customContextMenuRequested.connect(self.create_menu)
@ -84,13 +89,13 @@ class NodesListWidget(QTreeWidget):
item = self.currentItem() item = self.currentItem()
if not item: if not item:
return return
is_server = not bool(item.data(0, Qt.UserRole)) is_server = bool(item.data(0, self.IS_SERVER_ROLE))
menu = QMenu() menu = QMenu()
if is_server: if is_server:
server = item.data(1, Qt.UserRole) server = item.data(0, self.SERVER_ADDR_ROLE) # type: ServerAddr
menu.addAction(_("Use as server"), lambda: self.parent.follow_server(server)) menu.addAction(_("Use as server"), lambda: self.parent.follow_server(server))
else: else:
chain_id = item.data(1, Qt.UserRole) chain_id = item.data(0, self.CHAIN_ID_ROLE)
menu.addAction(_("Follow this branch"), lambda: self.parent.follow_branch(chain_id)) menu.addAction(_("Follow this branch"), lambda: self.parent.follow_branch(chain_id))
menu.exec_(self.viewport().mapToGlobal(position)) menu.exec_(self.viewport().mapToGlobal(position))
@ -117,15 +122,16 @@ class NodesListWidget(QTreeWidget):
name = b.get_name() name = b.get_name()
if n_chains > 1: if n_chains > 1:
x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()]) x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()])
x.setData(0, Qt.UserRole, 1) x.setData(0, self.IS_SERVER_ROLE, 0)
x.setData(1, Qt.UserRole, b.get_id()) x.setData(0, self.CHAIN_ID_ROLE, b.get_id())
else: else:
x = self x = self
for i in interfaces: for i in interfaces:
star = ' *' if i == network.interface else '' star = ' *' if i == network.interface else ''
item = QTreeWidgetItem([i.host + star, '%d'%i.tip]) item = QTreeWidgetItem([i.host + star, '%d'%i.tip])
item.setData(0, Qt.UserRole, 0) item.setData(0, self.IS_SERVER_ROLE, 1)
item.setData(1, Qt.UserRole, i.server) item.setData(0, self.SERVER_ADDR_ROLE, i.server)
item.setToolTip(0, str(i.server))
x.addChild(item) x.addChild(item)
if n_chains > 1: if n_chains > 1:
self.addTopLevelItem(x) self.addTopLevelItem(x)
@ -140,15 +146,17 @@ class NodesListWidget(QTreeWidget):
class ServerListWidget(QTreeWidget): class ServerListWidget(QTreeWidget):
"""List of all known servers."""
class Columns(IntEnum): class Columns(IntEnum):
HOST = 0 HOST = 0
PORT = 1 PORT = 1
SERVER_STR_ROLE = Qt.UserRole + 100 SERVER_ADDR_ROLE = Qt.UserRole + 100
def __init__(self, parent): def __init__(self, parent):
QTreeWidget.__init__(self) QTreeWidget.__init__(self)
self.parent = parent self.parent = parent # type: NetworkChoiceLayout
self.setHeaderLabels([_('Host'), _('Port')]) self.setHeaderLabels([_('Host'), _('Port')])
self.setContextMenuPolicy(Qt.CustomContextMenu) self.setContextMenuPolicy(Qt.CustomContextMenu)
self.customContextMenuRequested.connect(self.create_menu) self.customContextMenuRequested.connect(self.create_menu)
@ -158,14 +166,12 @@ class ServerListWidget(QTreeWidget):
if not item: if not item:
return return
menu = QMenu() menu = QMenu()
server = item.data(self.Columns.HOST, self.SERVER_STR_ROLE) server = item.data(self.Columns.HOST, self.SERVER_ADDR_ROLE)
menu.addAction(_("Use as server"), lambda: self.set_server(server)) menu.addAction(_("Use as server"), lambda: self.set_server(server))
menu.exec_(self.viewport().mapToGlobal(position)) menu.exec_(self.viewport().mapToGlobal(position))
def set_server(self, s): def set_server(self, server: ServerAddr):
host, port, protocol = deserialize_server(s) self.parent.server_e.setText(server.net_addr_str())
self.parent.server_host.setText(host)
self.parent.server_port.setText(port)
self.parent.set_server() self.parent.set_server()
def keyPressEvent(self, event): def keyPressEvent(self, event):
@ -180,16 +186,17 @@ class ServerListWidget(QTreeWidget):
pt.setX(50) pt.setX(50)
self.customContextMenuRequested.emit(pt) self.customContextMenuRequested.emit(pt)
def update(self, servers, protocol, use_tor): def update(self, servers, use_tor):
self.clear() self.clear()
protocol = PREFERRED_NETWORK_PROTOCOL
for _host, d in sorted(servers.items()): for _host, d in sorted(servers.items()):
if _host.endswith('.onion') and not use_tor: if _host.endswith('.onion') and not use_tor:
continue continue
port = d.get(protocol) port = d.get(protocol)
if port: if port:
x = QTreeWidgetItem([_host, port]) x = QTreeWidgetItem([_host, port])
server = serialize_server(_host, port, protocol) server = ServerAddr(_host, port, protocol=protocol)
x.setData(self.Columns.HOST, self.SERVER_STR_ROLE, server) x.setData(self.Columns.HOST, self.SERVER_ADDR_ROLE, server)
self.addTopLevelItem(x) self.addTopLevelItem(x)
h = self.header() h = self.header()
@ -205,7 +212,6 @@ class NetworkChoiceLayout(object):
def __init__(self, network: Network, config, wizard=False): def __init__(self, network: Network, config, wizard=False):
self.network = network self.network = network
self.config = config self.config = config
self.protocol = None
self.tor_proxy = None self.tor_proxy = None
self.tabs = tabs = QTabWidget() self.tabs = tabs = QTabWidget()
@ -223,15 +229,12 @@ class NetworkChoiceLayout(object):
grid = QGridLayout(server_tab) grid = QGridLayout(server_tab)
grid.setSpacing(8) grid.setSpacing(8)
self.server_host = QLineEdit() self.server_e = QLineEdit()
self.server_host.setFixedWidth(fixed_width_hostname) self.server_e.setFixedWidth(fixed_width_hostname + fixed_width_port)
self.server_port = QLineEdit()
self.server_port.setFixedWidth(fixed_width_port)
self.autoconnect_cb = QCheckBox(_('Select server automatically')) self.autoconnect_cb = QCheckBox(_('Select server automatically'))
self.autoconnect_cb.setEnabled(self.config.is_modifiable('auto_connect')) self.autoconnect_cb.setEnabled(self.config.is_modifiable('auto_connect'))
self.server_host.editingFinished.connect(self.set_server) self.server_e.editingFinished.connect(self.set_server)
self.server_port.editingFinished.connect(self.set_server)
self.autoconnect_cb.clicked.connect(self.set_server) self.autoconnect_cb.clicked.connect(self.set_server)
self.autoconnect_cb.clicked.connect(self.update) self.autoconnect_cb.clicked.connect(self.update)
@ -243,8 +246,7 @@ class NetworkChoiceLayout(object):
grid.addWidget(HelpButton(msg), 0, 4) grid.addWidget(HelpButton(msg), 0, 4)
grid.addWidget(QLabel(_('Server') + ':'), 1, 0) grid.addWidget(QLabel(_('Server') + ':'), 1, 0)
grid.addWidget(self.server_host, 1, 1, 1, 2) grid.addWidget(self.server_e, 1, 1, 1, 3)
grid.addWidget(self.server_port, 1, 3)
label = _('Server peers') if network.is_connected() else _('Default Servers') label = _('Server peers') if network.is_connected() else _('Default Servers')
grid.addWidget(QLabel(label), 2, 0, 1, 5) grid.addWidget(QLabel(label), 2, 0, 1, 5)
@ -348,29 +350,26 @@ class NetworkChoiceLayout(object):
def enable_set_server(self): def enable_set_server(self):
if self.config.is_modifiable('server'): if self.config.is_modifiable('server'):
enabled = not self.autoconnect_cb.isChecked() enabled = not self.autoconnect_cb.isChecked()
self.server_host.setEnabled(enabled) self.server_e.setEnabled(enabled)
self.server_port.setEnabled(enabled)
self.servers_list.setEnabled(enabled) self.servers_list.setEnabled(enabled)
else: else:
for w in [self.autoconnect_cb, self.server_host, self.server_port, self.servers_list]: for w in [self.autoconnect_cb, self.server_e, self.servers_list]:
w.setEnabled(False) w.setEnabled(False)
def update(self): def update(self):
net_params = self.network.get_parameters() net_params = self.network.get_parameters()
host, port, protocol = net_params.host, net_params.port, net_params.protocol server = net_params.server
proxy_config, auto_connect = net_params.proxy, net_params.auto_connect proxy_config, auto_connect = net_params.proxy, net_params.auto_connect
if not self.server_host.hasFocus() and not self.server_port.hasFocus(): if not self.server_e.hasFocus():
self.server_host.setText(host) self.server_e.setText(server.net_addr_str())
self.server_port.setText(str(port))
self.autoconnect_cb.setChecked(auto_connect) self.autoconnect_cb.setChecked(auto_connect)
interface = self.network.interface interface = self.network.interface
host = interface.host if interface else _('None') host = interface.host if interface else _('None')
self.server_label.setText(host) self.server_label.setText(host)
self.set_protocol(protocol)
self.servers = self.network.get_servers() self.servers = self.network.get_servers()
self.servers_list.update(self.servers, self.protocol, self.tor_cb.isChecked()) self.servers_list.update(self.servers, self.tor_cb.isChecked())
self.enable_set_server() self.enable_set_server()
height_str = "%d "%(self.network.get_local_height()) + _('blocks') height_str = "%d "%(self.network.get_local_height()) + _('blocks')
@ -411,59 +410,24 @@ class NetworkChoiceLayout(object):
def layout(self): def layout(self):
return self.layout_ return self.layout_
def set_protocol(self, protocol):
if protocol != self.protocol:
self.protocol = protocol
def change_protocol(self, use_ssl):
p = 's' if use_ssl else 't'
host = self.server_host.text()
pp = self.servers.get(host, constants.net.DEFAULT_PORTS)
if p not in pp.keys():
p = list(pp.keys())[0]
port = pp[p]
self.server_host.setText(host)
self.server_port.setText(port)
self.set_protocol(p)
self.set_server()
def follow_branch(self, chain_id): def follow_branch(self, chain_id):
self.network.run_from_another_thread(self.network.follow_chain_given_id(chain_id)) self.network.run_from_another_thread(self.network.follow_chain_given_id(chain_id))
self.update() self.update()
def follow_server(self, server): def follow_server(self, server: ServerAddr):
self.network.run_from_another_thread(self.network.follow_chain_given_server(server)) self.network.run_from_another_thread(self.network.follow_chain_given_server(server))
self.update() self.update()
def server_changed(self, x):
if x:
self.change_server(str(x.text(0)), self.protocol)
def change_server(self, host, protocol):
pp = self.servers.get(host, constants.net.DEFAULT_PORTS)
if protocol and protocol not in protocol_letters:
protocol = None
if protocol:
port = pp.get(protocol)
if port is None:
protocol = None
if not protocol:
if 's' in pp.keys():
protocol = 's'
port = pp.get(protocol)
else:
protocol = list(pp.keys())[0]
port = pp.get(protocol)
self.server_host.setText(host)
self.server_port.setText(port)
def accept(self): def accept(self):
pass pass
def set_server(self): def set_server(self):
net_params = self.network.get_parameters() net_params = self.network.get_parameters()
net_params = net_params._replace(host=str(self.server_host.text()), try:
port=str(self.server_port.text()), server = ServerAddr.from_str_with_inference(str(self.server_e.text()))
except Exception:
return
net_params = net_params._replace(server=server,
auto_connect=self.autoconnect_cb.isChecked()) auto_connect=self.autoconnect_cb.isChecked())
self.network.run_from_another_thread(self.network.set_parameters(net_params)) self.network.run_from_another_thread(self.network.set_parameters(net_params))

3
electrum/gui/stdio.py

@ -3,6 +3,7 @@ import getpass
import datetime import datetime
import logging import logging
from electrum import util
from electrum import WalletStorage, Wallet from electrum import WalletStorage, Wallet
from electrum.util import format_satoshis from electrum.util import format_satoshis
from electrum.bitcoin import is_address, COIN from electrum.bitcoin import is_address, COIN
@ -43,7 +44,7 @@ class ElectrumGui:
self.wallet.start_network(self.network) self.wallet.start_network(self.network)
self.contacts = self.wallet.contacts self.contacts = self.wallet.contacts
self.network.register_callback(self.on_network, ['wallet_updated', 'network_updated', 'banner']) util.register_callback(self.on_network, ['wallet_updated', 'network_updated', 'banner'])
self.commands = [_("[h] - displays this help text"), \ self.commands = [_("[h] - displays this help text"), \
_("[i] - display transaction history"), \ _("[i] - display transaction history"), \
_("[o] - enter payment order"), \ _("[o] - enter payment order"), \

31
electrum/gui/text.py

@ -6,23 +6,31 @@ import locale
from decimal import Decimal from decimal import Decimal
import getpass import getpass
import logging import logging
from typing import TYPE_CHECKING
import electrum import electrum
from electrum import util
from electrum.util import format_satoshis from electrum.util import format_satoshis
from electrum.bitcoin import is_address, COIN from electrum.bitcoin import is_address, COIN
from electrum.transaction import PartialTxOutput from electrum.transaction import PartialTxOutput
from electrum.wallet import Wallet from electrum.wallet import Wallet
from electrum.storage import WalletStorage from electrum.storage import WalletStorage
from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed
from electrum.interface import deserialize_server from electrum.interface import ServerAddr
from electrum.logging import console_stderr_handler from electrum.logging import console_stderr_handler
if TYPE_CHECKING:
from electrum.daemon import Daemon
from electrum.simple_config import SimpleConfig
from electrum.plugin import Plugins
_ = lambda x:x # i18n _ = lambda x:x # i18n
class ElectrumGui: class ElectrumGui:
def __init__(self, config, daemon, plugins): def __init__(self, config: 'SimpleConfig', daemon: 'Daemon', plugins: 'Plugins'):
self.config = config self.config = config
self.network = daemon.network self.network = daemon.network
@ -65,8 +73,7 @@ class ElectrumGui:
self.str_fee = "" self.str_fee = ""
self.history = None self.history = None
if self.network: util.register_callback(self.update, ['wallet_updated', 'network_updated'])
self.network.register_callback(self.update, ['wallet_updated', 'network_updated'])
self.tab_names = [_("History"), _("Send"), _("Receive"), _("Addresses"), _("Contacts"), _("Banner")] self.tab_names = [_("History"), _("Send"), _("Receive"), _("Addresses"), _("Contacts"), _("Banner")]
self.num_tabs = len(self.tab_names) self.num_tabs = len(self.tab_names)
@ -402,26 +409,28 @@ class ElectrumGui:
if not self.network: if not self.network:
return return
net_params = self.network.get_parameters() net_params = self.network.get_parameters()
host, port, protocol = net_params.host, net_params.port, net_params.protocol server_addr = net_params.server
proxy_config, auto_connect = net_params.proxy, net_params.auto_connect proxy_config, auto_connect = net_params.proxy, net_params.auto_connect
srv = 'auto-connect' if auto_connect else self.network.default_server srv = 'auto-connect' if auto_connect else str(self.network.default_server)
out = self.run_dialog('Network', [ out = self.run_dialog('Network', [
{'label':'server', 'type':'str', 'value':srv}, {'label':'server', 'type':'str', 'value':srv},
{'label':'proxy', 'type':'str', 'value':self.config.get('proxy', '')}, {'label':'proxy', 'type':'str', 'value':self.config.get('proxy', '')},
], buttons = 1) ], buttons = 1)
if out: if out:
if out.get('server'): if out.get('server'):
server = out.get('server') server_str = out.get('server')
auto_connect = server == 'auto-connect' auto_connect = server_str == 'auto-connect'
if not auto_connect: if not auto_connect:
try: try:
host, port, protocol = deserialize_server(server) server_addr = ServerAddr.from_str(server_str)
except Exception: except Exception:
self.show_message("Error:" + server + "\nIn doubt, type \"auto-connect\"") self.show_message("Error:" + server_str + "\nIn doubt, type \"auto-connect\"")
return False return False
if out.get('server') or out.get('proxy'): if out.get('server') or out.get('proxy'):
proxy = electrum.network.deserialize_proxy(out.get('proxy')) if out.get('proxy') else proxy_config proxy = electrum.network.deserialize_proxy(out.get('proxy')) if out.get('proxy') else proxy_config
net_params = NetworkParameters(host, port, protocol, proxy, auto_connect) net_params = NetworkParameters(server=server_addr,
proxy=proxy,
auto_connect=auto_connect)
self.network.run_from_another_thread(self.network.set_parameters(net_params)) self.network.run_from_another_thread(self.network.set_parameters(net_params))
def settings_dialog(self): def settings_dialog(self):

129
electrum/interface.py

@ -29,7 +29,7 @@ import sys
import traceback import traceback
import asyncio import asyncio
import socket import socket
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple
from collections import defaultdict from collections import defaultdict
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address
import itertools import itertools
@ -43,7 +43,7 @@ from aiorpcx.jsonrpc import JSONRPC, CodeMessageError
from aiorpcx.rawsocket import RSClient from aiorpcx.rawsocket import RSClient
import certifi import certifi
from .util import ignore_exceptions, log_exceptions, bfh, SilentTaskGroup from .util import ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy
from . import util from . import util
from . import x509 from . import x509
from . import pem from . import pem
@ -65,6 +65,10 @@ BUCKET_NAME_OF_ONION_SERVERS = 'onion'
MAX_INCOMING_MSG_SIZE = 1_000_000 # in bytes MAX_INCOMING_MSG_SIZE = 1_000_000 # in bytes
_KNOWN_NETWORK_PROTOCOLS = {'t', 's'}
PREFERRED_NETWORK_PROTOCOL = 's'
assert PREFERRED_NETWORK_PROTOCOL in _KNOWN_NETWORK_PROTOCOLS
class NetworkTimeout: class NetworkTimeout:
# seconds # seconds
@ -198,22 +202,75 @@ class _RSClient(RSClient):
raise ConnectError(e) from e raise ConnectError(e) from e
def deserialize_server(server_str: str) -> Tuple[str, str, str]: class ServerAddr:
# host might be IPv6 address, hence do rsplit:
host, port, protocol = str(server_str).rsplit(':', 2) def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
assert isinstance(host, str), repr(host)
if protocol is None:
protocol = 's'
if not host: if not host:
raise ValueError('host must not be empty') raise ValueError('host must not be empty')
if host[0] == '[' and host[-1] == ']': # IPv6 if host[0] == '[' and host[-1] == ']': # IPv6
host = host[1:-1] host = host[1:-1]
if protocol not in ('s', 't'): try:
raise ValueError('invalid network protocol: {}'.format(protocol))
net_addr = NetAddress(host, port) # this validates host and port net_addr = NetAddress(host, port) # this validates host and port
host = str(net_addr.host) # canonical form (if e.g. IPv6 address) except Exception as e:
return host, port, protocol raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
if protocol not in _KNOWN_NETWORK_PROTOCOLS:
raise ValueError(f"invalid network protocol: {protocol}")
self.host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
self.port = int(net_addr.port)
self.protocol = protocol
self._net_addr_str = str(net_addr)
@classmethod
def from_str(cls, s: str) -> 'ServerAddr':
# host might be IPv6 address, hence do rsplit:
host, port, protocol = str(s).rsplit(':', 2)
return ServerAddr(host=host, port=port, protocol=protocol)
@classmethod
def from_str_with_inference(cls, s: str) -> Optional['ServerAddr']:
"""Construct ServerAddr from str, guessing missing details.
Ongoing compatibility not guaranteed.
"""
if not s:
return None
items = str(s).rsplit(':', 2)
if len(items) < 2:
return None # although maybe we could guess the port too?
host = items[0]
port = items[1]
if len(items) >= 3:
protocol = items[2]
else:
protocol = PREFERRED_NETWORK_PROTOCOL
return ServerAddr(host=host, port=port, protocol=protocol)
def __str__(self):
return '{}:{}'.format(self.net_addr_str(), self.protocol)
def to_json(self) -> str:
return str(self)
def serialize_server(host: str, port: Union[str, int], protocol: str) -> str: def __repr__(self):
return str(':'.join([host, str(port), protocol])) return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
def net_addr_str(self) -> str:
return self._net_addr_str
def __eq__(self, other):
if not isinstance(other, ServerAddr):
return False
return (self.host == other.host
and self.port == other.port
and self.protocol == other.protocol)
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash((self.host, self.port, self.protocol))
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str: def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
@ -232,19 +289,17 @@ class Interface(Logger):
LOGGING_SHORTCUT = 'i' LOGGING_SHORTCUT = 'i'
def __init__(self, network: 'Network', server: str, proxy: Optional[dict]): def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[dict]):
self.ready = asyncio.Future() self.ready = asyncio.Future()
self.got_disconnected = asyncio.Future() self.got_disconnected = asyncio.Future()
self.server = server self.server = server
self.host, self.port, self.protocol = deserialize_server(self.server)
self.port = int(self.port)
Logger.__init__(self) Logger.__init__(self)
assert network.config.path assert network.config.path
self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host) self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
self.blockchain = None # type: Optional[Blockchain] self.blockchain = None # type: Optional[Blockchain]
self._requested_chunks = set() # type: Set[int] self._requested_chunks = set() # type: Set[int]
self.network = network self.network = network
self._set_proxy(proxy) self.proxy = MySocksProxy.from_proxy_dict(proxy)
self.session = None # type: Optional[NotificationSession] self.session = None # type: Optional[NotificationSession]
self._ipaddr_bucket = None self._ipaddr_bucket = None
@ -259,29 +314,24 @@ class Interface(Logger):
self.network.taskgroup.spawn(self.run()), self.network.asyncio_loop) self.network.taskgroup.spawn(self.run()), self.network.asyncio_loop)
self.taskgroup = SilentTaskGroup() self.taskgroup = SilentTaskGroup()
@property
def host(self):
return self.server.host
@property
def port(self):
return self.server.port
@property
def protocol(self):
return self.server.protocol
def diagnostic_name(self): def diagnostic_name(self):
return str(NetAddress(self.host, self.port)) return self.server.net_addr_str()
def __str__(self): def __str__(self):
return f"<Interface {self.diagnostic_name()}>" return f"<Interface {self.diagnostic_name()}>"
def _set_proxy(self, proxy: dict):
if proxy:
username, pw = proxy.get('user'), proxy.get('password')
if not username or not pw:
auth = None
else:
auth = aiorpcx.socks.SOCKSUserAuth(username, pw)
addr = NetAddress(proxy['host'], proxy['port'])
if proxy['mode'] == "socks4":
self.proxy = aiorpcx.socks.SOCKSProxy(addr, aiorpcx.socks.SOCKS4a, auth)
elif proxy['mode'] == "socks5":
self.proxy = aiorpcx.socks.SOCKSProxy(addr, aiorpcx.socks.SOCKS5, auth)
else:
raise NotImplementedError # http proxy not available with aiorpcx
else:
self.proxy = None
async def is_server_ca_signed(self, ca_ssl_context): async def is_server_ca_signed(self, ca_ssl_context):
"""Given a CA enforcing SSL context, returns True if the connection """Given a CA enforcing SSL context, returns True if the connection
can be established. Returns False if the server has a self-signed can be established. Returns False if the server has a self-signed
@ -435,13 +485,12 @@ class Interface(Logger):
async def get_certificate(self): async def get_certificate(self):
sslc = ssl.SSLContext() sslc = ssl.SSLContext()
try:
async with _RSClient(session_factory=RPCSession, async with _RSClient(session_factory=RPCSession,
host=self.host, port=self.port, host=self.host, port=self.port,
ssl=sslc, proxy=self.proxy) as session: ssl=sslc, proxy=self.proxy) as session:
return session.transport._asyncio_transport._ssl_protocol._sslpipe._sslobj.getpeercert(True) asyncio_transport = session.transport._asyncio_transport # type: asyncio.BaseTransport
except ValueError: ssl_object = asyncio_transport.get_extra_info("ssl_object") # type: ssl.SSLObject
return None return ssl_object.getpeercert(binary_form=True)
async def get_block_header(self, height, assert_mode): async def get_block_header(self, height, assert_mode):
self.logger.info(f'requesting block header {height} in mode {assert_mode}') self.logger.info(f'requesting block header {height} in mode {assert_mode}')
@ -548,7 +597,7 @@ class Interface(Logger):
raise GracefulDisconnect('server tip below max checkpoint') raise GracefulDisconnect('server tip below max checkpoint')
self._mark_ready() self._mark_ready()
await self._process_header_at_tip() await self._process_header_at_tip()
self.network.trigger_callback('network_updated') util.trigger_callback('network_updated')
await self.network.switch_unwanted_fork_interface() await self.network.switch_unwanted_fork_interface()
await self.network.switch_lagging_interface() await self.network.switch_lagging_interface()
@ -563,7 +612,7 @@ class Interface(Logger):
# in the simple case, height == self.tip+1 # in the simple case, height == self.tip+1
if height <= self.tip: if height <= self.tip:
await self.sync_until(height) await self.sync_until(height)
self.network.trigger_callback('blockchain_updated') util.trigger_callback('blockchain_updated')
async def sync_until(self, height, next_height=None): async def sync_until(self, height, next_height=None):
if next_height is None: if next_height is None:
@ -578,7 +627,7 @@ class Interface(Logger):
raise GracefulDisconnect('server chain conflicts with checkpoints or genesis') raise GracefulDisconnect('server chain conflicts with checkpoints or genesis')
last, height = await self.step(height) last, height = await self.step(height)
continue continue
self.network.trigger_callback('network_updated') util.trigger_callback('network_updated')
height = (height // 2016 * 2016) + num_headers height = (height // 2016 * 2016) + num_headers
assert height <= next_height+1, (height, self.tip) assert height <= next_height+1, (height, self.tip)
last = 'catchup' last = 'catchup'

8
electrum/lnchannel.py

@ -33,7 +33,7 @@ from aiorpcx import NetAddress
import attr import attr
from . import ecc from . import ecc
from . import constants from . import constants, util
from .util import bfh, bh2u, chunks, TxMinedInfo from .util import bfh, bh2u, chunks, TxMinedInfo
from .bitcoin import redeem_script_to_address from .bitcoin import redeem_script_to_address
from .crypto import sha256, sha256d from .crypto import sha256, sha256d
@ -679,16 +679,14 @@ class Channel(AbstractChannel):
def set_frozen_for_sending(self, b: bool) -> None: def set_frozen_for_sending(self, b: bool) -> None:
self.storage['frozen_for_sending'] = bool(b) self.storage['frozen_for_sending'] = bool(b)
if self.lnworker: util.trigger_callback('channel', self)
self.lnworker.network.trigger_callback('channel', self)
def is_frozen_for_receiving(self) -> bool: def is_frozen_for_receiving(self) -> bool:
return self.storage.get('frozen_for_receiving', False) return self.storage.get('frozen_for_receiving', False)
def set_frozen_for_receiving(self, b: bool) -> None: def set_frozen_for_receiving(self, b: bool) -> None:
self.storage['frozen_for_receiving'] = bool(b) self.storage['frozen_for_receiving'] = bool(b)
if self.lnworker: util.trigger_callback('channel', self)
self.lnworker.network.trigger_callback('channel', self)
def _assert_can_add_htlc(self, *, htlc_proposer: HTLCOwner, amount_msat: int) -> None: def _assert_can_add_htlc(self, *, htlc_proposer: HTLCOwner, amount_msat: int) -> None:
"""Raises PaymentFailure if the htlc_proposer cannot add this new HTLC. """Raises PaymentFailure if the htlc_proposer cannot add this new HTLC.

25
electrum/lnpeer.py

@ -19,7 +19,7 @@ from datetime import datetime
import aiorpcx import aiorpcx
from .crypto import sha256, sha256d from .crypto import sha256, sha256d
from . import bitcoin from . import bitcoin, util
from . import ecc from . import ecc
from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string, der_sig_from_sig_string from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string, der_sig_from_sig_string
from . import constants from . import constants
@ -74,6 +74,7 @@ class Peer(Logger):
self.lnworker = lnworker self.lnworker = lnworker
self.privkey = self.transport.privkey # local privkey self.privkey = self.transport.privkey # local privkey
self.features = self.lnworker.features self.features = self.lnworker.features
self.their_features = 0
self.node_ids = [self.pubkey, privkey_to_pubkey(self.privkey)] self.node_ids = [self.pubkey, privkey_to_pubkey(self.privkey)]
self.network = lnworker.network self.network = lnworker.network
self.channel_db = lnworker.network.channel_db self.channel_db = lnworker.network.channel_db
@ -200,15 +201,15 @@ class Peer(Logger):
if self._received_init: if self._received_init:
self.logger.info("ALREADY INITIALIZED BUT RECEIVED INIT") self.logger.info("ALREADY INITIALIZED BUT RECEIVED INIT")
return return
their_features = LnFeatures(int.from_bytes(payload['features'], byteorder="big")) self.their_features = LnFeatures(int.from_bytes(payload['features'], byteorder="big"))
their_globalfeatures = int.from_bytes(payload['globalfeatures'], byteorder="big") their_globalfeatures = int.from_bytes(payload['globalfeatures'], byteorder="big")
their_features |= their_globalfeatures self.their_features |= their_globalfeatures
# check transitive dependencies for received features # check transitive dependencies for received features
if not their_features.validate_transitive_dependecies(): if not self.their_features.validate_transitive_dependecies():
raise GracefulDisconnect("remote did not set all dependencies for the features they sent") raise GracefulDisconnect("remote did not set all dependencies for the features they sent")
# check if features are compatible, and set self.features to what we negotiated # check if features are compatible, and set self.features to what we negotiated
try: try:
self.features = ln_compare_features(self.features, their_features) self.features = ln_compare_features(self.features, self.their_features)
except IncompatibleLightningFeatures as e: except IncompatibleLightningFeatures as e:
self.initialized.set_exception(e) self.initialized.set_exception(e)
raise GracefulDisconnect(f"{str(e)}") raise GracefulDisconnect(f"{str(e)}")
@ -219,10 +220,7 @@ class Peer(Logger):
if constants.net.rev_genesis_bytes() not in their_chains: if constants.net.rev_genesis_bytes() not in their_chains:
raise GracefulDisconnect(f"no common chain found with remote. (they sent: {their_chains})") raise GracefulDisconnect(f"no common chain found with remote. (they sent: {their_chains})")
# all checks passed # all checks passed
if self.channel_db and isinstance(self.transport, LNTransport): self.lnworker.on_peer_successfully_established(self)
self.channel_db.add_recent_peer(self.transport.peer_addr)
for chan in self.channels.values():
chan.add_or_update_peer_addr(self.transport.peer_addr)
self._received_init = True self._received_init = True
self.maybe_set_initialized() self.maybe_set_initialized()
@ -254,7 +252,8 @@ class Peer(Logger):
return await func(self, *args, **kwargs) return await func(self, *args, **kwargs)
except GracefulDisconnect as e: except GracefulDisconnect as e:
self.logger.log(e.log_level, f"Disconnecting: {repr(e)}") self.logger.log(e.log_level, f"Disconnecting: {repr(e)}")
except (LightningPeerConnectionClosed, IncompatibleLightningFeatures) as e: except (LightningPeerConnectionClosed, IncompatibleLightningFeatures,
aiorpcx.socks.SOCKSError) as e:
self.logger.info(f"Disconnecting: {repr(e)}") self.logger.info(f"Disconnecting: {repr(e)}")
finally: finally:
self.close_and_cleanup() self.close_and_cleanup()
@ -744,7 +743,7 @@ class Peer(Logger):
f'already in peer_state {chan.peer_state}') f'already in peer_state {chan.peer_state}')
return return
chan.peer_state = PeerState.REESTABLISHING chan.peer_state = PeerState.REESTABLISHING
self.network.trigger_callback('channel', chan) util.trigger_callback('channel', chan)
# BOLT-02: "A node [...] upon disconnection [...] MUST reverse any uncommitted updates sent by the other side" # BOLT-02: "A node [...] upon disconnection [...] MUST reverse any uncommitted updates sent by the other side"
chan.hm.discard_unsigned_remote_updates() chan.hm.discard_unsigned_remote_updates()
# ctns # ctns
@ -891,7 +890,7 @@ class Peer(Logger):
# checks done # checks done
if chan.is_funded() and chan.config[LOCAL].funding_locked_received: if chan.is_funded() and chan.config[LOCAL].funding_locked_received:
self.mark_open(chan) self.mark_open(chan)
self.network.trigger_callback('channel', chan) util.trigger_callback('channel', chan)
if chan.get_state() == ChannelState.CLOSING: if chan.get_state() == ChannelState.CLOSING:
await self.send_shutdown(chan) await self.send_shutdown(chan)
@ -979,7 +978,7 @@ class Peer(Logger):
return return
assert chan.config[LOCAL].funding_locked_received assert chan.config[LOCAL].funding_locked_received
chan.set_state(ChannelState.OPEN) chan.set_state(ChannelState.OPEN)
self.network.trigger_callback('channel', chan) util.trigger_callback('channel', chan)
# peer may have sent us a channel update for the incoming direction previously # peer may have sent us a channel update for the incoming direction previously
pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id) pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
if pending_channel_update: if pending_channel_update:

15
electrum/lntransport.py

@ -8,12 +8,14 @@
import hashlib import hashlib
import asyncio import asyncio
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from typing import Optional
from .crypto import sha256, hmac_oneshot, chacha20_poly1305_encrypt, chacha20_poly1305_decrypt from .crypto import sha256, hmac_oneshot, chacha20_poly1305_encrypt, chacha20_poly1305_decrypt
from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed, from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed,
HandshakeFailed, LNPeerAddr) HandshakeFailed, LNPeerAddr)
from . import ecc from . import ecc
from .util import bh2u from .util import bh2u, MySocksProxy
class HandshakeState(object): class HandshakeState(object):
prologue = b"lightning" prologue = b"lightning"
@ -155,6 +157,8 @@ class LNTransportBase:
class LNResponderTransport(LNTransportBase): class LNResponderTransport(LNTransportBase):
"""Transport initiated by remote party."""
def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter): def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
LNTransportBase.__init__(self) LNTransportBase.__init__(self)
self.reader = reader self.reader = reader
@ -211,19 +215,26 @@ class LNResponderTransport(LNTransportBase):
self.init_counters(ck) self.init_counters(ck)
return rs return rs
class LNTransport(LNTransportBase): class LNTransport(LNTransportBase):
"""Transport initiated by local party."""
def __init__(self, privkey: bytes, peer_addr: LNPeerAddr): def __init__(self, privkey: bytes, peer_addr: LNPeerAddr, *,
proxy: Optional[dict]):
LNTransportBase.__init__(self) LNTransportBase.__init__(self)
assert type(privkey) is bytes and len(privkey) == 32 assert type(privkey) is bytes and len(privkey) == 32
self.privkey = privkey self.privkey = privkey
self.peer_addr = peer_addr self.peer_addr = peer_addr
self.proxy = MySocksProxy.from_proxy_dict(proxy)
def name(self): def name(self):
return self.peer_addr.net_addr_str() return self.peer_addr.net_addr_str()
async def handshake(self): async def handshake(self):
if not self.proxy:
self.reader, self.writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port) self.reader, self.writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port)
else:
self.reader, self.writer = await self.proxy.open_connection(self.peer_addr.host, self.peer_addr.port)
hs = HandshakeState(self.peer_addr.pubkey) hs = HandshakeState(self.peer_addr.pubkey)
# Get a new ephemeral key # Get a new ephemeral key
epriv, epub = create_ephemeral_key() epriv, epub = create_ephemeral_key()

6
electrum/lnwatcher.py

@ -8,6 +8,7 @@ import asyncio
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import NamedTuple, Dict from typing import NamedTuple, Dict
from . import util
from .sql_db import SqlDB, sql from .sql_db import SqlDB, sql
from .wallet_db import WalletDB from .wallet_db import WalletDB
from .util import bh2u, bfh, log_exceptions, ignore_exceptions, TxMinedInfo from .util import bh2u, bfh, log_exceptions, ignore_exceptions, TxMinedInfo
@ -139,7 +140,8 @@ class LNWatcher(AddressSynchronizer):
self.config = network.config self.config = network.config
self.channels = {} self.channels = {}
self.network = network self.network = network
self.network.register_callback(self.on_network_update, util.register_callback(
self.on_network_update,
['network_updated', 'blockchain_updated', 'verified', 'wallet_updated', 'fee']) ['network_updated', 'blockchain_updated', 'verified', 'wallet_updated', 'fee'])
# status gets populated when we run # status gets populated when we run
@ -420,4 +422,4 @@ class LNWalletWatcher(LNWatcher):
tx_was_added = False tx_was_added = False
if tx_was_added: if tx_was_added:
self.logger.info(f'added future tx: {name}. prevout: {prevout}') self.logger.info(f'added future tx: {name}. prevout: {prevout}')
self.network.trigger_callback('wallet_updated', self.lnworker.wallet) util.trigger_callback('wallet_updated', self.lnworker.wallet)

228
electrum/lnworker.py

@ -7,7 +7,7 @@ import os
from decimal import Decimal from decimal import Decimal
import random import random
import time import time
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING, NamedTuple, Union from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping
import threading import threading
import socket import socket
import json import json
@ -21,11 +21,11 @@ import dns.resolver
import dns.exception import dns.exception
from aiorpcx import run_in_thread from aiorpcx import run_in_thread
from . import constants from . import constants, util
from . import keystore from . import keystore
from .util import profiler from .util import profiler
from .util import PR_UNPAID, PR_EXPIRED, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_ROUTING from .util import PR_UNPAID, PR_EXPIRED, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_ROUTING
from .util import PR_TYPE_LN from .util import PR_TYPE_LN, NetworkRetryManager
from .lnutil import LN_MAX_FUNDING_SAT from .lnutil import LN_MAX_FUNDING_SAT
from .keystore import BIP32_KeyStore from .keystore import BIP32_KeyStore
from .bitcoin import COIN from .bitcoin import COIN
@ -77,9 +77,7 @@ SAVED_PR_STATUS = [PR_PAID, PR_UNPAID, PR_INFLIGHT] # status that are persisted
NUM_PEERS_TARGET = 4 NUM_PEERS_TARGET = 4
PEER_RETRY_INTERVAL = 600 # seconds
PEER_RETRY_INTERVAL_FOR_CHANNELS = 30 # seconds
GRAPH_DOWNLOAD_SECONDS = 600
FALLBACK_NODE_LIST_TESTNET = ( FALLBACK_NODE_LIST_TESTNET = (
LNPeerAddr(host='203.132.95.10', port=9735, pubkey=bfh('038863cf8ab91046230f561cd5b386cbff8309fa02e3f0c3ed161a3aeb64a643b9')), LNPeerAddr(host='203.132.95.10', port=9735, pubkey=bfh('038863cf8ab91046230f561cd5b386cbff8309fa02e3f0c3ed161a3aeb64a643b9')),
@ -141,12 +139,20 @@ class NoPathFound(PaymentFailure):
return _('No path found') return _('No path found')
class LNWorker(Logger): class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
def __init__(self, xprv): def __init__(self, xprv):
Logger.__init__(self) Logger.__init__(self)
NetworkRetryManager.__init__(
self,
max_retry_delay_normal=3600,
init_retry_delay_normal=600,
max_retry_delay_urgent=300,
init_retry_delay_urgent=4,
)
self.lock = threading.RLock()
self.node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY) self.node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
self.peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer self._peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer # needs self.lock
self.taskgroup = SilentTaskGroup() self.taskgroup = SilentTaskGroup()
# set some feature flags as baseline for both LNWallet and LNGossip # set some feature flags as baseline for both LNWallet and LNGossip
# note that e.g. DATA_LOSS_PROTECT is needed for LNGossip as many peers require it # note that e.g. DATA_LOSS_PROTECT is needed for LNGossip as many peers require it
@ -156,6 +162,14 @@ class LNWorker(Logger):
self.features |= LnFeatures.VAR_ONION_OPT self.features |= LnFeatures.VAR_ONION_OPT
self.features |= LnFeatures.PAYMENT_SECRET_OPT self.features |= LnFeatures.PAYMENT_SECRET_OPT
util.register_callback(self.on_proxy_changed, ['proxy_set'])
@property
def peers(self) -> Mapping[bytes, Peer]:
"""Returns a read-only copy of peers."""
with self.lock:
return self._peers.copy()
def channels_for_peer(self, node_id): def channels_for_peer(self, node_id):
return {} return {}
@ -175,10 +189,12 @@ class LNWorker(Logger):
self.logger.info('handshake failure from incoming connection') self.logger.info('handshake failure from incoming connection')
return return
peer = Peer(self, node_id, transport) peer = Peer(self, node_id, transport)
self.peers[node_id] = peer with self.lock:
self._peers[node_id] = peer
await self.taskgroup.spawn(peer.main_loop()) await self.taskgroup.spawn(peer.main_loop())
try: try:
# FIXME: server.close(), server.wait_closed(), etc... ? # FIXME: server.close(), server.wait_closed(), etc... ?
# TODO: onion hidden service?
server = await asyncio.start_server(cb, addr, int(port)) server = await asyncio.start_server(cb, addr, int(port))
except OSError as e: except OSError as e:
self.logger.error(f"cannot listen for lightning p2p. error: {e!r}") self.logger.error(f"cannot listen for lightning p2p. error: {e!r}")
@ -200,30 +216,31 @@ class LNWorker(Logger):
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
now = time.time() now = time.time()
if len(self.peers) >= NUM_PEERS_TARGET: if len(self._peers) >= NUM_PEERS_TARGET:
continue continue
peers = await self._get_next_peers_to_try() peers = await self._get_next_peers_to_try()
for peer in peers: for peer in peers:
last_tried = self._last_tried_peer.get(peer, 0) if self._can_retry_addr(peer, now=now):
if last_tried + PEER_RETRY_INTERVAL < now:
await self._add_peer(peer.host, peer.port, peer.pubkey) await self._add_peer(peer.host, peer.port, peer.pubkey)
async def _add_peer(self, host, port, node_id) -> Peer: async def _add_peer(self, host: str, port: int, node_id: bytes) -> Peer:
if node_id in self.peers: if node_id in self._peers:
return self.peers[node_id] return self._peers[node_id]
port = int(port) port = int(port)
peer_addr = LNPeerAddr(host, port, node_id) peer_addr = LNPeerAddr(host, port, node_id)
transport = LNTransport(self.node_keypair.privkey, peer_addr) transport = LNTransport(self.node_keypair.privkey, peer_addr,
self._last_tried_peer[peer_addr] = time.time() proxy=self.network.proxy)
self._trying_addr_now(peer_addr)
self.logger.info(f"adding peer {peer_addr}") self.logger.info(f"adding peer {peer_addr}")
peer = Peer(self, node_id, transport) peer = Peer(self, node_id, transport)
await self.taskgroup.spawn(peer.main_loop()) await self.taskgroup.spawn(peer.main_loop())
self.peers[node_id] = peer with self.lock:
self._peers[node_id] = peer
return peer return peer
def peer_closed(self, peer: Peer) -> None: def peer_closed(self, peer: Peer) -> None:
if peer.pubkey in self.peers: with self.lock:
self.peers.pop(peer.pubkey) self._peers.pop(peer.pubkey, None)
def num_peers(self) -> int: def num_peers(self) -> int:
return sum([p.is_initialized() for p in self.peers.values()]) return sum([p.is_initialized() for p in self.peers.values()])
@ -232,11 +249,9 @@ class LNWorker(Logger):
assert network assert network
self.network = network self.network = network
self.config = network.config self.config = network.config
daemon = network.daemon
self.channel_db = self.network.channel_db self.channel_db = self.network.channel_db
self._last_tried_peer = {} # type: Dict[LNPeerAddr, float] # LNPeerAddr -> unix timestamp
self._add_peers_from_config() self._add_peers_from_config()
asyncio.run_coroutine_threadsafe(daemon.taskgroup.spawn(self.main_loop()), self.network.asyncio_loop) asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
def _add_peers_from_config(self): def _add_peers_from_config(self):
peer_list = self.config.get('lightning_peers', []) peer_list = self.config.get('lightning_peers', [])
@ -260,20 +275,29 @@ class LNWorker(Logger):
#self.logger.info(f'is_good {peer.host}') #self.logger.info(f'is_good {peer.host}')
return True return True
def on_peer_successfully_established(self, peer: Peer) -> None:
if isinstance(peer.transport, LNTransport):
peer_addr = peer.transport.peer_addr
# reset connection attempt count
self._on_connection_successfully_established(peer_addr)
# add into channel db
if self.channel_db:
self.channel_db.add_recent_peer(peer_addr)
# save network address into channels we might have with peer
for chan in peer.channels.values():
chan.add_or_update_peer_addr(peer_addr)
async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]: async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
now = time.time() now = time.time()
await self.channel_db.data_loaded.wait() await self.channel_db.data_loaded.wait()
recent_peers = self.channel_db.get_recent_peers()
# maintenance for last tried times
# due to this, below we can just test membership in _last_tried_peer
for peer in list(self._last_tried_peer):
if now >= self._last_tried_peer[peer] + PEER_RETRY_INTERVAL:
del self._last_tried_peer[peer]
# first try from recent peers # first try from recent peers
recent_peers = self.channel_db.get_recent_peers()
for peer in recent_peers: for peer in recent_peers:
if peer.pubkey in self.peers: if not peer:
continue
if peer.pubkey in self._peers:
continue continue
if peer in self._last_tried_peer: if not self._can_retry_addr(peer, now=now):
continue continue
if not self.is_good_peer(peer): if not self.is_good_peer(peer):
continue continue
@ -290,7 +314,7 @@ class LNWorker(Logger):
peer = LNPeerAddr(host, port, node_id) peer = LNPeerAddr(host, port, node_id)
except ValueError: except ValueError:
continue continue
if peer in self._last_tried_peer: if not self._can_retry_addr(peer, now=now):
continue continue
if not self.is_good_peer(peer): if not self.is_good_peer(peer):
continue continue
@ -305,7 +329,7 @@ class LNWorker(Logger):
else: else:
return [] # regtest?? return [] # regtest??
fallback_list = [peer for peer in fallback_list if peer not in self._last_tried_peer] fallback_list = [peer for peer in fallback_list if self._can_retry_addr(peer, now=now)]
if fallback_list: if fallback_list:
return [random.choice(fallback_list)] return [random.choice(fallback_list)]
@ -363,12 +387,40 @@ class LNWorker(Logger):
choice = random.choice(addr_list) choice = random.choice(addr_list)
return choice return choice
def on_proxy_changed(self, event, *args):
for peer in self.peers.values():
peer.close_and_cleanup()
self._clear_addr_retry_times()
@log_exceptions
async def add_peer(self, connect_str: str) -> Peer:
node_id, rest = extract_nodeid(connect_str)
peer = self._peers.get(node_id)
if not peer:
if rest is not None:
host, port = split_host_port(rest)
else:
addrs = self.channel_db.get_node_addresses(node_id)
if not addrs:
raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id))
host, port, timestamp = self.choose_preferred_address(addrs)
port = int(port)
# Try DNS-resolving the host (if needed). This is simply so that
# the caller gets a nice exception if it cannot be resolved.
try:
await asyncio.get_event_loop().getaddrinfo(host, port)
except socket.gaierror:
raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
# add peer
peer = await self._add_peer(host, port, node_id)
return peer
class LNGossip(LNWorker): class LNGossip(LNWorker):
max_age = 14*24*3600 max_age = 14*24*3600
LOGGING_SHORTCUT = 'g' LOGGING_SHORTCUT = 'g'
def __init__(self, network): def __init__(self):
seed = os.urandom(32) seed = os.urandom(32)
node = BIP32Node.from_rootseed(seed, xtype='standard') node = BIP32Node.from_rootseed(seed, xtype='standard')
xprv = node.to_xprv() xprv = node.to_xprv()
@ -394,16 +446,16 @@ class LNGossip(LNWorker):
known = self.channel_db.get_channel_ids() known = self.channel_db.get_channel_ids()
new = set(ids) - set(known) new = set(ids) - set(known)
self.unknown_ids.update(new) self.unknown_ids.update(new)
self.network.trigger_callback('unknown_channels', len(self.unknown_ids)) util.trigger_callback('unknown_channels', len(self.unknown_ids))
self.network.trigger_callback('gossip_peers', self.num_peers()) util.trigger_callback('gossip_peers', self.num_peers())
self.network.trigger_callback('ln_gossip_sync_progress') util.trigger_callback('ln_gossip_sync_progress')
def get_ids_to_query(self): def get_ids_to_query(self):
N = 500 N = 500
l = list(self.unknown_ids) l = list(self.unknown_ids)
self.unknown_ids = set(l[N:]) self.unknown_ids = set(l[N:])
self.network.trigger_callback('unknown_channels', len(self.unknown_ids)) util.trigger_callback('unknown_channels', len(self.unknown_ids))
self.network.trigger_callback('ln_gossip_sync_progress') util.trigger_callback('ln_gossip_sync_progress')
return l[0:N] return l[0:N]
def get_sync_progress_estimate(self) -> Tuple[Optional[int], Optional[int]]: def get_sync_progress_estimate(self) -> Tuple[Optional[int], Optional[int]]:
@ -431,7 +483,6 @@ class LNWallet(LNWorker):
self.payments = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid self.payments = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid
self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
self.sweep_address = wallet.get_receiving_address() self.sweep_address = wallet.get_receiving_address()
self.lock = threading.RLock()
self.logs = defaultdict(list) # type: Dict[str, List[PaymentAttemptLog]] # key is RHASH # (not persisted) self.logs = defaultdict(list) # type: Dict[str, List[PaymentAttemptLog]] # key is RHASH # (not persisted)
self.is_routing = set() # (not persisted) keys of invoices that are in PR_ROUTING state self.is_routing = set() # (not persisted) keys of invoices that are in PR_ROUTING state
# used in tests # used in tests
@ -515,7 +566,7 @@ class LNWallet(LNWorker):
def peer_closed(self, peer): def peer_closed(self, peer):
for chan in self.channels_for_peer(peer.pubkey).values(): for chan in self.channels_for_peer(peer.pubkey).values():
chan.peer_state = PeerState.DISCONNECTED chan.peer_state = PeerState.DISCONNECTED
self.network.trigger_callback('channel', chan) util.trigger_callback('channel', chan)
super().peer_closed(peer) super().peer_closed(peer)
def get_settled_payments(self): def get_settled_payments(self):
@ -646,14 +697,14 @@ class LNWallet(LNWorker):
def channel_state_changed(self, chan): def channel_state_changed(self, chan):
self.save_channel(chan) self.save_channel(chan)
self.network.trigger_callback('channel', chan) util.trigger_callback('channel', chan)
def save_channel(self, chan): def save_channel(self, chan):
assert type(chan) is Channel assert type(chan) is Channel
if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point: if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point:
raise Exception("Tried to save channel with next_point == current_point, this should not happen") raise Exception("Tried to save channel with next_point == current_point, this should not happen")
self.wallet.save_db() self.wallet.save_db()
self.network.trigger_callback('channel', chan) util.trigger_callback('channel', chan)
def channel_by_txo(self, txo): def channel_by_txo(self, txo):
with self.lock: with self.lock:
@ -669,12 +720,12 @@ class LNWallet(LNWorker):
await self.try_force_closing(chan.channel_id) await self.try_force_closing(chan.channel_id)
elif chan.get_state() == ChannelState.FUNDED: elif chan.get_state() == ChannelState.FUNDED:
peer = self.peers.get(chan.node_id) peer = self._peers.get(chan.node_id)
if peer and peer.is_initialized(): if peer and peer.is_initialized():
peer.send_funding_locked(chan) peer.send_funding_locked(chan)
elif chan.get_state() == ChannelState.OPEN: elif chan.get_state() == ChannelState.OPEN:
peer = self.peers.get(chan.node_id) peer = self._peers.get(chan.node_id)
if peer: if peer:
await peer.maybe_update_fee(chan) await peer.maybe_update_fee(chan)
conf = self.lnwatcher.get_tx_height(chan.funding_outpoint.txid).conf conf = self.lnwatcher.get_tx_height(chan.funding_outpoint.txid).conf
@ -688,9 +739,6 @@ class LNWallet(LNWorker):
self.logger.info('REBROADCASTING CLOSING TX') self.logger.info('REBROADCASTING CLOSING TX')
await self.network.try_broadcasting(force_close_tx, 'force-close') await self.network.try_broadcasting(force_close_tx, 'force-close')
@log_exceptions @log_exceptions
async def _open_channel_coroutine(self, *, connect_str: str, funding_tx: PartialTransaction, async def _open_channel_coroutine(self, *, connect_str: str, funding_tx: PartialTransaction,
funding_sat: int, push_sat: int, funding_sat: int, push_sat: int,
@ -704,7 +752,7 @@ class LNWallet(LNWorker):
funding_sat=funding_sat, funding_sat=funding_sat,
push_msat=push_sat * 1000, push_msat=push_sat * 1000,
temp_channel_id=os.urandom(32)) temp_channel_id=os.urandom(32))
self.network.trigger_callback('channels_updated', self.wallet) util.trigger_callback('channels_updated', self.wallet)
self.wallet.add_transaction(funding_tx) # save tx as local into the wallet self.wallet.add_transaction(funding_tx) # save tx as local into the wallet
self.wallet.set_label(funding_tx.txid(), _('Open channel')) self.wallet.set_label(funding_tx.txid(), _('Open channel'))
if funding_tx.is_complete(): if funding_tx.is_complete():
@ -722,29 +770,6 @@ class LNWallet(LNWorker):
channels_db[chan.channel_id.hex()] = chan.storage channels_db[chan.channel_id.hex()] = chan.storage
self.wallet.save_backup() self.wallet.save_backup()
@log_exceptions
async def add_peer(self, connect_str: str) -> Peer:
node_id, rest = extract_nodeid(connect_str)
peer = self.peers.get(node_id)
if not peer:
if rest is not None:
host, port = split_host_port(rest)
else:
addrs = self.channel_db.get_node_addresses(node_id)
if not addrs:
raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id))
host, port, timestamp = self.choose_preferred_address(addrs)
port = int(port)
# Try DNS-resolving the host (if needed). This is simply so that
# the caller gets a nice exception if it cannot be resolved.
try:
await asyncio.get_event_loop().getaddrinfo(host, port)
except socket.gaierror:
raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
# add peer
peer = await self._add_peer(host, port, node_id)
return peer
def mktx_for_open_channel(self, *, coins: Sequence[PartialTxInput], funding_sat: int, def mktx_for_open_channel(self, *, coins: Sequence[PartialTxInput], funding_sat: int,
fee_est=None) -> PartialTransaction: fee_est=None) -> PartialTransaction:
dummy_address = ln_dummy_address() dummy_address = ln_dummy_address()
@ -805,10 +830,10 @@ class LNWallet(LNWorker):
# note: path-finding runs in a separate thread so that we don't block the asyncio loop # note: path-finding runs in a separate thread so that we don't block the asyncio loop
# graph updates might occur during the computation # graph updates might occur during the computation
self.set_invoice_status(key, PR_ROUTING) self.set_invoice_status(key, PR_ROUTING)
self.network.trigger_callback('invoice_status', key) util.trigger_callback('invoice_status', key)
route = await run_in_thread(self._create_route_from_invoice, lnaddr) route = await run_in_thread(self._create_route_from_invoice, lnaddr)
self.set_invoice_status(key, PR_INFLIGHT) self.set_invoice_status(key, PR_INFLIGHT)
self.network.trigger_callback('invoice_status', key) util.trigger_callback('invoice_status', key)
payment_attempt_log = await self._pay_to_route(route, lnaddr) payment_attempt_log = await self._pay_to_route(route, lnaddr)
except Exception as e: except Exception as e:
log.append(PaymentAttemptLog(success=False, exception=e)) log.append(PaymentAttemptLog(success=False, exception=e))
@ -821,17 +846,17 @@ class LNWallet(LNWorker):
break break
else: else:
reason = _('Failed after {} attempts').format(attempts) reason = _('Failed after {} attempts').format(attempts)
self.network.trigger_callback('invoice_status', key) util.trigger_callback('invoice_status', key)
if success: if success:
self.network.trigger_callback('payment_succeeded', key) util.trigger_callback('payment_succeeded', key)
else: else:
self.network.trigger_callback('payment_failed', key, reason) util.trigger_callback('payment_failed', key, reason)
return success return success
async def _pay_to_route(self, route: LNPaymentRoute, lnaddr: LnAddr) -> PaymentAttemptLog: async def _pay_to_route(self, route: LNPaymentRoute, lnaddr: LnAddr) -> PaymentAttemptLog:
short_channel_id = route[0].short_channel_id short_channel_id = route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id) chan = self.get_channel_by_short_id(short_channel_id)
peer = self.peers.get(route[0].node_id) peer = self._peers.get(route[0].node_id)
if not peer: if not peer:
raise Exception('Dropped peer') raise Exception('Dropped peer')
await peer.initialized await peer.initialized
@ -841,7 +866,7 @@ class LNWallet(LNWorker):
payment_hash=lnaddr.paymenthash, payment_hash=lnaddr.paymenthash,
min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
payment_secret=lnaddr.payment_secret) payment_secret=lnaddr.payment_secret)
self.network.trigger_callback('htlc_added', htlc, lnaddr, SENT) util.trigger_callback('htlc_added', htlc, lnaddr, SENT)
payment_attempt = await self.await_payment(lnaddr.paymenthash) payment_attempt = await self.await_payment(lnaddr.paymenthash)
if payment_attempt.success: if payment_attempt.success:
failure_log = None failure_log = None
@ -1140,9 +1165,9 @@ class LNWallet(LNWorker):
f.set_result(payment_attempt) f.set_result(payment_attempt)
else: else:
chan.logger.info('received unexpected payment_failed, probably from previous session') chan.logger.info('received unexpected payment_failed, probably from previous session')
self.network.trigger_callback('invoice_status', key) util.trigger_callback('invoice_status', key)
self.network.trigger_callback('payment_failed', key, '') util.trigger_callback('payment_failed', key, '')
self.network.trigger_callback('ln_payment_failed', payment_hash, chan.channel_id) util.trigger_callback('ln_payment_failed', payment_hash, chan.channel_id)
def payment_sent(self, chan, payment_hash: bytes): def payment_sent(self, chan, payment_hash: bytes):
self.set_payment_status(payment_hash, PR_PAID) self.set_payment_status(payment_hash, PR_PAID)
@ -1156,14 +1181,14 @@ class LNWallet(LNWorker):
f.set_result(payment_attempt) f.set_result(payment_attempt)
else: else:
chan.logger.info('received unexpected payment_sent, probably from previous session') chan.logger.info('received unexpected payment_sent, probably from previous session')
self.network.trigger_callback('invoice_status', key) util.trigger_callback('invoice_status', key)
self.network.trigger_callback('payment_succeeded', key) util.trigger_callback('payment_succeeded', key)
self.network.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id) util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
def payment_received(self, chan, payment_hash: bytes): def payment_received(self, chan, payment_hash: bytes):
self.set_payment_status(payment_hash, PR_PAID) self.set_payment_status(payment_hash, PR_PAID)
self.network.trigger_callback('request_status', payment_hash.hex(), PR_PAID) util.trigger_callback('request_status', payment_hash.hex(), PR_PAID)
self.network.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id) util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
async def _calc_routing_hints_for_invoice(self, amount_sat): async def _calc_routing_hints_for_invoice(self, amount_sat):
"""calculate routing hints (BOLT-11 'r' field)""" """calculate routing hints (BOLT-11 'r' field)"""
@ -1227,7 +1252,7 @@ class LNWallet(LNWorker):
async def close_channel(self, chan_id): async def close_channel(self, chan_id):
chan = self.channels[chan_id] chan = self.channels[chan_id]
peer = self.peers[chan.node_id] peer = self._peers[chan.node_id]
return await peer.close_channel(chan_id) return await peer.close_channel(chan_id)
async def force_close_channel(self, chan_id): async def force_close_channel(self, chan_id):
@ -1252,8 +1277,8 @@ class LNWallet(LNWorker):
self.channels.pop(chan_id) self.channels.pop(chan_id)
self.db.get('channels').pop(chan_id.hex()) self.db.get('channels').pop(chan_id.hex())
self.network.trigger_callback('channels_updated', self.wallet) util.trigger_callback('channels_updated', self.wallet)
self.network.trigger_callback('wallet_updated', self.wallet) util.trigger_callback('wallet_updated', self.wallet)
@ignore_exceptions @ignore_exceptions
@log_exceptions @log_exceptions
@ -1270,18 +1295,10 @@ class LNWallet(LNWorker):
peer_addresses.append(LNPeerAddr(host, port, chan.node_id)) peer_addresses.append(LNPeerAddr(host, port, chan.node_id))
# will try addresses stored in channel storage # will try addresses stored in channel storage
peer_addresses += list(chan.get_peer_addresses()) peer_addresses += list(chan.get_peer_addresses())
# Done gathering addresses.
# Now select first one that has not failed recently. # Now select first one that has not failed recently.
# Use long retry interval to check. This ensures each address we gathered gets a chance.
for peer in peer_addresses:
last_tried = self._last_tried_peer.get(peer, 0)
if last_tried + PEER_RETRY_INTERVAL < now:
await self._add_peer(peer.host, peer.port, peer.pubkey)
return
# Still here? That means all addresses failed ~recently.
# Use short retry interval now.
for peer in peer_addresses: for peer in peer_addresses:
last_tried = self._last_tried_peer.get(peer, 0) if self._can_retry_addr(peer, urgent=True, now=now):
if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now:
await self._add_peer(peer.host, peer.port, peer.pubkey) await self._add_peer(peer.host, peer.port, peer.pubkey)
return return
@ -1296,7 +1313,7 @@ class LNWallet(LNWorker):
# reestablish # reestablish
if not chan.should_try_to_reestablish_peer(): if not chan.should_try_to_reestablish_peer():
continue continue
peer = self.peers.get(chan.node_id, None) peer = self._peers.get(chan.node_id, None)
if peer: if peer:
await peer.taskgroup.spawn(peer.reestablish_channel(chan)) await peer.taskgroup.spawn(peer.reestablish_channel(chan))
else: else:
@ -1356,7 +1373,7 @@ class LNBackups(Logger):
self.channel_backups[bfh(channel_id)] = ChannelBackup(cb, sweep_address=self.sweep_address, lnworker=self) self.channel_backups[bfh(channel_id)] = ChannelBackup(cb, sweep_address=self.sweep_address, lnworker=self)
def channel_state_changed(self, chan): def channel_state_changed(self, chan):
self.network.trigger_callback('channel', chan) util.trigger_callback('channel', chan)
def peer_closed(self, chan): def peer_closed(self, chan):
pass pass
@ -1390,7 +1407,7 @@ class LNBackups(Logger):
d[channel_id] = cb_storage d[channel_id] = cb_storage
self.channel_backups[bfh(channel_id)] = cb = ChannelBackup(cb_storage, sweep_address=self.sweep_address, lnworker=self) self.channel_backups[bfh(channel_id)] = cb = ChannelBackup(cb_storage, sweep_address=self.sweep_address, lnworker=self)
self.wallet.save_db() self.wallet.save_db()
self.network.trigger_callback('channels_updated', self.wallet) util.trigger_callback('channels_updated', self.wallet)
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address()) self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
def remove_channel_backup(self, channel_id): def remove_channel_backup(self, channel_id):
@ -1400,13 +1417,14 @@ class LNBackups(Logger):
d.pop(channel_id.hex()) d.pop(channel_id.hex())
self.channel_backups.pop(channel_id) self.channel_backups.pop(channel_id)
self.wallet.save_db() self.wallet.save_db()
self.network.trigger_callback('channels_updated', self.wallet) util.trigger_callback('channels_updated', self.wallet)
@log_exceptions @log_exceptions
async def request_force_close(self, channel_id): async def request_force_close(self, channel_id):
cb = self.channel_backups[channel_id].cb cb = self.channel_backups[channel_id].cb
peer_addr = LNPeerAddr(cb.host, cb.port, cb.node_id) peer_addr = LNPeerAddr(cb.host, cb.port, cb.node_id)
transport = LNTransport(cb.privkey, peer_addr) transport = LNTransport(cb.privkey, peer_addr,
proxy=self.network.proxy)
peer = Peer(self, cb.node_id, transport) peer = Peer(self, cb.node_id, transport)
await self.taskgroup.spawn(peer._message_loop()) await self.taskgroup.spawn(peer._message_loop())
await peer.initialized await peer.initialized

302
electrum/network.py

@ -32,7 +32,7 @@ import socket
import json import json
import sys import sys
import asyncio import asyncio
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set
import traceback import traceback
import concurrent import concurrent
from concurrent import futures from concurrent import futures
@ -44,7 +44,7 @@ from aiohttp import ClientResponse
from . import util from . import util
from .util import (log_exceptions, ignore_exceptions, from .util import (log_exceptions, ignore_exceptions,
bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter, bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter,
is_hash256_str, is_non_negative_integer) is_hash256_str, is_non_negative_integer, MyEncoder, NetworkRetryManager)
from .bitcoin import COIN from .bitcoin import COIN
from . import constants from . import constants
@ -53,9 +53,9 @@ from . import bitcoin
from . import dns_hacks from . import dns_hacks
from .transaction import Transaction from .transaction import Transaction
from .blockchain import Blockchain, HEADER_SIZE from .blockchain import Blockchain, HEADER_SIZE
from .interface import (Interface, serialize_server, deserialize_server, from .interface import (Interface, PREFERRED_NETWORK_PROTOCOL,
RequestTimedOut, NetworkTimeout, BUCKET_NAME_OF_ONION_SERVERS, RequestTimedOut, NetworkTimeout, BUCKET_NAME_OF_ONION_SERVERS,
NetworkException, RequestCorrupted) NetworkException, RequestCorrupted, ServerAddr)
from .version import PROTOCOL_VERSION from .version import PROTOCOL_VERSION
from .simple_config import SimpleConfig from .simple_config import SimpleConfig
from .i18n import _ from .i18n import _
@ -71,10 +71,8 @@ if TYPE_CHECKING:
_logger = get_logger(__name__) _logger = get_logger(__name__)
NODES_RETRY_INTERVAL = 60
SERVER_RETRY_INTERVAL = 10
NUM_TARGET_CONNECTED_SERVERS = 10 NUM_TARGET_CONNECTED_SERVERS = 10
NUM_STICKY_SERVERS = 4
NUM_RECENT_SERVERS = 20 NUM_RECENT_SERVERS = 20
@ -117,30 +115,32 @@ def filter_noonion(servers):
return {k: v for k, v in servers.items() if not k.endswith('.onion')} return {k: v for k, v in servers.items() if not k.endswith('.onion')}
def filter_protocol(hostmap, protocol='s'): def filter_protocol(hostmap, *, allowed_protocols: Iterable[str] = None) -> Sequence[ServerAddr]:
'''Filters the hostmap for those implementing protocol. """Filters the hostmap for those implementing protocol."""
The result is a list in serialized form.''' if allowed_protocols is None:
allowed_protocols = {PREFERRED_NETWORK_PROTOCOL}
eligible = [] eligible = []
for host, portmap in hostmap.items(): for host, portmap in hostmap.items():
for protocol in allowed_protocols:
port = portmap.get(protocol) port = portmap.get(protocol)
if port: if port:
eligible.append(serialize_server(host, port, protocol)) eligible.append(ServerAddr(host, port, protocol=protocol))
return eligible return eligible
def pick_random_server(hostmap=None, protocol='s', exclude_set=None): def pick_random_server(hostmap=None, *, allowed_protocols: Iterable[str],
exclude_set: Set[ServerAddr] = None) -> Optional[ServerAddr]:
if hostmap is None: if hostmap is None:
hostmap = constants.net.DEFAULT_SERVERS hostmap = constants.net.DEFAULT_SERVERS
if exclude_set is None: if exclude_set is None:
exclude_set = set() exclude_set = set()
eligible = list(set(filter_protocol(hostmap, protocol)) - exclude_set) servers = set(filter_protocol(hostmap, allowed_protocols=allowed_protocols))
eligible = list(servers - exclude_set)
return random.choice(eligible) if eligible else None return random.choice(eligible) if eligible else None
class NetworkParameters(NamedTuple): class NetworkParameters(NamedTuple):
host: str server: ServerAddr
port: str
protocol: str
proxy: Optional[dict] proxy: Optional[dict]
auto_connect: bool auto_connect: bool
oneserver: bool = False oneserver: bool = False
@ -233,19 +233,33 @@ class UntrustedServerReturnedError(NetworkException):
_INSTANCE = None _INSTANCE = None
class Network(Logger): class Network(Logger, NetworkRetryManager[ServerAddr]):
"""The Network class manages a set of connections to remote electrum """The Network class manages a set of connections to remote electrum
servers, each connected socket is handled by an Interface() object. servers, each connected socket is handled by an Interface() object.
""" """
LOGGING_SHORTCUT = 'n' LOGGING_SHORTCUT = 'n'
taskgroup: Optional[TaskGroup]
interface: Optional[Interface]
interfaces: Dict[ServerAddr, Interface]
_connecting: Set[ServerAddr]
default_server: ServerAddr
_recent_servers: List[ServerAddr]
def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None): def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None):
global _INSTANCE global _INSTANCE
assert _INSTANCE is None, "Network is a singleton!" assert _INSTANCE is None, "Network is a singleton!"
_INSTANCE = self _INSTANCE = self
Logger.__init__(self) Logger.__init__(self)
NetworkRetryManager.__init__(
self,
max_retry_delay_normal=600,
init_retry_delay_normal=15,
max_retry_delay_urgent=10,
init_retry_delay_urgent=1,
)
self.asyncio_loop = asyncio.get_event_loop() self.asyncio_loop = asyncio.get_event_loop()
assert self.asyncio_loop.is_running(), "event loop not running" assert self.asyncio_loop.is_running(), "event loop not running"
@ -261,50 +275,47 @@ class Network(Logger):
self.logger.info(f"blockchains {list(map(lambda b: b.forkpoint, blockchain.blockchains.values()))}") self.logger.info(f"blockchains {list(map(lambda b: b.forkpoint, blockchain.blockchains.values()))}")
self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Optional[Dict] self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Optional[Dict]
self._blockchain = blockchain.get_best_chain() self._blockchain = blockchain.get_best_chain()
self._allowed_protocols = {PREFERRED_NETWORK_PROTOCOL}
# Server for addresses and transactions # Server for addresses and transactions
self.default_server = self.config.get('server', None) self.default_server = self.config.get('server', None)
# Sanitize default server # Sanitize default server
if self.default_server: if self.default_server:
try: try:
deserialize_server(self.default_server) self.default_server = ServerAddr.from_str(self.default_server)
except: except:
self.logger.warning('failed to parse server-string; falling back to localhost:1:s.') self.logger.warning('failed to parse server-string; falling back to localhost:1:s.')
self.default_server = "localhost:1:s" self.default_server = ServerAddr.from_str("localhost:1:s")
if not self.default_server: else:
self.default_server = pick_random_server() self.default_server = pick_random_server(allowed_protocols=self._allowed_protocols)
assert isinstance(self.default_server, ServerAddr), f"invalid type for default_server: {self.default_server!r}"
self.taskgroup = None # type: TaskGroup self.taskgroup = None
# locks # locks
self.restart_lock = asyncio.Lock() self.restart_lock = asyncio.Lock()
self.bhi_lock = asyncio.Lock() self.bhi_lock = asyncio.Lock()
self.callback_lock = threading.Lock()
self.recent_servers_lock = threading.RLock() # <- re-entrant self.recent_servers_lock = threading.RLock() # <- re-entrant
self.interfaces_lock = threading.Lock() # for mutating/iterating self.interfaces self.interfaces_lock = threading.Lock() # for mutating/iterating self.interfaces
self.server_peers = {} # returned by interface (servers that the main interface knows about) self.server_peers = {} # returned by interface (servers that the main interface knows about)
self.recent_servers = self._read_recent_servers() # note: needs self.recent_servers_lock self._recent_servers = self._read_recent_servers() # note: needs self.recent_servers_lock
self.banner = '' self.banner = ''
self.donation_address = '' self.donation_address = ''
self.relay_fee = None # type: Optional[int] self.relay_fee = None # type: Optional[int]
# callbacks set by the GUI
self.callbacks = defaultdict(list) # note: needs self.callback_lock
dir_path = os.path.join(self.config.path, 'certs') dir_path = os.path.join(self.config.path, 'certs')
util.make_dir(dir_path) util.make_dir(dir_path)
# retry times
self.server_retry_time = time.time()
self.nodes_retry_time = time.time()
# the main server we are currently communicating with # the main server we are currently communicating with
self.interface = None # type: Optional[Interface] self.interface = None
self.default_server_changed_event = asyncio.Event() self.default_server_changed_event = asyncio.Event()
# set of servers we have an ongoing connection with # set of servers we have an ongoing connection with
self.interfaces = {} # type: Dict[str, Interface] self.interfaces = {}
self.auto_connect = self.config.get('auto_connect', True) self.auto_connect = self.config.get('auto_connect', True)
self.connecting = set() self._connecting = set()
self.server_queue = None
self.proxy = None self.proxy = None
# Dump network messages (all interfaces). Set at runtime from the console. # Dump network messages (all interfaces). Set at runtime from the console.
@ -332,7 +343,7 @@ class Network(Logger):
from . import channel_db from . import channel_db
self.channel_db = channel_db.ChannelDB(self) self.channel_db = channel_db.ChannelDB(self)
self.path_finder = lnrouter.LNPathFinder(self.channel_db) self.path_finder = lnrouter.LNPathFinder(self.channel_db)
self.lngossip = lnworker.LNGossip(self) self.lngossip = lnworker.LNGossip()
self.lngossip.start_network(self) self.lngossip.start_network(self)
def run_from_another_thread(self, coro, *, timeout=None): def run_from_another_thread(self, coro, *, timeout=None):
@ -350,35 +361,15 @@ class Network(Logger):
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
return func_wrapper return func_wrapper
def register_callback(self, callback, events): def _read_recent_servers(self) -> List[ServerAddr]:
with self.callback_lock:
for event in events:
self.callbacks[event].append(callback)
def unregister_callback(self, callback):
with self.callback_lock:
for callbacks in self.callbacks.values():
if callback in callbacks:
callbacks.remove(callback)
def trigger_callback(self, event, *args):
with self.callback_lock:
callbacks = self.callbacks[event][:]
for callback in callbacks:
# FIXME: if callback throws, we will lose the traceback
if asyncio.iscoroutinefunction(callback):
asyncio.run_coroutine_threadsafe(callback(event, *args), self.asyncio_loop)
else:
self.asyncio_loop.call_soon_threadsafe(callback, event, *args)
def _read_recent_servers(self):
if not self.config.path: if not self.config.path:
return [] return []
path = os.path.join(self.config.path, "recent_servers") path = os.path.join(self.config.path, "recent_servers")
try: try:
with open(path, "r", encoding='utf-8') as f: with open(path, "r", encoding='utf-8') as f:
data = f.read() data = f.read()
return json.loads(data) servers_list = json.loads(data)
return [ServerAddr.from_str(s) for s in servers_list]
except: except:
return [] return []
@ -387,7 +378,7 @@ class Network(Logger):
if not self.config.path: if not self.config.path:
return return
path = os.path.join(self.config.path, "recent_servers") path = os.path.join(self.config.path, "recent_servers")
s = json.dumps(self.recent_servers, indent=4, sort_keys=True) s = json.dumps(self._recent_servers, indent=4, sort_keys=True, cls=MyEncoder)
try: try:
with open(path, "w", encoding='utf-8') as f: with open(path, "w", encoding='utf-8') as f:
f.write(s) f.write(s)
@ -481,15 +472,12 @@ class Network(Logger):
def notify(self, key): def notify(self, key):
if key in ['status', 'updated']: if key in ['status', 'updated']:
self.trigger_callback(key) util.trigger_callback(key)
else: else:
self.trigger_callback(key, self.get_status_value(key)) util.trigger_callback(key, self.get_status_value(key))
def get_parameters(self) -> NetworkParameters: def get_parameters(self) -> NetworkParameters:
host, port, protocol = deserialize_server(self.default_server) return NetworkParameters(server=self.default_server,
return NetworkParameters(host=host,
port=port,
protocol=protocol,
proxy=self.proxy, proxy=self.proxy,
auto_connect=self.auto_connect, auto_connect=self.auto_connect,
oneserver=self.oneserver) oneserver=self.oneserver)
@ -498,7 +486,7 @@ class Network(Logger):
if self.is_connected(): if self.is_connected():
return self.donation_address return self.donation_address
def get_interfaces(self) -> List[str]: def get_interfaces(self) -> List[ServerAddr]:
"""The list of servers for the connected interfaces.""" """The list of servers for the connected interfaces."""
with self.interfaces_lock: with self.interfaces_lock:
return list(self.interfaces) return list(self.interfaces)
@ -540,51 +528,60 @@ class Network(Logger):
# hardcoded servers # hardcoded servers
out.update(constants.net.DEFAULT_SERVERS) out.update(constants.net.DEFAULT_SERVERS)
# add recent servers # add recent servers
for s in self.recent_servers: for server in self._recent_servers:
try: port = str(server.port)
host, port, protocol = deserialize_server(s) if server.host in out:
except: out[server.host].update({server.protocol: port})
continue
if host in out:
out[host].update({protocol: port})
else: else:
out[host] = {protocol: port} out[server.host] = {server.protocol: port}
# potentially filter out some # potentially filter out some
if self.config.get('noonion'): if self.config.get('noonion'):
out = filter_noonion(out) out = filter_noonion(out)
return out return out
def _start_interface(self, server: str): def _get_next_server_to_try(self) -> Optional[ServerAddr]:
if server not in self.interfaces and server not in self.connecting: now = time.time()
if server == self.default_server:
self.logger.info(f"connecting to {server} as new interface")
self._set_status('connecting')
self.connecting.add(server)
self.server_queue.put(server)
def _start_random_interface(self):
with self.interfaces_lock: with self.interfaces_lock:
exclude_set = self.disconnected_servers | set(self.interfaces) | self.connecting connected_servers = set(self.interfaces) | self._connecting
server = pick_random_server(self.get_servers(), self.protocol, exclude_set) # First try from recent servers. (which are persisted)
if server: # As these are servers we successfully connected to recently, they are
self._start_interface(server) # most likely to work. This also makes servers "sticky".
# Note: with sticky servers, it is more difficult for an attacker to eclipse the client,
# however if they succeed, the eclipsing would persist. To try to balance this,
# we only give priority to recent_servers up to NUM_STICKY_SERVERS.
with self.recent_servers_lock:
recent_servers = list(self._recent_servers)
recent_servers = [s for s in recent_servers if s.protocol in self._allowed_protocols]
if len(connected_servers & set(recent_servers)) < NUM_STICKY_SERVERS:
for server in recent_servers:
if server in connected_servers:
continue
if not self._can_retry_addr(server, now=now):
continue
return server return server
# try all servers we know about, pick one at random
hostmap = self.get_servers()
servers = list(set(filter_protocol(hostmap, allowed_protocols=self._allowed_protocols)) - connected_servers)
random.shuffle(servers)
for server in servers:
if not self._can_retry_addr(server, now=now):
continue
return server
return None
def _set_proxy(self, proxy: Optional[dict]): def _set_proxy(self, proxy: Optional[dict]):
self.proxy = proxy self.proxy = proxy
dns_hacks.configure_dns_depending_on_proxy(bool(proxy)) dns_hacks.configure_dns_depending_on_proxy(bool(proxy))
self.logger.info(f'setting proxy {proxy}') self.logger.info(f'setting proxy {proxy}')
self.trigger_callback('proxy_set', self.proxy) util.trigger_callback('proxy_set', self.proxy)
@log_exceptions @log_exceptions
async def set_parameters(self, net_params: NetworkParameters): async def set_parameters(self, net_params: NetworkParameters):
proxy = net_params.proxy proxy = net_params.proxy
proxy_str = serialize_proxy(proxy) proxy_str = serialize_proxy(proxy)
host, port, protocol = net_params.host, net_params.port, net_params.protocol server = net_params.server
server_str = serialize_server(host, port, protocol)
# sanitize parameters # sanitize parameters
try: try:
deserialize_server(serialize_server(host, port, protocol))
if proxy: if proxy:
proxy_modes.index(proxy['mode']) + 1 proxy_modes.index(proxy['mode']) + 1
int(proxy['port']) int(proxy['port'])
@ -593,22 +590,22 @@ class Network(Logger):
self.config.set_key('auto_connect', net_params.auto_connect, False) self.config.set_key('auto_connect', net_params.auto_connect, False)
self.config.set_key('oneserver', net_params.oneserver, False) self.config.set_key('oneserver', net_params.oneserver, False)
self.config.set_key('proxy', proxy_str, False) self.config.set_key('proxy', proxy_str, False)
self.config.set_key('server', server_str, True) self.config.set_key('server', str(server), True)
# abort if changes were not allowed by config # abort if changes were not allowed by config
if self.config.get('server') != server_str \ if self.config.get('server') != str(server) \
or self.config.get('proxy') != proxy_str \ or self.config.get('proxy') != proxy_str \
or self.config.get('oneserver') != net_params.oneserver: or self.config.get('oneserver') != net_params.oneserver:
return return
async with self.restart_lock: async with self.restart_lock:
self.auto_connect = net_params.auto_connect self.auto_connect = net_params.auto_connect
if self.proxy != proxy or self.protocol != protocol or self.oneserver != net_params.oneserver: if self.proxy != proxy or self.oneserver != net_params.oneserver:
# Restart the network defaulting to the given server # Restart the network defaulting to the given server
await self._stop() await self._stop()
self.default_server = server_str self.default_server = server
await self._start() await self._start()
elif self.default_server != server_str: elif self.default_server != server:
await self.switch_to_interface(server_str) await self.switch_to_interface(server)
else: else:
await self.switch_lagging_interface() await self.switch_lagging_interface()
@ -670,7 +667,7 @@ class Network(Logger):
# FIXME switch to best available? # FIXME switch to best available?
self.logger.info("tried to switch to best chain but no interfaces are on it") self.logger.info("tried to switch to best chain but no interfaces are on it")
async def switch_to_interface(self, server: str): async def switch_to_interface(self, server: ServerAddr):
"""Switch to server as our main interface. If no connection exists, """Switch to server as our main interface. If no connection exists,
queue interface to be started. The actual switch will queue interface to be started. The actual switch will
happen when the interface becomes ready. happen when the interface becomes ready.
@ -686,11 +683,11 @@ class Network(Logger):
if old_server and old_server != server: if old_server and old_server != server:
await self._close_interface(old_interface) await self._close_interface(old_interface)
if len(self.interfaces) <= self.num_server: if len(self.interfaces) <= self.num_server:
self._start_interface(old_server) await self.taskgroup.spawn(self._run_new_interface(old_server))
if server not in self.interfaces: if server not in self.interfaces:
self.interface = None self.interface = None
self._start_interface(server) await self.taskgroup.spawn(self._run_new_interface(server))
return return
i = self.interfaces[server] i = self.interfaces[server]
@ -700,12 +697,13 @@ class Network(Logger):
blockchain_updated = i.blockchain != self.blockchain() blockchain_updated = i.blockchain != self.blockchain()
self.interface = i self.interface = i
await i.taskgroup.spawn(self._request_server_info(i)) await i.taskgroup.spawn(self._request_server_info(i))
self.trigger_callback('default_server_changed') util.trigger_callback('default_server_changed')
self.default_server_changed_event.set() self.default_server_changed_event.set()
self.default_server_changed_event.clear() self.default_server_changed_event.clear()
self._set_status('connected') self._set_status('connected')
self.trigger_callback('network_updated') util.trigger_callback('network_updated')
if blockchain_updated: self.trigger_callback('blockchain_updated') if blockchain_updated:
util.trigger_callback('blockchain_updated')
async def _close_interface(self, interface: Interface): async def _close_interface(self, interface: Interface):
if interface: if interface:
@ -717,12 +715,13 @@ class Network(Logger):
await interface.close() await interface.close()
@with_recent_servers_lock @with_recent_servers_lock
def _add_recent_server(self, server): def _add_recent_server(self, server: ServerAddr) -> None:
self._on_connection_successfully_established(server)
# list is ordered # list is ordered
if server in self.recent_servers: if server in self._recent_servers:
self.recent_servers.remove(server) self._recent_servers.remove(server)
self.recent_servers.insert(0, server) self._recent_servers.insert(0, server)
self.recent_servers = self.recent_servers[:NUM_RECENT_SERVERS] self._recent_servers = self._recent_servers[:NUM_RECENT_SERVERS]
self._save_recent_servers() self._save_recent_servers()
async def connection_down(self, interface: Interface): async def connection_down(self, interface: Interface):
@ -730,11 +729,10 @@ class Network(Logger):
We distinguish by whether it is in self.interfaces.''' We distinguish by whether it is in self.interfaces.'''
if not interface: return if not interface: return
server = interface.server server = interface.server
self.disconnected_servers.add(server)
if server == self.default_server: if server == self.default_server:
self._set_status('disconnected') self._set_status('disconnected')
await self._close_interface(interface) await self._close_interface(interface)
self.trigger_callback('network_updated') util.trigger_callback('network_updated')
def get_network_timeout_seconds(self, request_type=NetworkTimeout.Generic) -> int: def get_network_timeout_seconds(self, request_type=NetworkTimeout.Generic) -> int:
if self.oneserver and not self.auto_connect: if self.oneserver and not self.auto_connect:
@ -743,10 +741,18 @@ class Network(Logger):
return request_type.RELAXED return request_type.RELAXED
return request_type.NORMAL return request_type.NORMAL
@ignore_exceptions # do not kill main_taskgroup @ignore_exceptions # do not kill outer taskgroup
@log_exceptions @log_exceptions
async def _run_new_interface(self, server): async def _run_new_interface(self, server: ServerAddr):
interface = Interface(self, server, self.proxy) if server in self.interfaces or server in self._connecting:
return
self._connecting.add(server)
if server == self.default_server:
self.logger.info(f"connecting to {server} as new interface")
self._set_status('connecting')
self._trying_addr_now(server)
interface = Interface(network=self, server=server, proxy=self.proxy)
# note: using longer timeouts here as DNS can sometimes be slow! # note: using longer timeouts here as DNS can sometimes be slow!
timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic) timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic)
try: try:
@ -760,16 +766,16 @@ class Network(Logger):
assert server not in self.interfaces assert server not in self.interfaces
self.interfaces[server] = interface self.interfaces[server] = interface
finally: finally:
try: self.connecting.remove(server) try: self._connecting.remove(server)
except KeyError: pass except KeyError: pass
if server == self.default_server: if server == self.default_server:
await self.switch_to_interface(server) await self.switch_to_interface(server)
self._add_recent_server(server) self._add_recent_server(server)
self.trigger_callback('network_updated') util.trigger_callback('network_updated')
def check_interface_against_healthy_spread_of_connected_servers(self, iface_to_check) -> bool: def check_interface_against_healthy_spread_of_connected_servers(self, iface_to_check: Interface) -> bool:
# main interface is exempt. this makes switching servers easier # main interface is exempt. this makes switching servers easier
if iface_to_check.is_main_server(): if iface_to_check.is_main_server():
return True return True
@ -1093,23 +1099,21 @@ class Network(Logger):
with self.interfaces_lock: interfaces = list(self.interfaces.values()) with self.interfaces_lock: interfaces = list(self.interfaces.values())
interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces)) interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces))
if len(interfaces_on_selected_chain) == 0: return if len(interfaces_on_selected_chain) == 0: return
chosen_iface = random.choice(interfaces_on_selected_chain) chosen_iface = random.choice(interfaces_on_selected_chain) # type: Interface
# switch to server (and save to config) # switch to server (and save to config)
net_params = self.get_parameters() net_params = self.get_parameters()
host, port, protocol = deserialize_server(chosen_iface.server) net_params = net_params._replace(server=chosen_iface.server)
net_params = net_params._replace(host=host, port=port, protocol=protocol)
await self.set_parameters(net_params) await self.set_parameters(net_params)
async def follow_chain_given_server(self, server_str: str) -> None: async def follow_chain_given_server(self, server: ServerAddr) -> None:
# note that server_str should correspond to a connected interface # note that server_str should correspond to a connected interface
iface = self.interfaces.get(server_str) iface = self.interfaces.get(server)
if iface is None: if iface is None:
return return
self._set_preferred_chain(iface.blockchain) self._set_preferred_chain(iface.blockchain)
# switch to server (and save to config) # switch to server (and save to config)
net_params = self.get_parameters() net_params = self.get_parameters()
host, port, protocol = deserialize_server(server_str) net_params = net_params._replace(server=server)
net_params = net_params._replace(host=host, port=port, protocol=protocol)
await self.set_parameters(net_params) await self.set_parameters(net_params)
def get_local_height(self): def get_local_height(self):
@ -1127,14 +1131,12 @@ class Network(Logger):
assert not self.taskgroup assert not self.taskgroup
self.taskgroup = taskgroup = SilentTaskGroup() self.taskgroup = taskgroup = SilentTaskGroup()
assert not self.interface and not self.interfaces assert not self.interface and not self.interfaces
assert not self.connecting and not self.server_queue assert not self._connecting
self.logger.info('starting network') self.logger.info('starting network')
self.disconnected_servers = set([]) self._clear_addr_retry_times()
self.protocol = deserialize_server(self.default_server)[2]
self.server_queue = queue.Queue()
self._set_proxy(deserialize_proxy(self.config.get('proxy'))) self._set_proxy(deserialize_proxy(self.config.get('proxy')))
self._set_oneserver(self.config.get('oneserver', False)) self._set_oneserver(self.config.get('oneserver', False))
self._start_interface(self.default_server) await self.taskgroup.spawn(self._run_new_interface(self.default_server))
async def main(): async def main():
self.logger.info("starting taskgroup.") self.logger.info("starting taskgroup.")
@ -1152,7 +1154,7 @@ class Network(Logger):
self.logger.info("taskgroup stopped.") self.logger.info("taskgroup stopped.")
asyncio.run_coroutine_threadsafe(main(), self.asyncio_loop) asyncio.run_coroutine_threadsafe(main(), self.asyncio_loop)
self.trigger_callback('network_updated') util.trigger_callback('network_updated')
def start(self, jobs: Iterable = None): def start(self, jobs: Iterable = None):
"""Schedule starting the network, along with the given job co-routines. """Schedule starting the network, along with the given job co-routines.
@ -1170,13 +1172,12 @@ class Network(Logger):
await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2) await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2)
except (asyncio.TimeoutError, asyncio.CancelledError) as e: except (asyncio.TimeoutError, asyncio.CancelledError) as e:
self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}") self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}")
self.taskgroup = None # type: TaskGroup self.taskgroup = None
self.interface = None # type: Interface self.interface = None
self.interfaces = {} # type: Dict[str, Interface] self.interfaces = {}
self.connecting.clear() self._connecting.clear()
self.server_queue = None
if not full_shutdown: if not full_shutdown:
self.trigger_callback('network_updated') util.trigger_callback('network_updated')
def stop(self): def stop(self):
assert self._loop_thread != threading.current_thread(), 'must not be called from network thread' assert self._loop_thread != threading.current_thread(), 'must not be called from network thread'
@ -1188,33 +1189,21 @@ class Network(Logger):
async def _ensure_there_is_a_main_interface(self): async def _ensure_there_is_a_main_interface(self):
if self.is_connected(): if self.is_connected():
return return
now = time.time()
# if auto_connect is set, try a different server # if auto_connect is set, try a different server
if self.auto_connect and not self.is_connecting(): if self.auto_connect and not self.is_connecting():
await self._switch_to_random_interface() await self._switch_to_random_interface()
# if auto_connect is not set, or still no main interface, retry current # if auto_connect is not set, or still no main interface, retry current
if not self.is_connected() and not self.is_connecting(): if not self.is_connected() and not self.is_connecting():
if self.default_server in self.disconnected_servers: if self._can_retry_addr(self.default_server, urgent=True):
if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
self.disconnected_servers.remove(self.default_server)
self.server_retry_time = now
else:
await self.switch_to_interface(self.default_server) await self.switch_to_interface(self.default_server)
async def _maintain_sessions(self): async def _maintain_sessions(self):
async def launch_already_queued_up_new_interfaces(): async def maybe_start_new_interfaces():
while self.server_queue.qsize() > 0: for i in range(self.num_server - len(self.interfaces) - len(self._connecting)):
server = self.server_queue.get()
await self.taskgroup.spawn(self._run_new_interface(server))
async def maybe_queue_new_interfaces_to_be_launched_later():
now = time.time()
for i in range(self.num_server - len(self.interfaces) - len(self.connecting)):
# FIXME this should try to honour "healthy spread of connected servers" # FIXME this should try to honour "healthy spread of connected servers"
self._start_random_interface() server = self._get_next_server_to_try()
if now - self.nodes_retry_time > NODES_RETRY_INTERVAL: if server:
self.logger.info('network: retrying connections') await self.taskgroup.spawn(self._run_new_interface(server))
self.disconnected_servers = set([])
self.nodes_retry_time = now
async def maintain_healthy_spread_of_connected_servers(): async def maintain_healthy_spread_of_connected_servers():
with self.interfaces_lock: interfaces = list(self.interfaces.values()) with self.interfaces_lock: interfaces = list(self.interfaces.values())
random.shuffle(interfaces) random.shuffle(interfaces)
@ -1231,8 +1220,7 @@ class Network(Logger):
while True: while True:
try: try:
await launch_already_queued_up_new_interfaces() await maybe_start_new_interfaces()
await maybe_queue_new_interfaces_to_be_launched_later()
await maintain_healthy_spread_of_connected_servers() await maintain_healthy_spread_of_connected_servers()
await maintain_main_interface() await maintain_main_interface()
except asyncio.CancelledError: except asyncio.CancelledError:
@ -1289,10 +1277,10 @@ class Network(Logger):
session = self.interface.session session = self.interface.session
return parse_servers(await session.send_request('server.peers.subscribe')) return parse_servers(await session.send_request('server.peers.subscribe'))
async def send_multiple_requests(self, servers: List[str], method: str, params: Sequence): async def send_multiple_requests(self, servers: Sequence[ServerAddr], method: str, params: Sequence):
responses = dict() responses = dict()
async def get_response(server): async def get_response(server: ServerAddr):
interface = Interface(self, server, self.proxy) interface = Interface(network=self, server=server, proxy=self.proxy)
timeout = self.get_network_timeout_seconds(NetworkTimeout.Urgent) timeout = self.get_network_timeout_seconds(NetworkTimeout.Urgent)
try: try:
await asyncio.wait_for(interface.ready, timeout) await asyncio.wait_for(interface.ready, timeout)

2
electrum/scripts/peers.py

@ -17,7 +17,7 @@ network.start()
async def f(): async def f():
try: try:
peers = await network.get_peers() peers = await network.get_peers()
peers = filter_protocol(peers, 's') peers = filter_protocol(peers)
results = await network.send_multiple_requests(peers, 'blockchain.headers.subscribe', []) results = await network.send_multiple_requests(peers, 'blockchain.headers.subscribe', [])
for server, header in sorted(results.items(), key=lambda x: x[1].get('height')): for server, header in sorted(results.items(), key=lambda x: x[1].get('height')):
height = header.get('height') height = header.get('height')

2
electrum/scripts/txradar.py

@ -23,7 +23,7 @@ network.start()
async def f(): async def f():
try: try:
peers = await network.get_peers() peers = await network.get_peers()
peers = filter_protocol(peers, 's') peers = filter_protocol(peers)
results = await network.send_multiple_requests(peers, 'blockchain.transaction.get', [txid]) results = await network.send_multiple_requests(peers, 'blockchain.transaction.get', [txid])
r1, r2 = [], [] r1, r2 = [], []
for k, v in results.items(): for k, v in results.items():

6
electrum/sql_db.py

@ -19,9 +19,9 @@ def sql(func):
class SqlDB(Logger): class SqlDB(Logger):
def __init__(self, network, path, commit_interval=None): def __init__(self, asyncio_loop, path, commit_interval=None):
Logger.__init__(self) Logger.__init__(self)
self.network = network self.asyncio_loop = asyncio_loop
self.path = path self.path = path
self.commit_interval = commit_interval self.commit_interval = commit_interval
self.db_requests = queue.Queue() self.db_requests = queue.Queue()
@ -34,7 +34,7 @@ class SqlDB(Logger):
self.logger.info("Creating database") self.logger.info("Creating database")
self.create_database() self.create_database()
i = 0 i = 0
while self.network.asyncio_loop.is_running(): while self.asyncio_loop.is_running():
try: try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1) future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty: except queue.Empty:

5
electrum/synchronizer.py

@ -30,6 +30,7 @@ import logging
from aiorpcx import TaskGroup, run_in_thread, RPCError from aiorpcx import TaskGroup, run_in_thread, RPCError
from . import util
from .transaction import Transaction, PartialTransaction from .transaction import Transaction, PartialTransaction
from .util import bh2u, make_aiohttp_session, NetworkJobOnDefaultServer from .util import bh2u, make_aiohttp_session, NetworkJobOnDefaultServer
from .bitcoin import address_to_scripthash, is_address from .bitcoin import address_to_scripthash, is_address
@ -227,7 +228,7 @@ class Synchronizer(SynchronizerBase):
self.wallet.receive_tx_callback(tx_hash, tx, tx_height) self.wallet.receive_tx_callback(tx_hash, tx, tx_height)
self.logger.info(f"received tx {tx_hash} height: {tx_height} bytes: {len(raw_tx)}") self.logger.info(f"received tx {tx_hash} height: {tx_height} bytes: {len(raw_tx)}")
# callbacks # callbacks
self.wallet.network.trigger_callback('new_transaction', self.wallet, tx) util.trigger_callback('new_transaction', self.wallet, tx)
async def main(self): async def main(self):
self.wallet.set_up_to_date(False) self.wallet.set_up_to_date(False)
@ -252,7 +253,7 @@ class Synchronizer(SynchronizerBase):
if up_to_date: if up_to_date:
self._reset_request_counters() self._reset_request_counters()
self.wallet.set_up_to_date(up_to_date) self.wallet.set_up_to_date(up_to_date)
self.wallet.network.trigger_callback('wallet_updated', self.wallet) util.trigger_callback('wallet_updated', self.wallet)
class Notifier(SynchronizerBase): class Notifier(SynchronizerBase):

14
electrum/tests/test_lnpeer.py

@ -17,7 +17,7 @@ from electrum.ecc import ECPrivkey
from electrum import simple_config, lnutil from electrum import simple_config, lnutil
from electrum.lnaddr import lnencode, LnAddr, lndecode from electrum.lnaddr import lnencode, LnAddr, lndecode
from electrum.bitcoin import COIN, sha256 from electrum.bitcoin import COIN, sha256
from electrum.util import bh2u, create_and_start_event_loop from electrum.util import bh2u, create_and_start_event_loop, NetworkRetryManager
from electrum.lnpeer import Peer from electrum.lnpeer import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
@ -64,10 +64,6 @@ class MockNetwork:
def callback_lock(self): def callback_lock(self):
return noop_lock() return noop_lock()
register_callback = Network.register_callback
unregister_callback = Network.unregister_callback
trigger_callback = Network.trigger_callback
def get_local_height(self): def get_local_height(self):
return 0 return 0
@ -99,9 +95,10 @@ class MockWallet:
def is_lightning_backup(self): def is_lightning_backup(self):
return False return False
class MockLNWallet(Logger): class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue): def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue):
Logger.__init__(self) Logger.__init__(self)
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
self.remote_keypair = remote_keypair self.remote_keypair = remote_keypair
self.node_keypair = local_keypair self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue) self.network = MockNetwork(tx_queue)
@ -127,6 +124,10 @@ class MockLNWallet(Logger):
@property @property
def peers(self): def peers(self):
return self._peers
@property
def _peers(self):
return {self.remote_keypair.pubkey: self.peer} return {self.remote_keypair.pubkey: self.peer}
def channels_for_peer(self, pubkey): def channels_for_peer(self, pubkey):
@ -164,6 +165,7 @@ class MockLNWallet(Logger):
force_close_channel = LNWallet.force_close_channel force_close_channel = LNWallet.force_close_channel
try_force_closing = LNWallet.try_force_closing try_force_closing = LNWallet.try_force_closing
get_first_timestamp = lambda self: 0 get_first_timestamp = lambda self: 0
on_peer_successfully_established = LNWallet.on_peer_successfully_established
class MockTransport: class MockTransport:

2
electrum/tests/test_lntransport.py

@ -57,7 +57,7 @@ class TestLNTransport(ElectrumTestCase):
server = server_future.result() # type: asyncio.Server server = server_future.result() # type: asyncio.Server
async def connect(): async def connect():
peer_addr = LNPeerAddr('127.0.0.1', 42898, responder_key.get_public_key_bytes()) peer_addr = LNPeerAddr('127.0.0.1', 42898, responder_key.get_public_key_bytes())
t = LNTransport(initiator_key.get_secret_bytes(), peer_addr) t = LNTransport(initiator_key.get_secret_bytes(), peer_addr, proxy=None)
await t.handshake() await t.handshake()
t.send_bytes(b'hello from client') t.send_bytes(b'hello from client')
self.assertEqual(await t.read_messages().__anext__(), b'hello from server') self.assertEqual(await t.read_messages().__anext__(), b'hello from server')

4
electrum/tests/test_network.py

@ -5,7 +5,7 @@ import unittest
from electrum import constants from electrum import constants
from electrum.simple_config import SimpleConfig from electrum.simple_config import SimpleConfig
from electrum import blockchain from electrum import blockchain
from electrum.interface import Interface from electrum.interface import Interface, ServerAddr
from electrum.crypto import sha256 from electrum.crypto import sha256
from electrum.util import bh2u from electrum.util import bh2u
@ -24,7 +24,7 @@ class MockInterface(Interface):
self.config = config self.config = config
network = MockNetwork() network = MockNetwork()
network.config = config network.config = config
super().__init__(network, 'mock-server:50000:t', None) super().__init__(network=network, server=ServerAddr.from_str('mock-server:50000:t'), proxy=None)
self.q = asyncio.Queue() self.q = asyncio.Queue()
self.blockchain = blockchain.Blockchain(config=self.config, forkpoint=0, self.blockchain = blockchain.Blockchain(config=self.config, forkpoint=0,
parent=None, forkpoint_hash=constants.net.GENESIS, prev_hash=None) parent=None, forkpoint_hash=constants.net.GENESIS, prev_hash=None)

128
electrum/util.py

@ -23,7 +23,8 @@
import binascii import binascii
import os, sys, re, json import os, sys, re, json
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from typing import NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any, Sequence from typing import (NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any,
Sequence, Dict, Generic, TypeVar)
from datetime import datetime from datetime import datetime
import decimal import decimal
from decimal import Decimal from decimal import Decimal
@ -41,9 +42,11 @@ import time
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
import ssl import ssl
import ipaddress import ipaddress
import random
import aiohttp import aiohttp
from aiohttp_socks import ProxyConnector, ProxyType from aiohttp_socks import ProxyConnector, ProxyType
import aiorpcx
from aiorpcx import TaskGroup from aiorpcx import TaskGroup
import certifi import certifi
import dns.resolver import dns.resolver
@ -1130,7 +1133,7 @@ class NetworkJobOnDefaultServer(Logger):
self._restart_lock = asyncio.Lock() self._restart_lock = asyncio.Lock()
self._reset() self._reset()
asyncio.run_coroutine_threadsafe(self._restart(), network.asyncio_loop) asyncio.run_coroutine_threadsafe(self._restart(), network.asyncio_loop)
network.register_callback(self._restart, ['default_server_changed']) register_callback(self._restart, ['default_server_changed'])
def _reset(self): def _reset(self):
"""Initialise fields. Called every time the underlying """Initialise fields. Called every time the underlying
@ -1304,3 +1307,124 @@ def randrange(bound: int) -> int:
"""Return a random integer k such that 1 <= k < bound, uniformly """Return a random integer k such that 1 <= k < bound, uniformly
distributed across that range.""" distributed across that range."""
return ecdsa.util.randrange(bound) return ecdsa.util.randrange(bound)
class CallbackManager:
# callbacks set by the GUI
def __init__(self):
self.callback_lock = threading.Lock()
self.callbacks = defaultdict(list) # note: needs self.callback_lock
self.asyncio_loop = None
def register_callback(self, callback, events):
with self.callback_lock:
for event in events:
self.callbacks[event].append(callback)
def unregister_callback(self, callback):
with self.callback_lock:
for callbacks in self.callbacks.values():
if callback in callbacks:
callbacks.remove(callback)
def trigger_callback(self, event, *args):
if self.asyncio_loop is None:
self.asyncio_loop = asyncio.get_event_loop()
assert self.asyncio_loop.is_running(), "event loop not running"
with self.callback_lock:
callbacks = self.callbacks[event][:]
for callback in callbacks:
# FIXME: if callback throws, we will lose the traceback
if asyncio.iscoroutinefunction(callback):
asyncio.run_coroutine_threadsafe(callback(event, *args), self.asyncio_loop)
else:
self.asyncio_loop.call_soon_threadsafe(callback, event, *args)
callback_mgr = CallbackManager()
trigger_callback = callback_mgr.trigger_callback
register_callback = callback_mgr.register_callback
unregister_callback = callback_mgr.unregister_callback
_NetAddrType = TypeVar("_NetAddrType")
class NetworkRetryManager(Generic[_NetAddrType]):
"""Truncated Exponential Backoff for network connections."""
def __init__(
self, *,
max_retry_delay_normal: float,
init_retry_delay_normal: float,
max_retry_delay_urgent: float = None,
init_retry_delay_urgent: float = None,
):
self._last_tried_addr = {} # type: Dict[_NetAddrType, Tuple[float, int]] # (unix ts, num_attempts)
# note: these all use "seconds" as unit
if max_retry_delay_urgent is None:
max_retry_delay_urgent = max_retry_delay_normal
if init_retry_delay_urgent is None:
init_retry_delay_urgent = init_retry_delay_normal
self._max_retry_delay_normal = max_retry_delay_normal
self._init_retry_delay_normal = init_retry_delay_normal
self._max_retry_delay_urgent = max_retry_delay_urgent
self._init_retry_delay_urgent = init_retry_delay_urgent
def _trying_addr_now(self, addr: _NetAddrType) -> None:
last_time, num_attempts = self._last_tried_addr.get(addr, (0, 0))
# we add up to 1 second of noise to the time, so that clients are less likely
# to get synchronised and bombard the remote in connection waves:
cur_time = time.time() + random.random()
self._last_tried_addr[addr] = cur_time, num_attempts + 1
def _on_connection_successfully_established(self, addr: _NetAddrType) -> None:
self._last_tried_addr[addr] = time.time(), 0
def _can_retry_addr(self, peer: _NetAddrType, *,
now: float = None, urgent: bool = False) -> bool:
if now is None:
now = time.time()
last_time, num_attempts = self._last_tried_addr.get(peer, (0, 0))
if urgent:
delay = min(self._max_retry_delay_urgent,
self._init_retry_delay_urgent * 2 ** num_attempts)
else:
delay = min(self._max_retry_delay_normal,
self._init_retry_delay_normal * 2 ** num_attempts)
next_time = last_time + delay
return next_time < now
def _clear_addr_retry_times(self) -> None:
self._last_tried_addr.clear()
class MySocksProxy(aiorpcx.SOCKSProxy):
async def open_connection(self, host=None, port=None, **kwargs):
loop = asyncio.get_event_loop()
reader = asyncio.StreamReader(loop=loop)
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
transport, _ = await self.create_connection(
lambda: protocol, host, port, **kwargs)
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
return reader, writer
@classmethod
def from_proxy_dict(cls, proxy: dict = None) -> Optional['MySocksProxy']:
if not proxy:
return None
username, pw = proxy.get('user'), proxy.get('password')
if not username or not pw:
auth = None
else:
auth = aiorpcx.socks.SOCKSUserAuth(username, pw)
addr = aiorpcx.NetAddress(proxy['host'], proxy['port'])
if proxy['mode'] == "socks4":
ret = cls(addr, aiorpcx.socks.SOCKS4a, auth)
elif proxy['mode'] == "socks5":
ret = cls(addr, aiorpcx.socks.SOCKS5, auth)
else:
raise NotImplementedError # http proxy not available with aiorpcx
return ret

Loading…
Cancel
Save