Browse Source

Merge branch 'master' into patch-2

patch-2
ghost43 5 years ago
committed by GitHub
parent
commit
f114d1ffe2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      .gitignore
  2. 2
      contrib/build-linux/appimage/Dockerfile
  3. 2
      contrib/build-wine/Dockerfile
  4. 8
      electrum/address_synchronizer.py
  5. 8
      electrum/channel_db.py
  6. 15
      electrum/commands.py
  7. 9
      electrum/daemon.py
  8. 12
      electrum/exchange_rate.py
  9. 47
      electrum/gui/kivy/main_window.py
  10. 23
      electrum/gui/kivy/uix/ui_screens/server.kv
  11. 9
      electrum/gui/qt/channel_details.py
  12. 7
      electrum/gui/qt/lightning_dialog.py
  13. 11
      electrum/gui/qt/main_window.py
  14. 124
      electrum/gui/qt/network_dialog.py
  15. 3
      electrum/gui/stdio.py
  16. 31
      electrum/gui/text.py
  17. 145
      electrum/interface.py
  18. 8
      electrum/lnchannel.py
  19. 25
      electrum/lnpeer.py
  20. 17
      electrum/lntransport.py
  21. 8
      electrum/lnwatcher.py
  22. 228
      electrum/lnworker.py
  23. 308
      electrum/network.py
  24. 2
      electrum/scripts/peers.py
  25. 2
      electrum/scripts/txradar.py
  26. 6
      electrum/sql_db.py
  27. 5
      electrum/synchronizer.py
  28. 14
      electrum/tests/test_lnpeer.py
  29. 2
      electrum/tests/test_lntransport.py
  30. 4
      electrum/tests/test_network.py
  31. 128
      electrum/util.py

1
.gitignore

@ -16,6 +16,7 @@ bin/
.idea
.mypy_cache
.vscode
electrum_data
# icons
electrum/gui/kivy/theming/light-0.png

2
contrib/build-linux/appimage/Dockerfile

@ -4,7 +4,7 @@ ENV LC_ALL=C.UTF-8 LANG=C.UTF-8
RUN apt-get update -q && \
apt-get install -qy \
git=1:2.7.4-0ubuntu1.7 \
git=1:2.7.4-0ubuntu1.8 \
wget=1.17.1-1ubuntu1.5 \
make=4.1-6 \
autotools-dev=20150820.1 \

2
contrib/build-wine/Dockerfile

@ -13,7 +13,7 @@ RUN dpkg --add-architecture i386 && \
RUN apt-get update -q && \
apt-get install -qy \
git=1:2.17.1-1ubuntu0.5 \
git=1:2.17.1-1ubuntu0.6 \
p7zip-full=16.02+dfsg-6 \
make=4.1-9.1ubuntu1 \
mingw-w64=5.0.3-1 \

8
electrum/address_synchronizer.py

@ -28,7 +28,7 @@ import itertools
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence, List
from . import bitcoin
from . import bitcoin, util
from .bitcoin import COINBASE_MATURITY
from .util import profiler, bfh, TxMinedInfo
from .transaction import Transaction, TxOutput, TxInput, PartialTxInput, TxOutpoint, PartialTransaction
@ -161,7 +161,7 @@ class AddressSynchronizer(Logger):
if self.network is not None:
self.synchronizer = Synchronizer(self)
self.verifier = SPV(self.network, self)
self.network.register_callback(self.on_blockchain_updated, ['blockchain_updated'])
util.register_callback(self.on_blockchain_updated, ['blockchain_updated'])
def on_blockchain_updated(self, event, *args):
self._get_addr_balance_cache = {} # invalidate cache
@ -174,7 +174,7 @@ class AddressSynchronizer(Logger):
if self.verifier:
asyncio.run_coroutine_threadsafe(self.verifier.stop(), self.network.asyncio_loop)
self.verifier = None
self.network.unregister_callback(self.on_blockchain_updated)
util.unregister_callback(self.on_blockchain_updated)
self.db.put('stored_height', self.get_local_height())
def add_address(self, address):
@ -546,7 +546,7 @@ class AddressSynchronizer(Logger):
self.unverified_tx.pop(tx_hash, None)
self.db.add_verified_tx(tx_hash, info)
tx_mined_status = self.get_tx_height(tx_hash)
self.network.trigger_callback('verified', self, tx_hash, tx_mined_status)
util.trigger_callback('verified', self, tx_hash, tx_mined_status)
def get_unverified_txs(self):
'''Returns a map from tx hash to transaction height'''

8
electrum/channel_db.py

@ -35,7 +35,7 @@ import threading
from .sql_db import SqlDB, sql
from . import constants
from . import constants, util
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
from .logging import Logger
from .lnutil import (LNPeerAddr, format_short_channel_id, ShortChannelID,
@ -242,7 +242,7 @@ class ChannelDB(SqlDB):
def __init__(self, network: 'Network'):
path = os.path.join(get_headers_dir(network.config), 'gossip_db')
super().__init__(network, path, commit_interval=100)
super().__init__(network.asyncio_loop, path, commit_interval=100)
self.lock = threading.RLock()
self.num_nodes = 0
self.num_channels = 0
@ -269,8 +269,8 @@ class ChannelDB(SqlDB):
self.num_nodes = len(self._nodes)
self.num_channels = len(self._channels)
self.num_policies = len(self._policies)
self.network.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
self.network.trigger_callback('ln_gossip_sync_progress')
util.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
util.trigger_callback('ln_gossip_sync_progress')
def get_channel_ids(self):
with self.lock:

15
electrum/commands.py

@ -53,6 +53,7 @@ from .wallet import Abstract_Wallet, create_new_wallet, restore_wallet_from_text
from .address_synchronizer import TX_HEIGHT_LOCAL
from .mnemonic import Mnemonic
from .lnutil import SENT, RECEIVED
from .lnutil import LnFeatures
from .lnutil import ln_dummy_address
from .lnpeer import channel_id_from_funding_tx
from .plugin import run_hook
@ -186,7 +187,7 @@ class Commands:
net_params = self.network.get_parameters()
response = {
'path': self.network.config.path,
'server': net_params.host,
'server': net_params.server.host,
'blockchain_height': self.network.get_local_height(),
'server_height': self.network.get_server_height(),
'spv_nodes': len(self.network.get_interfaces()),
@ -965,18 +966,21 @@ class Commands:
# lightning network commands
@command('wn')
async def add_peer(self, connection_string, timeout=20, wallet: Abstract_Wallet = None):
await wallet.lnworker.add_peer(connection_string)
async def add_peer(self, connection_string, timeout=20, gossip=False, wallet: Abstract_Wallet = None):
lnworker = self.network.lngossip if gossip else wallet.lnworker
await lnworker.add_peer(connection_string)
return True
@command('wn')
async def list_peers(self, wallet: Abstract_Wallet = None):
async def list_peers(self, gossip=False, wallet: Abstract_Wallet = None):
lnworker = self.network.lngossip if gossip else wallet.lnworker
return [{
'node_id':p.pubkey.hex(),
'address':p.transport.name(),
'initialized':p.is_initialized(),
'features': str(LnFeatures(p.features)),
'channels': [c.funding_outpoint.to_str() for c in p.channels.values()],
} for p in wallet.lnworker.peers.values()]
} for p in lnworker.peers.values()]
@command('wpn')
async def open_channel(self, connection_string, amount, push_amount=0, password=None, wallet: Abstract_Wallet = None):
@ -1165,6 +1169,7 @@ command_options = {
'from_height': (None, "Only show transactions that confirmed after given block height"),
'to_height': (None, "Only show transactions that confirmed before given block height"),
'iknowwhatimdoing': (None, "Acknowledge that I understand the full implications of what I am about to do"),
'gossip': (None, "Apply command to gossip node instead of wallet"),
}

9
electrum/daemon.py

@ -32,6 +32,8 @@ import threading
from typing import Dict, Optional, Tuple, Iterable
from base64 import b64decode, b64encode
from collections import defaultdict
import concurrent
from concurrent import futures
import aiohttp
from aiohttp import web, client_exceptions
@ -41,6 +43,7 @@ from jsonrpcserver import response
from jsonrpcclient.clients.aiohttp_client import AiohttpClient
from aiorpcx import TaskGroup
from . import util
from .network import Network
from .util import (json_decode, to_bytes, to_string, profiler, standardize_path, constant_time_compare)
from .util import PR_PAID, PR_EXPIRED, get_request_status
@ -181,7 +184,7 @@ class PayServer(Logger):
self.daemon = daemon
self.config = daemon.config
self.pending = defaultdict(asyncio.Event)
self.daemon.network.register_callback(self.on_payment, ['payment_received'])
util.register_callback(self.on_payment, ['payment_received'])
async def on_payment(self, evt, wallet, key, status):
if status == PR_PAID:
@ -269,6 +272,8 @@ class AuthenticationCredentialsInvalid(AuthenticationError):
class Daemon(Logger):
network: Optional[Network]
@profiler
def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
Logger.__init__(self)
@ -504,7 +509,7 @@ class Daemon(Logger):
fut = asyncio.run_coroutine_threadsafe(self.taskgroup.cancel_remaining(), self.asyncio_loop)
try:
fut.result(timeout=2)
except (asyncio.TimeoutError, asyncio.CancelledError):
except (concurrent.futures.TimeoutError, concurrent.futures.CancelledError, asyncio.CancelledError):
pass
self.logger.info("removing lockfile")
remove_lockfile(get_lockfile(self.config))

12
electrum/exchange_rate.py

@ -12,6 +12,7 @@ from typing import Sequence, Optional
from aiorpcx.curio import timeout_after, TaskTimeout, TaskGroup
from . import util
from .bitcoin import COIN
from .i18n import _
from .util import (ThreadJob, make_dir, log_exceptions,
@ -452,12 +453,11 @@ def get_exchanges_by_ccy(history=True):
class FxThread(ThreadJob):
def __init__(self, config: SimpleConfig, network: Network):
def __init__(self, config: SimpleConfig, network: Optional[Network]):
ThreadJob.__init__(self)
self.config = config
self.network = network
if self.network:
self.network.register_callback(self.set_proxy, ['proxy_set'])
util.register_callback(self.set_proxy, ['proxy_set'])
self.ccy = self.get_currency()
self.history_used_spot = False
self.ccy_combo = None
@ -567,12 +567,10 @@ class FxThread(ThreadJob):
self.exchange.read_historical_rates(self.ccy, self.cache_dir)
def on_quotes(self):
if self.network:
self.network.trigger_callback('on_quotes')
util.trigger_callback('on_quotes')
def on_history(self):
if self.network:
self.network.trigger_callback('on_history')
util.trigger_callback('on_history')
def exchange_rate(self) -> Decimal:
"""Returns the exchange rate as a Decimal"""

47
electrum/gui/kivy/main_window.py

@ -13,6 +13,7 @@ from electrum.storage import WalletStorage, StorageReadWriteError
from electrum.wallet_db import WalletDB
from electrum.wallet import Wallet, InternalAddressCorruption, Abstract_Wallet
from electrum.plugin import run_hook
from electrum import util
from electrum.util import (profiler, InvalidPassword, send_exception_to_crash_reporter,
format_satoshis, format_satoshis_plain, format_fee_satoshis,
PR_PAID, PR_FAILED, maybe_extract_bolt11_invoice)
@ -50,7 +51,6 @@ from .uix.dialogs.question import Question
# delayed imports: for startup speed on android
notification = app = ref = None
util = False
# register widget cache for keeping memory down timeout to forever to cache
# the data
@ -145,6 +145,17 @@ class ElectrumWindow(App):
servers = self.network.get_servers()
ChoiceDialog(_('Choose a server'), sorted(servers), popup.ids.host.text, cb2).open()
def maybe_switch_to_server(self, server_str: str):
from electrum.interface import ServerAddr
net_params = self.network.get_parameters()
try:
server = ServerAddr.from_str_with_inference(server_str)
except Exception as e:
self.show_error(_("Invalid server details: {}").format(repr(e)))
return
net_params = net_params._replace(server=server)
self.network.run_from_another_thread(self.network.set_parameters(net_params))
def choose_blockchain_dialog(self, dt):
from .uix.dialogs.choice_dialog import ChoiceDialog
chains = self.network.get_blockchains()
@ -348,8 +359,8 @@ class ElectrumWindow(App):
self.num_blocks = self.network.get_local_height()
self.num_nodes = len(self.network.get_interfaces())
net_params = self.network.get_parameters()
self.server_host = net_params.host
self.server_port = net_params.port
self.server_host = net_params.server.host
self.server_port = str(net_params.server.port)
self.auto_connect = net_params.auto_connect
self.oneserver = net_params.oneserver
self.proxy_config = net_params.proxy if net_params.proxy else {}
@ -565,20 +576,20 @@ class ElectrumWindow(App):
if self.network:
interests = ['wallet_updated', 'network_updated', 'blockchain_updated',
'status', 'new_transaction', 'verified']
self.network.register_callback(self.on_network_event, interests)
self.network.register_callback(self.on_fee, ['fee'])
self.network.register_callback(self.on_fee_histogram, ['fee_histogram'])
self.network.register_callback(self.on_quotes, ['on_quotes'])
self.network.register_callback(self.on_history, ['on_history'])
self.network.register_callback(self.on_channels, ['channels_updated'])
self.network.register_callback(self.on_channel, ['channel'])
self.network.register_callback(self.on_invoice_status, ['invoice_status'])
self.network.register_callback(self.on_request_status, ['request_status'])
self.network.register_callback(self.on_payment_failed, ['payment_failed'])
self.network.register_callback(self.on_payment_succeeded, ['payment_succeeded'])
self.network.register_callback(self.on_channel_db, ['channel_db'])
self.network.register_callback(self.set_num_peers, ['gossip_peers'])
self.network.register_callback(self.set_unknown_channels, ['unknown_channels'])
util.register_callback(self.on_network_event, interests)
util.register_callback(self.on_fee, ['fee'])
util.register_callback(self.on_fee_histogram, ['fee_histogram'])
util.register_callback(self.on_quotes, ['on_quotes'])
util.register_callback(self.on_history, ['on_history'])
util.register_callback(self.on_channels, ['channels_updated'])
util.register_callback(self.on_channel, ['channel'])
util.register_callback(self.on_invoice_status, ['invoice_status'])
util.register_callback(self.on_request_status, ['request_status'])
util.register_callback(self.on_payment_failed, ['payment_failed'])
util.register_callback(self.on_payment_succeeded, ['payment_succeeded'])
util.register_callback(self.on_channel_db, ['channel_db'])
util.register_callback(self.set_num_peers, ['gossip_peers'])
util.register_callback(self.set_unknown_channels, ['unknown_channels'])
# load wallet
self.load_wallet_by_name(self.electrum_config.get_wallet_path(use_gui_last_wallet=True))
# URI passed in config
@ -814,7 +825,7 @@ class ElectrumWindow(App):
if interface:
self.server_host = interface.host
else:
self.server_host = str(net_params.host) + ' (connecting...)'
self.server_host = str(net_params.server.host) + ' (connecting...)'
self.proxy_config = net_params.proxy or {}
self.update_proxy_str(self.proxy_config)

23
electrum/gui/kivy/uix/ui_screens/server.kv

@ -16,27 +16,14 @@ Popup:
height: '36dp'
size_hint_x: 1
size_hint_y: None
text: _('Host') + ':'
text: _('Server') + ':'
TextInput:
id: host
id: server_str
multiline: False
height: '36dp'
size_hint_x: 3
size_hint_y: None
text: app.network.get_parameters().host
Label:
height: '36dp'
size_hint_x: 1
size_hint_y: None
text: _('Port') + ':'
TextInput:
id: port
multiline: False
input_type: 'number'
height: '36dp'
size_hint_x: 3
size_hint_y: None
text: app.network.get_parameters().port
text: app.network.get_parameters().server.net_addr_str()
Widget
Button:
id: chooser
@ -56,7 +43,5 @@ Popup:
height: '48dp'
text: _('OK')
on_release:
net_params = app.network.get_parameters()
net_params = net_params._replace(host=str(root.ids.host.text), port=str(root.ids.port.text))
app.network.run_from_another_thread(app.network.set_parameters(net_params))
app.maybe_switch_to_server(str(root.ids.server_str.text))
nd.dismiss()

9
electrum/gui/qt/channel_details.py

@ -5,6 +5,7 @@ import PyQt5.QtWidgets as QtWidgets
import PyQt5.QtCore as QtCore
from PyQt5.QtWidgets import QLabel, QLineEdit
from electrum import util
from electrum.i18n import _
from electrum.util import bh2u, format_time
from electrum.lnutil import format_short_channel_id, LOCAL, REMOTE, UpdateAddHtlc, Direction
@ -132,10 +133,10 @@ class ChannelDetailsDialog(QtWidgets.QDialog):
self.htlc_added.connect(self.do_htlc_added)
# register callbacks for updating
window.network.register_callback(self.ln_payment_completed.emit, ['ln_payment_completed'])
window.network.register_callback(self.ln_payment_failed.emit, ['ln_payment_failed'])
window.network.register_callback(self.htlc_added.emit, ['htlc_added'])
window.network.register_callback(self.state_changed.emit, ['channel'])
util.register_callback(self.ln_payment_completed.emit, ['ln_payment_completed'])
util.register_callback(self.ln_payment_failed.emit, ['ln_payment_failed'])
util.register_callback(self.htlc_added.emit, ['htlc_added'])
util.register_callback(self.state_changed.emit, ['channel'])
# set attributes of QDialog
self.setWindowTitle(_('Channel Details'))

7
electrum/gui/qt/lightning_dialog.py

@ -27,6 +27,7 @@ from typing import TYPE_CHECKING
from PyQt5.QtWidgets import (QDialog, QLabel, QVBoxLayout, QPushButton)
from electrum import util
from electrum.i18n import _
from .util import Buttons
@ -58,9 +59,9 @@ class LightningDialog(QDialog):
b = QPushButton(_('Close'))
b.clicked.connect(self.close)
vbox.addLayout(Buttons(b))
self.network.register_callback(self.on_channel_db, ['channel_db'])
self.network.register_callback(self.set_num_peers, ['gossip_peers'])
self.network.register_callback(self.set_unknown_channels, ['unknown_channels'])
util.register_callback(self.on_channel_db, ['channel_db'])
util.register_callback(self.set_num_peers, ['gossip_peers'])
util.register_callback(self.set_unknown_channels, ['unknown_channels'])
self.network.channel_db.update_counts() # trigger callback
self.set_num_peers('', self.network.lngossip.num_peers())
self.set_unknown_channels('', len(self.network.lngossip.unknown_ids))

11
electrum/gui/qt/main_window.py

@ -272,7 +272,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
# window from being GC-ed when closed, callbacks should be
# methods of this class only, and specifically not be
# partials, lambdas or methods of subobjects. Hence...
self.network.register_callback(self.on_network, interests)
util.register_callback(self.on_network, interests)
# set initial message
self.console.showMessage(self.network.banner)
@ -466,8 +466,8 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
def load_wallet(self, wallet):
wallet.thread = TaskThread(self, self.on_error)
self.update_recently_visited(wallet.storage.path)
if wallet.lnworker and wallet.network:
wallet.network.trigger_callback('channels_updated', wallet)
if wallet.lnworker:
util.trigger_callback('channels_updated', wallet)
self.need_update.set()
# Once GUI has been initialized check if we want to announce something since the callback has been called before the GUI was initialized
# update menus
@ -738,7 +738,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
def donate_to_server(self):
d = self.network.get_donation_address()
if d:
host = self.network.get_parameters().host
host = self.network.get_parameters().server.host
self.pay_to_URI('bitcoin:%s?message=donation for %s'%(d, host))
else:
self.show_error(_('No donation address for this server'))
@ -2889,8 +2889,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
def clean_up(self):
self.wallet.thread.stop()
if self.network:
self.network.unregister_callback(self.on_network)
util.unregister_callback(self.on_network)
self.config.set_key("is_maximized", self.isMaximized())
if not self.isMaximized():
g = self.geometry()

124
electrum/gui/qt/network_dialog.py

@ -35,8 +35,8 @@ from PyQt5.QtWidgets import (QTreeWidget, QTreeWidgetItem, QMenu, QGridLayout, Q
from PyQt5.QtGui import QFontMetrics
from electrum.i18n import _
from electrum import constants, blockchain
from electrum.interface import serialize_server, deserialize_server
from electrum import constants, blockchain, util
from electrum.interface import ServerAddr, PREFERRED_NETWORK_PROTOCOL
from electrum.network import Network
from electrum.logging import get_logger
@ -61,7 +61,7 @@ class NetworkDialog(QDialog):
vbox.addLayout(Buttons(CloseButton(self)))
self.network_updated_signal_obj.network_updated_signal.connect(
self.on_update)
network.register_callback(self.on_network, ['network_updated'])
util.register_callback(self.on_network, ['network_updated'])
def on_network(self, event, *args):
self.network_updated_signal_obj.network_updated_signal.emit(event, args)
@ -72,10 +72,15 @@ class NetworkDialog(QDialog):
class NodesListWidget(QTreeWidget):
"""List of connected servers."""
SERVER_ADDR_ROLE = Qt.UserRole + 100
CHAIN_ID_ROLE = Qt.UserRole + 101
IS_SERVER_ROLE = Qt.UserRole + 102
def __init__(self, parent):
QTreeWidget.__init__(self)
self.parent = parent
self.parent = parent # type: NetworkChoiceLayout
self.setHeaderLabels([_('Connected node'), _('Height')])
self.setContextMenuPolicy(Qt.CustomContextMenu)
self.customContextMenuRequested.connect(self.create_menu)
@ -84,13 +89,13 @@ class NodesListWidget(QTreeWidget):
item = self.currentItem()
if not item:
return
is_server = not bool(item.data(0, Qt.UserRole))
is_server = bool(item.data(0, self.IS_SERVER_ROLE))
menu = QMenu()
if is_server:
server = item.data(1, Qt.UserRole)
server = item.data(0, self.SERVER_ADDR_ROLE) # type: ServerAddr
menu.addAction(_("Use as server"), lambda: self.parent.follow_server(server))
else:
chain_id = item.data(1, Qt.UserRole)
chain_id = item.data(0, self.CHAIN_ID_ROLE)
menu.addAction(_("Follow this branch"), lambda: self.parent.follow_branch(chain_id))
menu.exec_(self.viewport().mapToGlobal(position))
@ -117,15 +122,16 @@ class NodesListWidget(QTreeWidget):
name = b.get_name()
if n_chains > 1:
x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()])
x.setData(0, Qt.UserRole, 1)
x.setData(1, Qt.UserRole, b.get_id())
x.setData(0, self.IS_SERVER_ROLE, 0)
x.setData(0, self.CHAIN_ID_ROLE, b.get_id())
else:
x = self
for i in interfaces:
star = ' *' if i == network.interface else ''
item = QTreeWidgetItem([i.host + star, '%d'%i.tip])
item.setData(0, Qt.UserRole, 0)
item.setData(1, Qt.UserRole, i.server)
item.setData(0, self.IS_SERVER_ROLE, 1)
item.setData(0, self.SERVER_ADDR_ROLE, i.server)
item.setToolTip(0, str(i.server))
x.addChild(item)
if n_chains > 1:
self.addTopLevelItem(x)
@ -140,15 +146,17 @@ class NodesListWidget(QTreeWidget):
class ServerListWidget(QTreeWidget):
"""List of all known servers."""
class Columns(IntEnum):
HOST = 0
PORT = 1
SERVER_STR_ROLE = Qt.UserRole + 100
SERVER_ADDR_ROLE = Qt.UserRole + 100
def __init__(self, parent):
QTreeWidget.__init__(self)
self.parent = parent
self.parent = parent # type: NetworkChoiceLayout
self.setHeaderLabels([_('Host'), _('Port')])
self.setContextMenuPolicy(Qt.CustomContextMenu)
self.customContextMenuRequested.connect(self.create_menu)
@ -158,14 +166,12 @@ class ServerListWidget(QTreeWidget):
if not item:
return
menu = QMenu()
server = item.data(self.Columns.HOST, self.SERVER_STR_ROLE)
server = item.data(self.Columns.HOST, self.SERVER_ADDR_ROLE)
menu.addAction(_("Use as server"), lambda: self.set_server(server))
menu.exec_(self.viewport().mapToGlobal(position))
def set_server(self, s):
host, port, protocol = deserialize_server(s)
self.parent.server_host.setText(host)
self.parent.server_port.setText(port)
def set_server(self, server: ServerAddr):
self.parent.server_e.setText(server.net_addr_str())
self.parent.set_server()
def keyPressEvent(self, event):
@ -180,16 +186,17 @@ class ServerListWidget(QTreeWidget):
pt.setX(50)
self.customContextMenuRequested.emit(pt)
def update(self, servers, protocol, use_tor):
def update(self, servers, use_tor):
self.clear()
protocol = PREFERRED_NETWORK_PROTOCOL
for _host, d in sorted(servers.items()):
if _host.endswith('.onion') and not use_tor:
continue
port = d.get(protocol)
if port:
x = QTreeWidgetItem([_host, port])
server = serialize_server(_host, port, protocol)
x.setData(self.Columns.HOST, self.SERVER_STR_ROLE, server)
server = ServerAddr(_host, port, protocol=protocol)
x.setData(self.Columns.HOST, self.SERVER_ADDR_ROLE, server)
self.addTopLevelItem(x)
h = self.header()
@ -205,7 +212,6 @@ class NetworkChoiceLayout(object):
def __init__(self, network: Network, config, wizard=False):
self.network = network
self.config = config
self.protocol = None
self.tor_proxy = None
self.tabs = tabs = QTabWidget()
@ -223,15 +229,12 @@ class NetworkChoiceLayout(object):
grid = QGridLayout(server_tab)
grid.setSpacing(8)
self.server_host = QLineEdit()
self.server_host.setFixedWidth(fixed_width_hostname)
self.server_port = QLineEdit()
self.server_port.setFixedWidth(fixed_width_port)
self.server_e = QLineEdit()
self.server_e.setFixedWidth(fixed_width_hostname + fixed_width_port)
self.autoconnect_cb = QCheckBox(_('Select server automatically'))
self.autoconnect_cb.setEnabled(self.config.is_modifiable('auto_connect'))
self.server_host.editingFinished.connect(self.set_server)
self.server_port.editingFinished.connect(self.set_server)
self.server_e.editingFinished.connect(self.set_server)
self.autoconnect_cb.clicked.connect(self.set_server)
self.autoconnect_cb.clicked.connect(self.update)
@ -243,8 +246,7 @@ class NetworkChoiceLayout(object):
grid.addWidget(HelpButton(msg), 0, 4)
grid.addWidget(QLabel(_('Server') + ':'), 1, 0)
grid.addWidget(self.server_host, 1, 1, 1, 2)
grid.addWidget(self.server_port, 1, 3)
grid.addWidget(self.server_e, 1, 1, 1, 3)
label = _('Server peers') if network.is_connected() else _('Default Servers')
grid.addWidget(QLabel(label), 2, 0, 1, 5)
@ -348,29 +350,26 @@ class NetworkChoiceLayout(object):
def enable_set_server(self):
if self.config.is_modifiable('server'):
enabled = not self.autoconnect_cb.isChecked()
self.server_host.setEnabled(enabled)
self.server_port.setEnabled(enabled)
self.server_e.setEnabled(enabled)
self.servers_list.setEnabled(enabled)
else:
for w in [self.autoconnect_cb, self.server_host, self.server_port, self.servers_list]:
for w in [self.autoconnect_cb, self.server_e, self.servers_list]:
w.setEnabled(False)
def update(self):
net_params = self.network.get_parameters()
host, port, protocol = net_params.host, net_params.port, net_params.protocol
server = net_params.server
proxy_config, auto_connect = net_params.proxy, net_params.auto_connect
if not self.server_host.hasFocus() and not self.server_port.hasFocus():
self.server_host.setText(host)
self.server_port.setText(str(port))
if not self.server_e.hasFocus():
self.server_e.setText(server.net_addr_str())
self.autoconnect_cb.setChecked(auto_connect)
interface = self.network.interface
host = interface.host if interface else _('None')
self.server_label.setText(host)
self.set_protocol(protocol)
self.servers = self.network.get_servers()
self.servers_list.update(self.servers, self.protocol, self.tor_cb.isChecked())
self.servers_list.update(self.servers, self.tor_cb.isChecked())
self.enable_set_server()
height_str = "%d "%(self.network.get_local_height()) + _('blocks')
@ -411,59 +410,24 @@ class NetworkChoiceLayout(object):
def layout(self):
return self.layout_
def set_protocol(self, protocol):
if protocol != self.protocol:
self.protocol = protocol
def change_protocol(self, use_ssl):
p = 's' if use_ssl else 't'
host = self.server_host.text()
pp = self.servers.get(host, constants.net.DEFAULT_PORTS)
if p not in pp.keys():
p = list(pp.keys())[0]
port = pp[p]
self.server_host.setText(host)
self.server_port.setText(port)
self.set_protocol(p)
self.set_server()
def follow_branch(self, chain_id):
self.network.run_from_another_thread(self.network.follow_chain_given_id(chain_id))
self.update()
def follow_server(self, server):
def follow_server(self, server: ServerAddr):
self.network.run_from_another_thread(self.network.follow_chain_given_server(server))
self.update()
def server_changed(self, x):
if x:
self.change_server(str(x.text(0)), self.protocol)
def change_server(self, host, protocol):
pp = self.servers.get(host, constants.net.DEFAULT_PORTS)
if protocol and protocol not in protocol_letters:
protocol = None
if protocol:
port = pp.get(protocol)
if port is None:
protocol = None
if not protocol:
if 's' in pp.keys():
protocol = 's'
port = pp.get(protocol)
else:
protocol = list(pp.keys())[0]
port = pp.get(protocol)
self.server_host.setText(host)
self.server_port.setText(port)
def accept(self):
pass
def set_server(self):
net_params = self.network.get_parameters()
net_params = net_params._replace(host=str(self.server_host.text()),
port=str(self.server_port.text()),
try:
server = ServerAddr.from_str_with_inference(str(self.server_e.text()))
except Exception:
return
net_params = net_params._replace(server=server,
auto_connect=self.autoconnect_cb.isChecked())
self.network.run_from_another_thread(self.network.set_parameters(net_params))

3
electrum/gui/stdio.py

@ -3,6 +3,7 @@ import getpass
import datetime
import logging
from electrum import util
from electrum import WalletStorage, Wallet
from electrum.util import format_satoshis
from electrum.bitcoin import is_address, COIN
@ -43,7 +44,7 @@ class ElectrumGui:
self.wallet.start_network(self.network)
self.contacts = self.wallet.contacts
self.network.register_callback(self.on_network, ['wallet_updated', 'network_updated', 'banner'])
util.register_callback(self.on_network, ['wallet_updated', 'network_updated', 'banner'])
self.commands = [_("[h] - displays this help text"), \
_("[i] - display transaction history"), \
_("[o] - enter payment order"), \

31
electrum/gui/text.py

@ -6,23 +6,31 @@ import locale
from decimal import Decimal
import getpass
import logging
from typing import TYPE_CHECKING
import electrum
from electrum import util
from electrum.util import format_satoshis
from electrum.bitcoin import is_address, COIN
from electrum.transaction import PartialTxOutput
from electrum.wallet import Wallet
from electrum.storage import WalletStorage
from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed
from electrum.interface import deserialize_server
from electrum.interface import ServerAddr
from electrum.logging import console_stderr_handler
if TYPE_CHECKING:
from electrum.daemon import Daemon
from electrum.simple_config import SimpleConfig
from electrum.plugin import Plugins
_ = lambda x:x # i18n
class ElectrumGui:
def __init__(self, config, daemon, plugins):
def __init__(self, config: 'SimpleConfig', daemon: 'Daemon', plugins: 'Plugins'):
self.config = config
self.network = daemon.network
@ -65,8 +73,7 @@ class ElectrumGui:
self.str_fee = ""
self.history = None
if self.network:
self.network.register_callback(self.update, ['wallet_updated', 'network_updated'])
util.register_callback(self.update, ['wallet_updated', 'network_updated'])
self.tab_names = [_("History"), _("Send"), _("Receive"), _("Addresses"), _("Contacts"), _("Banner")]
self.num_tabs = len(self.tab_names)
@ -402,26 +409,28 @@ class ElectrumGui:
if not self.network:
return
net_params = self.network.get_parameters()
host, port, protocol = net_params.host, net_params.port, net_params.protocol
server_addr = net_params.server
proxy_config, auto_connect = net_params.proxy, net_params.auto_connect
srv = 'auto-connect' if auto_connect else self.network.default_server
srv = 'auto-connect' if auto_connect else str(self.network.default_server)
out = self.run_dialog('Network', [
{'label':'server', 'type':'str', 'value':srv},
{'label':'proxy', 'type':'str', 'value':self.config.get('proxy', '')},
], buttons = 1)
if out:
if out.get('server'):
server = out.get('server')
auto_connect = server == 'auto-connect'
server_str = out.get('server')
auto_connect = server_str == 'auto-connect'
if not auto_connect:
try:
host, port, protocol = deserialize_server(server)
server_addr = ServerAddr.from_str(server_str)
except Exception:
self.show_message("Error:" + server + "\nIn doubt, type \"auto-connect\"")
self.show_message("Error:" + server_str + "\nIn doubt, type \"auto-connect\"")
return False
if out.get('server') or out.get('proxy'):
proxy = electrum.network.deserialize_proxy(out.get('proxy')) if out.get('proxy') else proxy_config
net_params = NetworkParameters(host, port, protocol, proxy, auto_connect)
net_params = NetworkParameters(server=server_addr,
proxy=proxy,
auto_connect=auto_connect)
self.network.run_from_another_thread(self.network.set_parameters(net_params))
def settings_dialog(self):

145
electrum/interface.py

@ -29,7 +29,7 @@ import sys
import traceback
import asyncio
import socket
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple
from collections import defaultdict
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address
import itertools
@ -43,7 +43,7 @@ from aiorpcx.jsonrpc import JSONRPC, CodeMessageError
from aiorpcx.rawsocket import RSClient
import certifi
from .util import ignore_exceptions, log_exceptions, bfh, SilentTaskGroup
from .util import ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy
from . import util
from . import x509
from . import pem
@ -65,6 +65,10 @@ BUCKET_NAME_OF_ONION_SERVERS = 'onion'
MAX_INCOMING_MSG_SIZE = 1_000_000 # in bytes
_KNOWN_NETWORK_PROTOCOLS = {'t', 's'}
PREFERRED_NETWORK_PROTOCOL = 's'
assert PREFERRED_NETWORK_PROTOCOL in _KNOWN_NETWORK_PROTOCOLS
class NetworkTimeout:
# seconds
@ -198,22 +202,75 @@ class _RSClient(RSClient):
raise ConnectError(e) from e
def deserialize_server(server_str: str) -> Tuple[str, str, str]:
# host might be IPv6 address, hence do rsplit:
host, port, protocol = str(server_str).rsplit(':', 2)
if not host:
raise ValueError('host must not be empty')
if host[0] == '[' and host[-1] == ']': # IPv6
host = host[1:-1]
if protocol not in ('s', 't'):
raise ValueError('invalid network protocol: {}'.format(protocol))
net_addr = NetAddress(host, port) # this validates host and port
host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
return host, port, protocol
class ServerAddr:
def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
assert isinstance(host, str), repr(host)
if protocol is None:
protocol = 's'
if not host:
raise ValueError('host must not be empty')
if host[0] == '[' and host[-1] == ']': # IPv6
host = host[1:-1]
try:
net_addr = NetAddress(host, port) # this validates host and port
except Exception as e:
raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
if protocol not in _KNOWN_NETWORK_PROTOCOLS:
raise ValueError(f"invalid network protocol: {protocol}")
self.host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
self.port = int(net_addr.port)
self.protocol = protocol
self._net_addr_str = str(net_addr)
def serialize_server(host: str, port: Union[str, int], protocol: str) -> str:
return str(':'.join([host, str(port), protocol]))
@classmethod
def from_str(cls, s: str) -> 'ServerAddr':
# host might be IPv6 address, hence do rsplit:
host, port, protocol = str(s).rsplit(':', 2)
return ServerAddr(host=host, port=port, protocol=protocol)
@classmethod
def from_str_with_inference(cls, s: str) -> Optional['ServerAddr']:
"""Construct ServerAddr from str, guessing missing details.
Ongoing compatibility not guaranteed.
"""
if not s:
return None
items = str(s).rsplit(':', 2)
if len(items) < 2:
return None # although maybe we could guess the port too?
host = items[0]
port = items[1]
if len(items) >= 3:
protocol = items[2]
else:
protocol = PREFERRED_NETWORK_PROTOCOL
return ServerAddr(host=host, port=port, protocol=protocol)
def __str__(self):
return '{}:{}'.format(self.net_addr_str(), self.protocol)
def to_json(self) -> str:
return str(self)
def __repr__(self):
return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
def net_addr_str(self) -> str:
return self._net_addr_str
def __eq__(self, other):
if not isinstance(other, ServerAddr):
return False
return (self.host == other.host
and self.port == other.port
and self.protocol == other.protocol)
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash((self.host, self.port, self.protocol))
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
@ -232,19 +289,17 @@ class Interface(Logger):
LOGGING_SHORTCUT = 'i'
def __init__(self, network: 'Network', server: str, proxy: Optional[dict]):
def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[dict]):
self.ready = asyncio.Future()
self.got_disconnected = asyncio.Future()
self.server = server
self.host, self.port, self.protocol = deserialize_server(self.server)
self.port = int(self.port)
Logger.__init__(self)
assert network.config.path
self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
self.blockchain = None # type: Optional[Blockchain]
self._requested_chunks = set() # type: Set[int]
self.network = network
self._set_proxy(proxy)
self.proxy = MySocksProxy.from_proxy_dict(proxy)
self.session = None # type: Optional[NotificationSession]
self._ipaddr_bucket = None
@ -259,29 +314,24 @@ class Interface(Logger):
self.network.taskgroup.spawn(self.run()), self.network.asyncio_loop)
self.taskgroup = SilentTaskGroup()
@property
def host(self):
return self.server.host
@property
def port(self):
return self.server.port
@property
def protocol(self):
return self.server.protocol
def diagnostic_name(self):
return str(NetAddress(self.host, self.port))
return self.server.net_addr_str()
def __str__(self):
return f"<Interface {self.diagnostic_name()}>"
def _set_proxy(self, proxy: dict):
if proxy:
username, pw = proxy.get('user'), proxy.get('password')
if not username or not pw:
auth = None
else:
auth = aiorpcx.socks.SOCKSUserAuth(username, pw)
addr = NetAddress(proxy['host'], proxy['port'])
if proxy['mode'] == "socks4":
self.proxy = aiorpcx.socks.SOCKSProxy(addr, aiorpcx.socks.SOCKS4a, auth)
elif proxy['mode'] == "socks5":
self.proxy = aiorpcx.socks.SOCKSProxy(addr, aiorpcx.socks.SOCKS5, auth)
else:
raise NotImplementedError # http proxy not available with aiorpcx
else:
self.proxy = None
async def is_server_ca_signed(self, ca_ssl_context):
"""Given a CA enforcing SSL context, returns True if the connection
can be established. Returns False if the server has a self-signed
@ -435,13 +485,12 @@ class Interface(Logger):
async def get_certificate(self):
sslc = ssl.SSLContext()
try:
async with _RSClient(session_factory=RPCSession,
host=self.host, port=self.port,
ssl=sslc, proxy=self.proxy) as session:
return session.transport._asyncio_transport._ssl_protocol._sslpipe._sslobj.getpeercert(True)
except ValueError:
return None
async with _RSClient(session_factory=RPCSession,
host=self.host, port=self.port,
ssl=sslc, proxy=self.proxy) as session:
asyncio_transport = session.transport._asyncio_transport # type: asyncio.BaseTransport
ssl_object = asyncio_transport.get_extra_info("ssl_object") # type: ssl.SSLObject
return ssl_object.getpeercert(binary_form=True)
async def get_block_header(self, height, assert_mode):
self.logger.info(f'requesting block header {height} in mode {assert_mode}')
@ -548,7 +597,7 @@ class Interface(Logger):
raise GracefulDisconnect('server tip below max checkpoint')
self._mark_ready()
await self._process_header_at_tip()
self.network.trigger_callback('network_updated')
util.trigger_callback('network_updated')
await self.network.switch_unwanted_fork_interface()
await self.network.switch_lagging_interface()
@ -563,7 +612,7 @@ class Interface(Logger):
# in the simple case, height == self.tip+1
if height <= self.tip:
await self.sync_until(height)
self.network.trigger_callback('blockchain_updated')
util.trigger_callback('blockchain_updated')
async def sync_until(self, height, next_height=None):
if next_height is None:
@ -578,7 +627,7 @@ class Interface(Logger):
raise GracefulDisconnect('server chain conflicts with checkpoints or genesis')
last, height = await self.step(height)
continue
self.network.trigger_callback('network_updated')
util.trigger_callback('network_updated')
height = (height // 2016 * 2016) + num_headers
assert height <= next_height+1, (height, self.tip)
last = 'catchup'

8
electrum/lnchannel.py

@ -33,7 +33,7 @@ from aiorpcx import NetAddress
import attr
from . import ecc
from . import constants
from . import constants, util
from .util import bfh, bh2u, chunks, TxMinedInfo
from .bitcoin import redeem_script_to_address
from .crypto import sha256, sha256d
@ -679,16 +679,14 @@ class Channel(AbstractChannel):
def set_frozen_for_sending(self, b: bool) -> None:
self.storage['frozen_for_sending'] = bool(b)
if self.lnworker:
self.lnworker.network.trigger_callback('channel', self)
util.trigger_callback('channel', self)
def is_frozen_for_receiving(self) -> bool:
return self.storage.get('frozen_for_receiving', False)
def set_frozen_for_receiving(self, b: bool) -> None:
self.storage['frozen_for_receiving'] = bool(b)
if self.lnworker:
self.lnworker.network.trigger_callback('channel', self)
util.trigger_callback('channel', self)
def _assert_can_add_htlc(self, *, htlc_proposer: HTLCOwner, amount_msat: int) -> None:
"""Raises PaymentFailure if the htlc_proposer cannot add this new HTLC.

25
electrum/lnpeer.py

@ -19,7 +19,7 @@ from datetime import datetime
import aiorpcx
from .crypto import sha256, sha256d
from . import bitcoin
from . import bitcoin, util
from . import ecc
from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string, der_sig_from_sig_string
from . import constants
@ -74,6 +74,7 @@ class Peer(Logger):
self.lnworker = lnworker
self.privkey = self.transport.privkey # local privkey
self.features = self.lnworker.features
self.their_features = 0
self.node_ids = [self.pubkey, privkey_to_pubkey(self.privkey)]
self.network = lnworker.network
self.channel_db = lnworker.network.channel_db
@ -200,15 +201,15 @@ class Peer(Logger):
if self._received_init:
self.logger.info("ALREADY INITIALIZED BUT RECEIVED INIT")
return
their_features = LnFeatures(int.from_bytes(payload['features'], byteorder="big"))
self.their_features = LnFeatures(int.from_bytes(payload['features'], byteorder="big"))
their_globalfeatures = int.from_bytes(payload['globalfeatures'], byteorder="big")
their_features |= their_globalfeatures
self.their_features |= their_globalfeatures
# check transitive dependencies for received features
if not their_features.validate_transitive_dependecies():
if not self.their_features.validate_transitive_dependecies():
raise GracefulDisconnect("remote did not set all dependencies for the features they sent")
# check if features are compatible, and set self.features to what we negotiated
try:
self.features = ln_compare_features(self.features, their_features)
self.features = ln_compare_features(self.features, self.their_features)
except IncompatibleLightningFeatures as e:
self.initialized.set_exception(e)
raise GracefulDisconnect(f"{str(e)}")
@ -219,10 +220,7 @@ class Peer(Logger):
if constants.net.rev_genesis_bytes() not in their_chains:
raise GracefulDisconnect(f"no common chain found with remote. (they sent: {their_chains})")
# all checks passed
if self.channel_db and isinstance(self.transport, LNTransport):
self.channel_db.add_recent_peer(self.transport.peer_addr)
for chan in self.channels.values():
chan.add_or_update_peer_addr(self.transport.peer_addr)
self.lnworker.on_peer_successfully_established(self)
self._received_init = True
self.maybe_set_initialized()
@ -254,7 +252,8 @@ class Peer(Logger):
return await func(self, *args, **kwargs)
except GracefulDisconnect as e:
self.logger.log(e.log_level, f"Disconnecting: {repr(e)}")
except (LightningPeerConnectionClosed, IncompatibleLightningFeatures) as e:
except (LightningPeerConnectionClosed, IncompatibleLightningFeatures,
aiorpcx.socks.SOCKSError) as e:
self.logger.info(f"Disconnecting: {repr(e)}")
finally:
self.close_and_cleanup()
@ -744,7 +743,7 @@ class Peer(Logger):
f'already in peer_state {chan.peer_state}')
return
chan.peer_state = PeerState.REESTABLISHING
self.network.trigger_callback('channel', chan)
util.trigger_callback('channel', chan)
# BOLT-02: "A node [...] upon disconnection [...] MUST reverse any uncommitted updates sent by the other side"
chan.hm.discard_unsigned_remote_updates()
# ctns
@ -891,7 +890,7 @@ class Peer(Logger):
# checks done
if chan.is_funded() and chan.config[LOCAL].funding_locked_received:
self.mark_open(chan)
self.network.trigger_callback('channel', chan)
util.trigger_callback('channel', chan)
if chan.get_state() == ChannelState.CLOSING:
await self.send_shutdown(chan)
@ -979,7 +978,7 @@ class Peer(Logger):
return
assert chan.config[LOCAL].funding_locked_received
chan.set_state(ChannelState.OPEN)
self.network.trigger_callback('channel', chan)
util.trigger_callback('channel', chan)
# peer may have sent us a channel update for the incoming direction previously
pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
if pending_channel_update:

17
electrum/lntransport.py

@ -8,12 +8,14 @@
import hashlib
import asyncio
from asyncio import StreamReader, StreamWriter
from typing import Optional
from .crypto import sha256, hmac_oneshot, chacha20_poly1305_encrypt, chacha20_poly1305_decrypt
from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed,
HandshakeFailed, LNPeerAddr)
from . import ecc
from .util import bh2u
from .util import bh2u, MySocksProxy
class HandshakeState(object):
prologue = b"lightning"
@ -155,6 +157,8 @@ class LNTransportBase:
class LNResponderTransport(LNTransportBase):
"""Transport initiated by remote party."""
def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
LNTransportBase.__init__(self)
self.reader = reader
@ -211,19 +215,26 @@ class LNResponderTransport(LNTransportBase):
self.init_counters(ck)
return rs
class LNTransport(LNTransportBase):
"""Transport initiated by local party."""
def __init__(self, privkey: bytes, peer_addr: LNPeerAddr):
def __init__(self, privkey: bytes, peer_addr: LNPeerAddr, *,
proxy: Optional[dict]):
LNTransportBase.__init__(self)
assert type(privkey) is bytes and len(privkey) == 32
self.privkey = privkey
self.peer_addr = peer_addr
self.proxy = MySocksProxy.from_proxy_dict(proxy)
def name(self):
return self.peer_addr.net_addr_str()
async def handshake(self):
self.reader, self.writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port)
if not self.proxy:
self.reader, self.writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port)
else:
self.reader, self.writer = await self.proxy.open_connection(self.peer_addr.host, self.peer_addr.port)
hs = HandshakeState(self.peer_addr.pubkey)
# Get a new ephemeral key
epriv, epub = create_ephemeral_key()

8
electrum/lnwatcher.py

@ -8,6 +8,7 @@ import asyncio
from enum import IntEnum, auto
from typing import NamedTuple, Dict
from . import util
from .sql_db import SqlDB, sql
from .wallet_db import WalletDB
from .util import bh2u, bfh, log_exceptions, ignore_exceptions, TxMinedInfo
@ -139,8 +140,9 @@ class LNWatcher(AddressSynchronizer):
self.config = network.config
self.channels = {}
self.network = network
self.network.register_callback(self.on_network_update,
['network_updated', 'blockchain_updated', 'verified', 'wallet_updated', 'fee'])
util.register_callback(
self.on_network_update,
['network_updated', 'blockchain_updated', 'verified', 'wallet_updated', 'fee'])
# status gets populated when we run
self.channel_status = {}
@ -420,4 +422,4 @@ class LNWalletWatcher(LNWatcher):
tx_was_added = False
if tx_was_added:
self.logger.info(f'added future tx: {name}. prevout: {prevout}')
self.network.trigger_callback('wallet_updated', self.lnworker.wallet)
util.trigger_callback('wallet_updated', self.lnworker.wallet)

228
electrum/lnworker.py

@ -7,7 +7,7 @@ import os
from decimal import Decimal
import random
import time
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING, NamedTuple, Union
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping
import threading
import socket
import json
@ -21,11 +21,11 @@ import dns.resolver
import dns.exception
from aiorpcx import run_in_thread
from . import constants
from . import constants, util
from . import keystore
from .util import profiler
from .util import PR_UNPAID, PR_EXPIRED, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_ROUTING
from .util import PR_TYPE_LN
from .util import PR_TYPE_LN, NetworkRetryManager
from .lnutil import LN_MAX_FUNDING_SAT
from .keystore import BIP32_KeyStore
from .bitcoin import COIN
@ -77,9 +77,7 @@ SAVED_PR_STATUS = [PR_PAID, PR_UNPAID, PR_INFLIGHT] # status that are persisted
NUM_PEERS_TARGET = 4
PEER_RETRY_INTERVAL = 600 # seconds
PEER_RETRY_INTERVAL_FOR_CHANNELS = 30 # seconds
GRAPH_DOWNLOAD_SECONDS = 600
FALLBACK_NODE_LIST_TESTNET = (
LNPeerAddr(host='203.132.95.10', port=9735, pubkey=bfh('038863cf8ab91046230f561cd5b386cbff8309fa02e3f0c3ed161a3aeb64a643b9')),
@ -141,12 +139,20 @@ class NoPathFound(PaymentFailure):
return _('No path found')
class LNWorker(Logger):
class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
def __init__(self, xprv):
Logger.__init__(self)
NetworkRetryManager.__init__(
self,
max_retry_delay_normal=3600,
init_retry_delay_normal=600,
max_retry_delay_urgent=300,
init_retry_delay_urgent=4,
)
self.lock = threading.RLock()
self.node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
self.peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer
self._peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer # needs self.lock
self.taskgroup = SilentTaskGroup()
# set some feature flags as baseline for both LNWallet and LNGossip
# note that e.g. DATA_LOSS_PROTECT is needed for LNGossip as many peers require it
@ -156,6 +162,14 @@ class LNWorker(Logger):
self.features |= LnFeatures.VAR_ONION_OPT
self.features |= LnFeatures.PAYMENT_SECRET_OPT
util.register_callback(self.on_proxy_changed, ['proxy_set'])
@property
def peers(self) -> Mapping[bytes, Peer]:
"""Returns a read-only copy of peers."""
with self.lock:
return self._peers.copy()
def channels_for_peer(self, node_id):
return {}
@ -175,10 +189,12 @@ class LNWorker(Logger):
self.logger.info('handshake failure from incoming connection')
return
peer = Peer(self, node_id, transport)
self.peers[node_id] = peer
with self.lock:
self._peers[node_id] = peer
await self.taskgroup.spawn(peer.main_loop())
try:
# FIXME: server.close(), server.wait_closed(), etc... ?
# TODO: onion hidden service?
server = await asyncio.start_server(cb, addr, int(port))
except OSError as e:
self.logger.error(f"cannot listen for lightning p2p. error: {e!r}")
@ -200,30 +216,31 @@ class LNWorker(Logger):
while True:
await asyncio.sleep(1)
now = time.time()
if len(self.peers) >= NUM_PEERS_TARGET:
if len(self._peers) >= NUM_PEERS_TARGET:
continue
peers = await self._get_next_peers_to_try()
for peer in peers:
last_tried = self._last_tried_peer.get(peer, 0)
if last_tried + PEER_RETRY_INTERVAL < now:
if self._can_retry_addr(peer, now=now):
await self._add_peer(peer.host, peer.port, peer.pubkey)
async def _add_peer(self, host, port, node_id) -> Peer:
if node_id in self.peers:
return self.peers[node_id]
async def _add_peer(self, host: str, port: int, node_id: bytes) -> Peer:
if node_id in self._peers:
return self._peers[node_id]
port = int(port)
peer_addr = LNPeerAddr(host, port, node_id)
transport = LNTransport(self.node_keypair.privkey, peer_addr)
self._last_tried_peer[peer_addr] = time.time()
transport = LNTransport(self.node_keypair.privkey, peer_addr,
proxy=self.network.proxy)
self._trying_addr_now(peer_addr)
self.logger.info(f"adding peer {peer_addr}")
peer = Peer(self, node_id, transport)
await self.taskgroup.spawn(peer.main_loop())
self.peers[node_id] = peer
with self.lock:
self._peers[node_id] = peer
return peer
def peer_closed(self, peer: Peer) -> None:
if peer.pubkey in self.peers:
self.peers.pop(peer.pubkey)
with self.lock:
self._peers.pop(peer.pubkey, None)
def num_peers(self) -> int:
return sum([p.is_initialized() for p in self.peers.values()])
@ -232,11 +249,9 @@ class LNWorker(Logger):
assert network
self.network = network
self.config = network.config
daemon = network.daemon
self.channel_db = self.network.channel_db
self._last_tried_peer = {} # type: Dict[LNPeerAddr, float] # LNPeerAddr -> unix timestamp
self._add_peers_from_config()
asyncio.run_coroutine_threadsafe(daemon.taskgroup.spawn(self.main_loop()), self.network.asyncio_loop)
asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
def _add_peers_from_config(self):
peer_list = self.config.get('lightning_peers', [])
@ -260,20 +275,29 @@ class LNWorker(Logger):
#self.logger.info(f'is_good {peer.host}')
return True
def on_peer_successfully_established(self, peer: Peer) -> None:
if isinstance(peer.transport, LNTransport):
peer_addr = peer.transport.peer_addr
# reset connection attempt count
self._on_connection_successfully_established(peer_addr)
# add into channel db
if self.channel_db:
self.channel_db.add_recent_peer(peer_addr)
# save network address into channels we might have with peer
for chan in peer.channels.values():
chan.add_or_update_peer_addr(peer_addr)
async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
now = time.time()
await self.channel_db.data_loaded.wait()
recent_peers = self.channel_db.get_recent_peers()
# maintenance for last tried times
# due to this, below we can just test membership in _last_tried_peer
for peer in list(self._last_tried_peer):
if now >= self._last_tried_peer[peer] + PEER_RETRY_INTERVAL:
del self._last_tried_peer[peer]
# first try from recent peers
recent_peers = self.channel_db.get_recent_peers()
for peer in recent_peers:
if peer.pubkey in self.peers:
if not peer:
continue
if peer.pubkey in self._peers:
continue
if peer in self._last_tried_peer:
if not self._can_retry_addr(peer, now=now):
continue
if not self.is_good_peer(peer):
continue
@ -290,7 +314,7 @@ class LNWorker(Logger):
peer = LNPeerAddr(host, port, node_id)
except ValueError:
continue
if peer in self._last_tried_peer:
if not self._can_retry_addr(peer, now=now):
continue
if not self.is_good_peer(peer):
continue
@ -305,7 +329,7 @@ class LNWorker(Logger):
else:
return [] # regtest??
fallback_list = [peer for peer in fallback_list if peer not in self._last_tried_peer]
fallback_list = [peer for peer in fallback_list if self._can_retry_addr(peer, now=now)]
if fallback_list:
return [random.choice(fallback_list)]
@ -363,12 +387,40 @@ class LNWorker(Logger):
choice = random.choice(addr_list)
return choice
def on_proxy_changed(self, event, *args):
for peer in self.peers.values():
peer.close_and_cleanup()
self._clear_addr_retry_times()
@log_exceptions
async def add_peer(self, connect_str: str) -> Peer:
node_id, rest = extract_nodeid(connect_str)
peer = self._peers.get(node_id)
if not peer:
if rest is not None:
host, port = split_host_port(rest)
else:
addrs = self.channel_db.get_node_addresses(node_id)
if not addrs:
raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id))
host, port, timestamp = self.choose_preferred_address(addrs)
port = int(port)
# Try DNS-resolving the host (if needed). This is simply so that
# the caller gets a nice exception if it cannot be resolved.
try:
await asyncio.get_event_loop().getaddrinfo(host, port)
except socket.gaierror:
raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
# add peer
peer = await self._add_peer(host, port, node_id)
return peer
class LNGossip(LNWorker):
max_age = 14*24*3600
LOGGING_SHORTCUT = 'g'
def __init__(self, network):
def __init__(self):
seed = os.urandom(32)
node = BIP32Node.from_rootseed(seed, xtype='standard')
xprv = node.to_xprv()
@ -394,16 +446,16 @@ class LNGossip(LNWorker):
known = self.channel_db.get_channel_ids()
new = set(ids) - set(known)
self.unknown_ids.update(new)
self.network.trigger_callback('unknown_channels', len(self.unknown_ids))
self.network.trigger_callback('gossip_peers', self.num_peers())
self.network.trigger_callback('ln_gossip_sync_progress')
util.trigger_callback('unknown_channels', len(self.unknown_ids))
util.trigger_callback('gossip_peers', self.num_peers())
util.trigger_callback('ln_gossip_sync_progress')
def get_ids_to_query(self):
N = 500
l = list(self.unknown_ids)
self.unknown_ids = set(l[N:])
self.network.trigger_callback('unknown_channels', len(self.unknown_ids))
self.network.trigger_callback('ln_gossip_sync_progress')
util.trigger_callback('unknown_channels', len(self.unknown_ids))
util.trigger_callback('ln_gossip_sync_progress')
return l[0:N]
def get_sync_progress_estimate(self) -> Tuple[Optional[int], Optional[int]]:
@ -431,7 +483,6 @@ class LNWallet(LNWorker):
self.payments = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid
self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
self.sweep_address = wallet.get_receiving_address()
self.lock = threading.RLock()
self.logs = defaultdict(list) # type: Dict[str, List[PaymentAttemptLog]] # key is RHASH # (not persisted)
self.is_routing = set() # (not persisted) keys of invoices that are in PR_ROUTING state
# used in tests
@ -515,7 +566,7 @@ class LNWallet(LNWorker):
def peer_closed(self, peer):
for chan in self.channels_for_peer(peer.pubkey).values():
chan.peer_state = PeerState.DISCONNECTED
self.network.trigger_callback('channel', chan)
util.trigger_callback('channel', chan)
super().peer_closed(peer)
def get_settled_payments(self):
@ -646,14 +697,14 @@ class LNWallet(LNWorker):
def channel_state_changed(self, chan):
self.save_channel(chan)
self.network.trigger_callback('channel', chan)
util.trigger_callback('channel', chan)
def save_channel(self, chan):
assert type(chan) is Channel
if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point:
raise Exception("Tried to save channel with next_point == current_point, this should not happen")
self.wallet.save_db()
self.network.trigger_callback('channel', chan)
util.trigger_callback('channel', chan)
def channel_by_txo(self, txo):
with self.lock:
@ -669,12 +720,12 @@ class LNWallet(LNWorker):
await self.try_force_closing(chan.channel_id)
elif chan.get_state() == ChannelState.FUNDED:
peer = self.peers.get(chan.node_id)
peer = self._peers.get(chan.node_id)
if peer and peer.is_initialized():
peer.send_funding_locked(chan)
elif chan.get_state() == ChannelState.OPEN:
peer = self.peers.get(chan.node_id)
peer = self._peers.get(chan.node_id)
if peer:
await peer.maybe_update_fee(chan)
conf = self.lnwatcher.get_tx_height(chan.funding_outpoint.txid).conf
@ -688,9 +739,6 @@ class LNWallet(LNWorker):
self.logger.info('REBROADCASTING CLOSING TX')
await self.network.try_broadcasting(force_close_tx, 'force-close')
@log_exceptions
async def _open_channel_coroutine(self, *, connect_str: str, funding_tx: PartialTransaction,
funding_sat: int, push_sat: int,
@ -704,7 +752,7 @@ class LNWallet(LNWorker):
funding_sat=funding_sat,
push_msat=push_sat * 1000,
temp_channel_id=os.urandom(32))
self.network.trigger_callback('channels_updated', self.wallet)
util.trigger_callback('channels_updated', self.wallet)
self.wallet.add_transaction(funding_tx) # save tx as local into the wallet
self.wallet.set_label(funding_tx.txid(), _('Open channel'))
if funding_tx.is_complete():
@ -722,29 +770,6 @@ class LNWallet(LNWorker):
channels_db[chan.channel_id.hex()] = chan.storage
self.wallet.save_backup()
@log_exceptions
async def add_peer(self, connect_str: str) -> Peer:
node_id, rest = extract_nodeid(connect_str)
peer = self.peers.get(node_id)
if not peer:
if rest is not None:
host, port = split_host_port(rest)
else:
addrs = self.channel_db.get_node_addresses(node_id)
if not addrs:
raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id))
host, port, timestamp = self.choose_preferred_address(addrs)
port = int(port)
# Try DNS-resolving the host (if needed). This is simply so that
# the caller gets a nice exception if it cannot be resolved.
try:
await asyncio.get_event_loop().getaddrinfo(host, port)
except socket.gaierror:
raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
# add peer
peer = await self._add_peer(host, port, node_id)
return peer
def mktx_for_open_channel(self, *, coins: Sequence[PartialTxInput], funding_sat: int,
fee_est=None) -> PartialTransaction:
dummy_address = ln_dummy_address()
@ -805,10 +830,10 @@ class LNWallet(LNWorker):
# note: path-finding runs in a separate thread so that we don't block the asyncio loop
# graph updates might occur during the computation
self.set_invoice_status(key, PR_ROUTING)
self.network.trigger_callback('invoice_status', key)
util.trigger_callback('invoice_status', key)
route = await run_in_thread(self._create_route_from_invoice, lnaddr)
self.set_invoice_status(key, PR_INFLIGHT)
self.network.trigger_callback('invoice_status', key)
util.trigger_callback('invoice_status', key)
payment_attempt_log = await self._pay_to_route(route, lnaddr)
except Exception as e:
log.append(PaymentAttemptLog(success=False, exception=e))
@ -821,17 +846,17 @@ class LNWallet(LNWorker):
break
else:
reason = _('Failed after {} attempts').format(attempts)
self.network.trigger_callback('invoice_status', key)
util.trigger_callback('invoice_status', key)
if success:
self.network.trigger_callback('payment_succeeded', key)
util.trigger_callback('payment_succeeded', key)
else:
self.network.trigger_callback('payment_failed', key, reason)
util.trigger_callback('payment_failed', key, reason)
return success
async def _pay_to_route(self, route: LNPaymentRoute, lnaddr: LnAddr) -> PaymentAttemptLog:
short_channel_id = route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id)
peer = self.peers.get(route[0].node_id)
peer = self._peers.get(route[0].node_id)
if not peer:
raise Exception('Dropped peer')
await peer.initialized
@ -841,7 +866,7 @@ class LNWallet(LNWorker):
payment_hash=lnaddr.paymenthash,
min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
payment_secret=lnaddr.payment_secret)
self.network.trigger_callback('htlc_added', htlc, lnaddr, SENT)
util.trigger_callback('htlc_added', htlc, lnaddr, SENT)
payment_attempt = await self.await_payment(lnaddr.paymenthash)
if payment_attempt.success:
failure_log = None
@ -1140,9 +1165,9 @@ class LNWallet(LNWorker):
f.set_result(payment_attempt)
else:
chan.logger.info('received unexpected payment_failed, probably from previous session')
self.network.trigger_callback('invoice_status', key)
self.network.trigger_callback('payment_failed', key, '')
self.network.trigger_callback('ln_payment_failed', payment_hash, chan.channel_id)
util.trigger_callback('invoice_status', key)
util.trigger_callback('payment_failed', key, '')
util.trigger_callback('ln_payment_failed', payment_hash, chan.channel_id)
def payment_sent(self, chan, payment_hash: bytes):
self.set_payment_status(payment_hash, PR_PAID)
@ -1156,14 +1181,14 @@ class LNWallet(LNWorker):
f.set_result(payment_attempt)
else:
chan.logger.info('received unexpected payment_sent, probably from previous session')
self.network.trigger_callback('invoice_status', key)
self.network.trigger_callback('payment_succeeded', key)
self.network.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
util.trigger_callback('invoice_status', key)
util.trigger_callback('payment_succeeded', key)
util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
def payment_received(self, chan, payment_hash: bytes):
self.set_payment_status(payment_hash, PR_PAID)
self.network.trigger_callback('request_status', payment_hash.hex(), PR_PAID)
self.network.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
util.trigger_callback('request_status', payment_hash.hex(), PR_PAID)
util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
async def _calc_routing_hints_for_invoice(self, amount_sat):
"""calculate routing hints (BOLT-11 'r' field)"""
@ -1227,7 +1252,7 @@ class LNWallet(LNWorker):
async def close_channel(self, chan_id):
chan = self.channels[chan_id]
peer = self.peers[chan.node_id]
peer = self._peers[chan.node_id]
return await peer.close_channel(chan_id)
async def force_close_channel(self, chan_id):
@ -1252,8 +1277,8 @@ class LNWallet(LNWorker):
self.channels.pop(chan_id)
self.db.get('channels').pop(chan_id.hex())
self.network.trigger_callback('channels_updated', self.wallet)
self.network.trigger_callback('wallet_updated', self.wallet)
util.trigger_callback('channels_updated', self.wallet)
util.trigger_callback('wallet_updated', self.wallet)
@ignore_exceptions
@log_exceptions
@ -1270,18 +1295,10 @@ class LNWallet(LNWorker):
peer_addresses.append(LNPeerAddr(host, port, chan.node_id))
# will try addresses stored in channel storage
peer_addresses += list(chan.get_peer_addresses())
# Done gathering addresses.
# Now select first one that has not failed recently.
# Use long retry interval to check. This ensures each address we gathered gets a chance.
for peer in peer_addresses:
last_tried = self._last_tried_peer.get(peer, 0)
if last_tried + PEER_RETRY_INTERVAL < now:
await self._add_peer(peer.host, peer.port, peer.pubkey)
return
# Still here? That means all addresses failed ~recently.
# Use short retry interval now.
for peer in peer_addresses:
last_tried = self._last_tried_peer.get(peer, 0)
if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now:
if self._can_retry_addr(peer, urgent=True, now=now):
await self._add_peer(peer.host, peer.port, peer.pubkey)
return
@ -1296,7 +1313,7 @@ class LNWallet(LNWorker):
# reestablish
if not chan.should_try_to_reestablish_peer():
continue
peer = self.peers.get(chan.node_id, None)
peer = self._peers.get(chan.node_id, None)
if peer:
await peer.taskgroup.spawn(peer.reestablish_channel(chan))
else:
@ -1356,7 +1373,7 @@ class LNBackups(Logger):
self.channel_backups[bfh(channel_id)] = ChannelBackup(cb, sweep_address=self.sweep_address, lnworker=self)
def channel_state_changed(self, chan):
self.network.trigger_callback('channel', chan)
util.trigger_callback('channel', chan)
def peer_closed(self, chan):
pass
@ -1390,7 +1407,7 @@ class LNBackups(Logger):
d[channel_id] = cb_storage
self.channel_backups[bfh(channel_id)] = cb = ChannelBackup(cb_storage, sweep_address=self.sweep_address, lnworker=self)
self.wallet.save_db()
self.network.trigger_callback('channels_updated', self.wallet)
util.trigger_callback('channels_updated', self.wallet)
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
def remove_channel_backup(self, channel_id):
@ -1400,13 +1417,14 @@ class LNBackups(Logger):
d.pop(channel_id.hex())
self.channel_backups.pop(channel_id)
self.wallet.save_db()
self.network.trigger_callback('channels_updated', self.wallet)
util.trigger_callback('channels_updated', self.wallet)
@log_exceptions
async def request_force_close(self, channel_id):
cb = self.channel_backups[channel_id].cb
peer_addr = LNPeerAddr(cb.host, cb.port, cb.node_id)
transport = LNTransport(cb.privkey, peer_addr)
transport = LNTransport(cb.privkey, peer_addr,
proxy=self.network.proxy)
peer = Peer(self, cb.node_id, transport)
await self.taskgroup.spawn(peer._message_loop())
await peer.initialized

308
electrum/network.py

@ -32,7 +32,7 @@ import socket
import json
import sys
import asyncio
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set
import traceback
import concurrent
from concurrent import futures
@ -44,7 +44,7 @@ from aiohttp import ClientResponse
from . import util
from .util import (log_exceptions, ignore_exceptions,
bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter,
is_hash256_str, is_non_negative_integer)
is_hash256_str, is_non_negative_integer, MyEncoder, NetworkRetryManager)
from .bitcoin import COIN
from . import constants
@ -53,9 +53,9 @@ from . import bitcoin
from . import dns_hacks
from .transaction import Transaction
from .blockchain import Blockchain, HEADER_SIZE
from .interface import (Interface, serialize_server, deserialize_server,
from .interface import (Interface, PREFERRED_NETWORK_PROTOCOL,
RequestTimedOut, NetworkTimeout, BUCKET_NAME_OF_ONION_SERVERS,
NetworkException, RequestCorrupted)
NetworkException, RequestCorrupted, ServerAddr)
from .version import PROTOCOL_VERSION
from .simple_config import SimpleConfig
from .i18n import _
@ -71,10 +71,8 @@ if TYPE_CHECKING:
_logger = get_logger(__name__)
NODES_RETRY_INTERVAL = 60
SERVER_RETRY_INTERVAL = 10
NUM_TARGET_CONNECTED_SERVERS = 10
NUM_STICKY_SERVERS = 4
NUM_RECENT_SERVERS = 20
@ -117,30 +115,32 @@ def filter_noonion(servers):
return {k: v for k, v in servers.items() if not k.endswith('.onion')}
def filter_protocol(hostmap, protocol='s'):
'''Filters the hostmap for those implementing protocol.
The result is a list in serialized form.'''
def filter_protocol(hostmap, *, allowed_protocols: Iterable[str] = None) -> Sequence[ServerAddr]:
"""Filters the hostmap for those implementing protocol."""
if allowed_protocols is None:
allowed_protocols = {PREFERRED_NETWORK_PROTOCOL}
eligible = []
for host, portmap in hostmap.items():
port = portmap.get(protocol)
if port:
eligible.append(serialize_server(host, port, protocol))
for protocol in allowed_protocols:
port = portmap.get(protocol)
if port:
eligible.append(ServerAddr(host, port, protocol=protocol))
return eligible
def pick_random_server(hostmap=None, protocol='s', exclude_set=None):
def pick_random_server(hostmap=None, *, allowed_protocols: Iterable[str],
exclude_set: Set[ServerAddr] = None) -> Optional[ServerAddr]:
if hostmap is None:
hostmap = constants.net.DEFAULT_SERVERS
if exclude_set is None:
exclude_set = set()
eligible = list(set(filter_protocol(hostmap, protocol)) - exclude_set)
servers = set(filter_protocol(hostmap, allowed_protocols=allowed_protocols))
eligible = list(servers - exclude_set)
return random.choice(eligible) if eligible else None
class NetworkParameters(NamedTuple):
host: str
port: str
protocol: str
server: ServerAddr
proxy: Optional[dict]
auto_connect: bool
oneserver: bool = False
@ -233,19 +233,33 @@ class UntrustedServerReturnedError(NetworkException):
_INSTANCE = None
class Network(Logger):
class Network(Logger, NetworkRetryManager[ServerAddr]):
"""The Network class manages a set of connections to remote electrum
servers, each connected socket is handled by an Interface() object.
"""
LOGGING_SHORTCUT = 'n'
taskgroup: Optional[TaskGroup]
interface: Optional[Interface]
interfaces: Dict[ServerAddr, Interface]
_connecting: Set[ServerAddr]
default_server: ServerAddr
_recent_servers: List[ServerAddr]
def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None):
global _INSTANCE
assert _INSTANCE is None, "Network is a singleton!"
_INSTANCE = self
Logger.__init__(self)
NetworkRetryManager.__init__(
self,
max_retry_delay_normal=600,
init_retry_delay_normal=15,
max_retry_delay_urgent=10,
init_retry_delay_urgent=1,
)
self.asyncio_loop = asyncio.get_event_loop()
assert self.asyncio_loop.is_running(), "event loop not running"
@ -261,50 +275,47 @@ class Network(Logger):
self.logger.info(f"blockchains {list(map(lambda b: b.forkpoint, blockchain.blockchains.values()))}")
self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Optional[Dict]
self._blockchain = blockchain.get_best_chain()
self._allowed_protocols = {PREFERRED_NETWORK_PROTOCOL}
# Server for addresses and transactions
self.default_server = self.config.get('server', None)
# Sanitize default server
if self.default_server:
try:
deserialize_server(self.default_server)
self.default_server = ServerAddr.from_str(self.default_server)
except:
self.logger.warning('failed to parse server-string; falling back to localhost:1:s.')
self.default_server = "localhost:1:s"
if not self.default_server:
self.default_server = pick_random_server()
self.default_server = ServerAddr.from_str("localhost:1:s")
else:
self.default_server = pick_random_server(allowed_protocols=self._allowed_protocols)
assert isinstance(self.default_server, ServerAddr), f"invalid type for default_server: {self.default_server!r}"
self.taskgroup = None # type: TaskGroup
self.taskgroup = None
# locks
self.restart_lock = asyncio.Lock()
self.bhi_lock = asyncio.Lock()
self.callback_lock = threading.Lock()
self.recent_servers_lock = threading.RLock() # <- re-entrant
self.interfaces_lock = threading.Lock() # for mutating/iterating self.interfaces
self.server_peers = {} # returned by interface (servers that the main interface knows about)
self.recent_servers = self._read_recent_servers() # note: needs self.recent_servers_lock
self._recent_servers = self._read_recent_servers() # note: needs self.recent_servers_lock
self.banner = ''
self.donation_address = ''
self.relay_fee = None # type: Optional[int]
# callbacks set by the GUI
self.callbacks = defaultdict(list) # note: needs self.callback_lock
dir_path = os.path.join(self.config.path, 'certs')
util.make_dir(dir_path)
# retry times
self.server_retry_time = time.time()
self.nodes_retry_time = time.time()
# the main server we are currently communicating with
self.interface = None # type: Optional[Interface]
self.interface = None
self.default_server_changed_event = asyncio.Event()
# set of servers we have an ongoing connection with
self.interfaces = {} # type: Dict[str, Interface]
self.interfaces = {}
self.auto_connect = self.config.get('auto_connect', True)
self.connecting = set()
self.server_queue = None
self._connecting = set()
self.proxy = None
# Dump network messages (all interfaces). Set at runtime from the console.
@ -332,7 +343,7 @@ class Network(Logger):
from . import channel_db
self.channel_db = channel_db.ChannelDB(self)
self.path_finder = lnrouter.LNPathFinder(self.channel_db)
self.lngossip = lnworker.LNGossip(self)
self.lngossip = lnworker.LNGossip()
self.lngossip.start_network(self)
def run_from_another_thread(self, coro, *, timeout=None):
@ -350,35 +361,15 @@ class Network(Logger):
return func(self, *args, **kwargs)
return func_wrapper
def register_callback(self, callback, events):
with self.callback_lock:
for event in events:
self.callbacks[event].append(callback)
def unregister_callback(self, callback):
with self.callback_lock:
for callbacks in self.callbacks.values():
if callback in callbacks:
callbacks.remove(callback)
def trigger_callback(self, event, *args):
with self.callback_lock:
callbacks = self.callbacks[event][:]
for callback in callbacks:
# FIXME: if callback throws, we will lose the traceback
if asyncio.iscoroutinefunction(callback):
asyncio.run_coroutine_threadsafe(callback(event, *args), self.asyncio_loop)
else:
self.asyncio_loop.call_soon_threadsafe(callback, event, *args)
def _read_recent_servers(self):
def _read_recent_servers(self) -> List[ServerAddr]:
if not self.config.path:
return []
path = os.path.join(self.config.path, "recent_servers")
try:
with open(path, "r", encoding='utf-8') as f:
data = f.read()
return json.loads(data)
servers_list = json.loads(data)
return [ServerAddr.from_str(s) for s in servers_list]
except:
return []
@ -387,7 +378,7 @@ class Network(Logger):
if not self.config.path:
return
path = os.path.join(self.config.path, "recent_servers")
s = json.dumps(self.recent_servers, indent=4, sort_keys=True)
s = json.dumps(self._recent_servers, indent=4, sort_keys=True, cls=MyEncoder)
try:
with open(path, "w", encoding='utf-8') as f:
f.write(s)
@ -481,15 +472,12 @@ class Network(Logger):
def notify(self, key):
if key in ['status', 'updated']:
self.trigger_callback(key)
util.trigger_callback(key)
else:
self.trigger_callback(key, self.get_status_value(key))
util.trigger_callback(key, self.get_status_value(key))
def get_parameters(self) -> NetworkParameters:
host, port, protocol = deserialize_server(self.default_server)
return NetworkParameters(host=host,
port=port,
protocol=protocol,
return NetworkParameters(server=self.default_server,
proxy=self.proxy,
auto_connect=self.auto_connect,
oneserver=self.oneserver)
@ -498,7 +486,7 @@ class Network(Logger):
if self.is_connected():
return self.donation_address
def get_interfaces(self) -> List[str]:
def get_interfaces(self) -> List[ServerAddr]:
"""The list of servers for the connected interfaces."""
with self.interfaces_lock:
return list(self.interfaces)
@ -540,51 +528,60 @@ class Network(Logger):
# hardcoded servers
out.update(constants.net.DEFAULT_SERVERS)
# add recent servers
for s in self.recent_servers:
try:
host, port, protocol = deserialize_server(s)
except:
continue
if host in out:
out[host].update({protocol: port})
for server in self._recent_servers:
port = str(server.port)
if server.host in out:
out[server.host].update({server.protocol: port})
else:
out[host] = {protocol: port}
out[server.host] = {server.protocol: port}
# potentially filter out some
if self.config.get('noonion'):
out = filter_noonion(out)
return out
def _start_interface(self, server: str):
if server not in self.interfaces and server not in self.connecting:
if server == self.default_server:
self.logger.info(f"connecting to {server} as new interface")
self._set_status('connecting')
self.connecting.add(server)
self.server_queue.put(server)
def _start_random_interface(self):
def _get_next_server_to_try(self) -> Optional[ServerAddr]:
now = time.time()
with self.interfaces_lock:
exclude_set = self.disconnected_servers | set(self.interfaces) | self.connecting
server = pick_random_server(self.get_servers(), self.protocol, exclude_set)
if server:
self._start_interface(server)
return server
connected_servers = set(self.interfaces) | self._connecting
# First try from recent servers. (which are persisted)
# As these are servers we successfully connected to recently, they are
# most likely to work. This also makes servers "sticky".
# Note: with sticky servers, it is more difficult for an attacker to eclipse the client,
# however if they succeed, the eclipsing would persist. To try to balance this,
# we only give priority to recent_servers up to NUM_STICKY_SERVERS.
with self.recent_servers_lock:
recent_servers = list(self._recent_servers)
recent_servers = [s for s in recent_servers if s.protocol in self._allowed_protocols]
if len(connected_servers & set(recent_servers)) < NUM_STICKY_SERVERS:
for server in recent_servers:
if server in connected_servers:
continue
if not self._can_retry_addr(server, now=now):
continue
return server
# try all servers we know about, pick one at random
hostmap = self.get_servers()
servers = list(set(filter_protocol(hostmap, allowed_protocols=self._allowed_protocols)) - connected_servers)
random.shuffle(servers)
for server in servers:
if not self._can_retry_addr(server, now=now):
continue
return server
return None
def _set_proxy(self, proxy: Optional[dict]):
self.proxy = proxy
dns_hacks.configure_dns_depending_on_proxy(bool(proxy))
self.logger.info(f'setting proxy {proxy}')
self.trigger_callback('proxy_set', self.proxy)
util.trigger_callback('proxy_set', self.proxy)
@log_exceptions
async def set_parameters(self, net_params: NetworkParameters):
proxy = net_params.proxy
proxy_str = serialize_proxy(proxy)
host, port, protocol = net_params.host, net_params.port, net_params.protocol
server_str = serialize_server(host, port, protocol)
server = net_params.server
# sanitize parameters
try:
deserialize_server(serialize_server(host, port, protocol))
if proxy:
proxy_modes.index(proxy['mode']) + 1
int(proxy['port'])
@ -593,22 +590,22 @@ class Network(Logger):
self.config.set_key('auto_connect', net_params.auto_connect, False)
self.config.set_key('oneserver', net_params.oneserver, False)
self.config.set_key('proxy', proxy_str, False)
self.config.set_key('server', server_str, True)
self.config.set_key('server', str(server), True)
# abort if changes were not allowed by config
if self.config.get('server') != server_str \
if self.config.get('server') != str(server) \
or self.config.get('proxy') != proxy_str \
or self.config.get('oneserver') != net_params.oneserver:
return
async with self.restart_lock:
self.auto_connect = net_params.auto_connect
if self.proxy != proxy or self.protocol != protocol or self.oneserver != net_params.oneserver:
if self.proxy != proxy or self.oneserver != net_params.oneserver:
# Restart the network defaulting to the given server
await self._stop()
self.default_server = server_str
self.default_server = server
await self._start()
elif self.default_server != server_str:
await self.switch_to_interface(server_str)
elif self.default_server != server:
await self.switch_to_interface(server)
else:
await self.switch_lagging_interface()
@ -670,7 +667,7 @@ class Network(Logger):
# FIXME switch to best available?
self.logger.info("tried to switch to best chain but no interfaces are on it")
async def switch_to_interface(self, server: str):
async def switch_to_interface(self, server: ServerAddr):
"""Switch to server as our main interface. If no connection exists,
queue interface to be started. The actual switch will
happen when the interface becomes ready.
@ -686,11 +683,11 @@ class Network(Logger):
if old_server and old_server != server:
await self._close_interface(old_interface)
if len(self.interfaces) <= self.num_server:
self._start_interface(old_server)
await self.taskgroup.spawn(self._run_new_interface(old_server))
if server not in self.interfaces:
self.interface = None
self._start_interface(server)
await self.taskgroup.spawn(self._run_new_interface(server))
return
i = self.interfaces[server]
@ -700,12 +697,13 @@ class Network(Logger):
blockchain_updated = i.blockchain != self.blockchain()
self.interface = i
await i.taskgroup.spawn(self._request_server_info(i))
self.trigger_callback('default_server_changed')
util.trigger_callback('default_server_changed')
self.default_server_changed_event.set()
self.default_server_changed_event.clear()
self._set_status('connected')
self.trigger_callback('network_updated')
if blockchain_updated: self.trigger_callback('blockchain_updated')
util.trigger_callback('network_updated')
if blockchain_updated:
util.trigger_callback('blockchain_updated')
async def _close_interface(self, interface: Interface):
if interface:
@ -717,12 +715,13 @@ class Network(Logger):
await interface.close()
@with_recent_servers_lock
def _add_recent_server(self, server):
def _add_recent_server(self, server: ServerAddr) -> None:
self._on_connection_successfully_established(server)
# list is ordered
if server in self.recent_servers:
self.recent_servers.remove(server)
self.recent_servers.insert(0, server)
self.recent_servers = self.recent_servers[:NUM_RECENT_SERVERS]
if server in self._recent_servers:
self._recent_servers.remove(server)
self._recent_servers.insert(0, server)
self._recent_servers = self._recent_servers[:NUM_RECENT_SERVERS]
self._save_recent_servers()
async def connection_down(self, interface: Interface):
@ -730,11 +729,10 @@ class Network(Logger):
We distinguish by whether it is in self.interfaces.'''
if not interface: return
server = interface.server
self.disconnected_servers.add(server)
if server == self.default_server:
self._set_status('disconnected')
await self._close_interface(interface)
self.trigger_callback('network_updated')
util.trigger_callback('network_updated')
def get_network_timeout_seconds(self, request_type=NetworkTimeout.Generic) -> int:
if self.oneserver and not self.auto_connect:
@ -743,10 +741,18 @@ class Network(Logger):
return request_type.RELAXED
return request_type.NORMAL
@ignore_exceptions # do not kill main_taskgroup
@ignore_exceptions # do not kill outer taskgroup
@log_exceptions
async def _run_new_interface(self, server):
interface = Interface(self, server, self.proxy)
async def _run_new_interface(self, server: ServerAddr):
if server in self.interfaces or server in self._connecting:
return
self._connecting.add(server)
if server == self.default_server:
self.logger.info(f"connecting to {server} as new interface")
self._set_status('connecting')
self._trying_addr_now(server)
interface = Interface(network=self, server=server, proxy=self.proxy)
# note: using longer timeouts here as DNS can sometimes be slow!
timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic)
try:
@ -760,16 +766,16 @@ class Network(Logger):
assert server not in self.interfaces
self.interfaces[server] = interface
finally:
try: self.connecting.remove(server)
try: self._connecting.remove(server)
except KeyError: pass
if server == self.default_server:
await self.switch_to_interface(server)
self._add_recent_server(server)
self.trigger_callback('network_updated')
util.trigger_callback('network_updated')
def check_interface_against_healthy_spread_of_connected_servers(self, iface_to_check) -> bool:
def check_interface_against_healthy_spread_of_connected_servers(self, iface_to_check: Interface) -> bool:
# main interface is exempt. this makes switching servers easier
if iface_to_check.is_main_server():
return True
@ -1093,23 +1099,21 @@ class Network(Logger):
with self.interfaces_lock: interfaces = list(self.interfaces.values())
interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces))
if len(interfaces_on_selected_chain) == 0: return
chosen_iface = random.choice(interfaces_on_selected_chain)
chosen_iface = random.choice(interfaces_on_selected_chain) # type: Interface
# switch to server (and save to config)
net_params = self.get_parameters()
host, port, protocol = deserialize_server(chosen_iface.server)
net_params = net_params._replace(host=host, port=port, protocol=protocol)
net_params = net_params._replace(server=chosen_iface.server)
await self.set_parameters(net_params)
async def follow_chain_given_server(self, server_str: str) -> None:
async def follow_chain_given_server(self, server: ServerAddr) -> None:
# note that server_str should correspond to a connected interface
iface = self.interfaces.get(server_str)
iface = self.interfaces.get(server)
if iface is None:
return
self._set_preferred_chain(iface.blockchain)
# switch to server (and save to config)
net_params = self.get_parameters()
host, port, protocol = deserialize_server(server_str)
net_params = net_params._replace(host=host, port=port, protocol=protocol)
net_params = net_params._replace(server=server)
await self.set_parameters(net_params)
def get_local_height(self):
@ -1127,14 +1131,12 @@ class Network(Logger):
assert not self.taskgroup
self.taskgroup = taskgroup = SilentTaskGroup()
assert not self.interface and not self.interfaces
assert not self.connecting and not self.server_queue
assert not self._connecting
self.logger.info('starting network')
self.disconnected_servers = set([])
self.protocol = deserialize_server(self.default_server)[2]
self.server_queue = queue.Queue()
self._clear_addr_retry_times()
self._set_proxy(deserialize_proxy(self.config.get('proxy')))
self._set_oneserver(self.config.get('oneserver', False))
self._start_interface(self.default_server)
await self.taskgroup.spawn(self._run_new_interface(self.default_server))
async def main():
self.logger.info("starting taskgroup.")
@ -1152,7 +1154,7 @@ class Network(Logger):
self.logger.info("taskgroup stopped.")
asyncio.run_coroutine_threadsafe(main(), self.asyncio_loop)
self.trigger_callback('network_updated')
util.trigger_callback('network_updated')
def start(self, jobs: Iterable = None):
"""Schedule starting the network, along with the given job co-routines.
@ -1170,13 +1172,12 @@ class Network(Logger):
await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2)
except (asyncio.TimeoutError, asyncio.CancelledError) as e:
self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}")
self.taskgroup = None # type: TaskGroup
self.interface = None # type: Interface
self.interfaces = {} # type: Dict[str, Interface]
self.connecting.clear()
self.server_queue = None
self.taskgroup = None
self.interface = None
self.interfaces = {}
self._connecting.clear()
if not full_shutdown:
self.trigger_callback('network_updated')
util.trigger_callback('network_updated')
def stop(self):
assert self._loop_thread != threading.current_thread(), 'must not be called from network thread'
@ -1188,33 +1189,21 @@ class Network(Logger):
async def _ensure_there_is_a_main_interface(self):
if self.is_connected():
return
now = time.time()
# if auto_connect is set, try a different server
if self.auto_connect and not self.is_connecting():
await self._switch_to_random_interface()
# if auto_connect is not set, or still no main interface, retry current
if not self.is_connected() and not self.is_connecting():
if self.default_server in self.disconnected_servers:
if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
self.disconnected_servers.remove(self.default_server)
self.server_retry_time = now
else:
if self._can_retry_addr(self.default_server, urgent=True):
await self.switch_to_interface(self.default_server)
async def _maintain_sessions(self):
async def launch_already_queued_up_new_interfaces():
while self.server_queue.qsize() > 0:
server = self.server_queue.get()
await self.taskgroup.spawn(self._run_new_interface(server))
async def maybe_queue_new_interfaces_to_be_launched_later():
now = time.time()
for i in range(self.num_server - len(self.interfaces) - len(self.connecting)):
async def maybe_start_new_interfaces():
for i in range(self.num_server - len(self.interfaces) - len(self._connecting)):
# FIXME this should try to honour "healthy spread of connected servers"
self._start_random_interface()
if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
self.logger.info('network: retrying connections')
self.disconnected_servers = set([])
self.nodes_retry_time = now
server = self._get_next_server_to_try()
if server:
await self.taskgroup.spawn(self._run_new_interface(server))
async def maintain_healthy_spread_of_connected_servers():
with self.interfaces_lock: interfaces = list(self.interfaces.values())
random.shuffle(interfaces)
@ -1231,8 +1220,7 @@ class Network(Logger):
while True:
try:
await launch_already_queued_up_new_interfaces()
await maybe_queue_new_interfaces_to_be_launched_later()
await maybe_start_new_interfaces()
await maintain_healthy_spread_of_connected_servers()
await maintain_main_interface()
except asyncio.CancelledError:
@ -1289,10 +1277,10 @@ class Network(Logger):
session = self.interface.session
return parse_servers(await session.send_request('server.peers.subscribe'))
async def send_multiple_requests(self, servers: List[str], method: str, params: Sequence):
async def send_multiple_requests(self, servers: Sequence[ServerAddr], method: str, params: Sequence):
responses = dict()
async def get_response(server):
interface = Interface(self, server, self.proxy)
async def get_response(server: ServerAddr):
interface = Interface(network=self, server=server, proxy=self.proxy)
timeout = self.get_network_timeout_seconds(NetworkTimeout.Urgent)
try:
await asyncio.wait_for(interface.ready, timeout)

2
electrum/scripts/peers.py

@ -17,7 +17,7 @@ network.start()
async def f():
try:
peers = await network.get_peers()
peers = filter_protocol(peers, 's')
peers = filter_protocol(peers)
results = await network.send_multiple_requests(peers, 'blockchain.headers.subscribe', [])
for server, header in sorted(results.items(), key=lambda x: x[1].get('height')):
height = header.get('height')

2
electrum/scripts/txradar.py

@ -23,7 +23,7 @@ network.start()
async def f():
try:
peers = await network.get_peers()
peers = filter_protocol(peers, 's')
peers = filter_protocol(peers)
results = await network.send_multiple_requests(peers, 'blockchain.transaction.get', [txid])
r1, r2 = [], []
for k, v in results.items():

6
electrum/sql_db.py

@ -19,9 +19,9 @@ def sql(func):
class SqlDB(Logger):
def __init__(self, network, path, commit_interval=None):
def __init__(self, asyncio_loop, path, commit_interval=None):
Logger.__init__(self)
self.network = network
self.asyncio_loop = asyncio_loop
self.path = path
self.commit_interval = commit_interval
self.db_requests = queue.Queue()
@ -34,7 +34,7 @@ class SqlDB(Logger):
self.logger.info("Creating database")
self.create_database()
i = 0
while self.network.asyncio_loop.is_running():
while self.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:

5
electrum/synchronizer.py

@ -30,6 +30,7 @@ import logging
from aiorpcx import TaskGroup, run_in_thread, RPCError
from . import util
from .transaction import Transaction, PartialTransaction
from .util import bh2u, make_aiohttp_session, NetworkJobOnDefaultServer
from .bitcoin import address_to_scripthash, is_address
@ -227,7 +228,7 @@ class Synchronizer(SynchronizerBase):
self.wallet.receive_tx_callback(tx_hash, tx, tx_height)
self.logger.info(f"received tx {tx_hash} height: {tx_height} bytes: {len(raw_tx)}")
# callbacks
self.wallet.network.trigger_callback('new_transaction', self.wallet, tx)
util.trigger_callback('new_transaction', self.wallet, tx)
async def main(self):
self.wallet.set_up_to_date(False)
@ -252,7 +253,7 @@ class Synchronizer(SynchronizerBase):
if up_to_date:
self._reset_request_counters()
self.wallet.set_up_to_date(up_to_date)
self.wallet.network.trigger_callback('wallet_updated', self.wallet)
util.trigger_callback('wallet_updated', self.wallet)
class Notifier(SynchronizerBase):

14
electrum/tests/test_lnpeer.py

@ -17,7 +17,7 @@ from electrum.ecc import ECPrivkey
from electrum import simple_config, lnutil
from electrum.lnaddr import lnencode, LnAddr, lndecode
from electrum.bitcoin import COIN, sha256
from electrum.util import bh2u, create_and_start_event_loop
from electrum.util import bh2u, create_and_start_event_loop, NetworkRetryManager
from electrum.lnpeer import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
@ -64,10 +64,6 @@ class MockNetwork:
def callback_lock(self):
return noop_lock()
register_callback = Network.register_callback
unregister_callback = Network.unregister_callback
trigger_callback = Network.trigger_callback
def get_local_height(self):
return 0
@ -99,9 +95,10 @@ class MockWallet:
def is_lightning_backup(self):
return False
class MockLNWallet(Logger):
class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue):
Logger.__init__(self)
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
self.remote_keypair = remote_keypair
self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue)
@ -127,6 +124,10 @@ class MockLNWallet(Logger):
@property
def peers(self):
return self._peers
@property
def _peers(self):
return {self.remote_keypair.pubkey: self.peer}
def channels_for_peer(self, pubkey):
@ -164,6 +165,7 @@ class MockLNWallet(Logger):
force_close_channel = LNWallet.force_close_channel
try_force_closing = LNWallet.try_force_closing
get_first_timestamp = lambda self: 0
on_peer_successfully_established = LNWallet.on_peer_successfully_established
class MockTransport:

2
electrum/tests/test_lntransport.py

@ -57,7 +57,7 @@ class TestLNTransport(ElectrumTestCase):
server = server_future.result() # type: asyncio.Server
async def connect():
peer_addr = LNPeerAddr('127.0.0.1', 42898, responder_key.get_public_key_bytes())
t = LNTransport(initiator_key.get_secret_bytes(), peer_addr)
t = LNTransport(initiator_key.get_secret_bytes(), peer_addr, proxy=None)
await t.handshake()
t.send_bytes(b'hello from client')
self.assertEqual(await t.read_messages().__anext__(), b'hello from server')

4
electrum/tests/test_network.py

@ -5,7 +5,7 @@ import unittest
from electrum import constants
from electrum.simple_config import SimpleConfig
from electrum import blockchain
from electrum.interface import Interface
from electrum.interface import Interface, ServerAddr
from electrum.crypto import sha256
from electrum.util import bh2u
@ -24,7 +24,7 @@ class MockInterface(Interface):
self.config = config
network = MockNetwork()
network.config = config
super().__init__(network, 'mock-server:50000:t', None)
super().__init__(network=network, server=ServerAddr.from_str('mock-server:50000:t'), proxy=None)
self.q = asyncio.Queue()
self.blockchain = blockchain.Blockchain(config=self.config, forkpoint=0,
parent=None, forkpoint_hash=constants.net.GENESIS, prev_hash=None)

128
electrum/util.py

@ -23,7 +23,8 @@
import binascii
import os, sys, re, json
from collections import defaultdict, OrderedDict
from typing import NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any, Sequence
from typing import (NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any,
Sequence, Dict, Generic, TypeVar)
from datetime import datetime
import decimal
from decimal import Decimal
@ -41,9 +42,11 @@ import time
from typing import NamedTuple, Optional
import ssl
import ipaddress
import random
import aiohttp
from aiohttp_socks import ProxyConnector, ProxyType
import aiorpcx
from aiorpcx import TaskGroup
import certifi
import dns.resolver
@ -1130,7 +1133,7 @@ class NetworkJobOnDefaultServer(Logger):
self._restart_lock = asyncio.Lock()
self._reset()
asyncio.run_coroutine_threadsafe(self._restart(), network.asyncio_loop)
network.register_callback(self._restart, ['default_server_changed'])
register_callback(self._restart, ['default_server_changed'])
def _reset(self):
"""Initialise fields. Called every time the underlying
@ -1304,3 +1307,124 @@ def randrange(bound: int) -> int:
"""Return a random integer k such that 1 <= k < bound, uniformly
distributed across that range."""
return ecdsa.util.randrange(bound)
class CallbackManager:
# callbacks set by the GUI
def __init__(self):
self.callback_lock = threading.Lock()
self.callbacks = defaultdict(list) # note: needs self.callback_lock
self.asyncio_loop = None
def register_callback(self, callback, events):
with self.callback_lock:
for event in events:
self.callbacks[event].append(callback)
def unregister_callback(self, callback):
with self.callback_lock:
for callbacks in self.callbacks.values():
if callback in callbacks:
callbacks.remove(callback)
def trigger_callback(self, event, *args):
if self.asyncio_loop is None:
self.asyncio_loop = asyncio.get_event_loop()
assert self.asyncio_loop.is_running(), "event loop not running"
with self.callback_lock:
callbacks = self.callbacks[event][:]
for callback in callbacks:
# FIXME: if callback throws, we will lose the traceback
if asyncio.iscoroutinefunction(callback):
asyncio.run_coroutine_threadsafe(callback(event, *args), self.asyncio_loop)
else:
self.asyncio_loop.call_soon_threadsafe(callback, event, *args)
callback_mgr = CallbackManager()
trigger_callback = callback_mgr.trigger_callback
register_callback = callback_mgr.register_callback
unregister_callback = callback_mgr.unregister_callback
_NetAddrType = TypeVar("_NetAddrType")
class NetworkRetryManager(Generic[_NetAddrType]):
"""Truncated Exponential Backoff for network connections."""
def __init__(
self, *,
max_retry_delay_normal: float,
init_retry_delay_normal: float,
max_retry_delay_urgent: float = None,
init_retry_delay_urgent: float = None,
):
self._last_tried_addr = {} # type: Dict[_NetAddrType, Tuple[float, int]] # (unix ts, num_attempts)
# note: these all use "seconds" as unit
if max_retry_delay_urgent is None:
max_retry_delay_urgent = max_retry_delay_normal
if init_retry_delay_urgent is None:
init_retry_delay_urgent = init_retry_delay_normal
self._max_retry_delay_normal = max_retry_delay_normal
self._init_retry_delay_normal = init_retry_delay_normal
self._max_retry_delay_urgent = max_retry_delay_urgent
self._init_retry_delay_urgent = init_retry_delay_urgent
def _trying_addr_now(self, addr: _NetAddrType) -> None:
last_time, num_attempts = self._last_tried_addr.get(addr, (0, 0))
# we add up to 1 second of noise to the time, so that clients are less likely
# to get synchronised and bombard the remote in connection waves:
cur_time = time.time() + random.random()
self._last_tried_addr[addr] = cur_time, num_attempts + 1
def _on_connection_successfully_established(self, addr: _NetAddrType) -> None:
self._last_tried_addr[addr] = time.time(), 0
def _can_retry_addr(self, peer: _NetAddrType, *,
now: float = None, urgent: bool = False) -> bool:
if now is None:
now = time.time()
last_time, num_attempts = self._last_tried_addr.get(peer, (0, 0))
if urgent:
delay = min(self._max_retry_delay_urgent,
self._init_retry_delay_urgent * 2 ** num_attempts)
else:
delay = min(self._max_retry_delay_normal,
self._init_retry_delay_normal * 2 ** num_attempts)
next_time = last_time + delay
return next_time < now
def _clear_addr_retry_times(self) -> None:
self._last_tried_addr.clear()
class MySocksProxy(aiorpcx.SOCKSProxy):
async def open_connection(self, host=None, port=None, **kwargs):
loop = asyncio.get_event_loop()
reader = asyncio.StreamReader(loop=loop)
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
transport, _ = await self.create_connection(
lambda: protocol, host, port, **kwargs)
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
return reader, writer
@classmethod
def from_proxy_dict(cls, proxy: dict = None) -> Optional['MySocksProxy']:
if not proxy:
return None
username, pw = proxy.get('user'), proxy.get('password')
if not username or not pw:
auth = None
else:
auth = aiorpcx.socks.SOCKSUserAuth(username, pw)
addr = aiorpcx.NetAddress(proxy['host'], proxy['port'])
if proxy['mode'] == "socks4":
ret = cls(addr, aiorpcx.socks.SOCKS4a, auth)
elif proxy['mode'] == "socks5":
ret = cls(addr, aiorpcx.socks.SOCKS5, auth)
else:
raise NotImplementedError # http proxy not available with aiorpcx
return ret

Loading…
Cancel
Save