Browse Source

network: tighten checks of server responses for type/sanity

patch-4
SomberNight 4 years ago
parent
commit
c5da22a9dd
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 76
      electrum/interface.py
  2. 21
      electrum/network.py
  3. 4
      electrum/simple_config.py
  4. 87
      electrum/tests/test_util.py
  5. 40
      electrum/util.py

76
electrum/interface.py

@ -29,7 +29,7 @@ import sys
import traceback
import asyncio
import socket
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any, Sequence
from collections import defaultdict
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
import itertools
@ -46,13 +46,14 @@ import certifi
from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy,
is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
is_real_number)
is_int_or_float, is_non_negative_int_or_float)
from . import util
from . import x509
from . import pem
from . import version
from . import blockchain
from .blockchain import Blockchain, HEADER_SIZE
from . import bitcoin
from . import constants
from .i18n import _
from .logging import Logger
@ -96,9 +97,14 @@ def assert_integer(val: Any) -> None:
raise RequestCorrupted(f'{val!r} should be an integer')
def assert_real_number(val: Any, *, as_str: bool = False) -> None:
if not is_real_number(val, as_str=as_str):
raise RequestCorrupted(f'{val!r} should be a number')
def assert_int_or_float(val: Any) -> None:
if not is_int_or_float(val):
raise RequestCorrupted(f'{val!r} should be int or float')
def assert_non_negative_int_or_float(val: Any) -> None:
if not is_non_negative_int_or_float(val):
raise RequestCorrupted(f'{val!r} should be a non-negative int or float')
def assert_hash256_str(val: Any) -> None:
@ -656,14 +662,13 @@ class Interface(Logger):
async def request_fee_estimates(self):
from .simple_config import FEE_ETA_TARGETS
from .bitcoin import COIN
while True:
async with TaskGroup() as group:
fee_tasks = []
for i in FEE_ETA_TARGETS:
fee_tasks.append((i, await group.spawn(self.session.send_request('blockchain.estimatefee', [i]))))
fee_tasks.append((i, await group.spawn(self.get_estimatefee(i))))
for nblock_target, task in fee_tasks:
fee = int(task.result() * COIN)
fee = task.result()
if fee < 0: continue
self.fee_estimates_eta[nblock_target] = fee
self.network.update_fee_estimates()
@ -983,6 +988,61 @@ class Interface(Logger):
assert_hash256_str(res)
return res
async def get_fee_histogram(self) -> Sequence[Tuple[Union[float, int], int]]:
# do request
res = await self.session.send_request('mempool.get_fee_histogram')
# check response
assert_list_or_tuple(res)
for fee, s in res:
assert_non_negative_int_or_float(fee)
assert_non_negative_integer(s)
return res
async def get_server_banner(self) -> str:
# do request
res = await self.session.send_request('server.banner')
# check response
if not isinstance(res, str):
raise RequestCorrupted(f'{res!r} should be a str')
return res
async def get_donation_address(self) -> str:
# do request
res = await self.session.send_request('server.donation_address')
# check response
if not res: # ignore empty string
return ''
if not bitcoin.is_address(res):
# note: do not hard-fail -- allow server to use future-type
# bitcoin address we do not recognize
self.logger.info(f"invalid donation address from server: {repr(res)}")
res = ''
return res
async def get_relay_fee(self) -> int:
"""Returns the min relay feerate in sat/kbyte."""
# do request
res = await self.session.send_request('blockchain.relayfee')
# check response
assert_non_negative_int_or_float(res)
relayfee = int(res * bitcoin.COIN)
relayfee = max(0, relayfee)
return relayfee
async def get_estimatefee(self, num_blocks: int) -> int:
"""Returns a feerate estimate for getting confirmed within
num_blocks blocks, in sat/kbyte.
"""
if not is_non_negative_integer(num_blocks):
raise Exception(f"{repr(num_blocks)} is not a num_blocks")
# do request
res = await self.session.send_request('blockchain.estimatefee', [num_blocks])
# check response
if res != -1:
assert_non_negative_int_or_float(res)
res = int(res * bitcoin.COIN)
return res
def _assert_header_does_not_check_against_any_chain(header: dict) -> None:
chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)

21
electrum/network.py

@ -418,20 +418,15 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
def is_connecting(self):
return self.connection_status == 'connecting'
async def _request_server_info(self, interface):
async def _request_server_info(self, interface: 'Interface'):
await interface.ready
session = interface.session
async def get_banner():
self.banner = await session.send_request('server.banner')
self.banner = await interface.get_server_banner()
self.notify('banner')
async def get_donation_address():
addr = await session.send_request('server.donation_address')
if not bitcoin.is_address(addr):
if addr: # ignore empty string
self.logger.info(f"invalid donation address from server: {repr(addr)}")
addr = ''
self.donation_address = addr
self.donation_address = await interface.get_donation_address()
async def get_server_peers():
server_peers = await session.send_request('server.peers.subscribe')
random.shuffle(server_peers)
@ -441,12 +436,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
self.server_peers = parse_servers(server_peers)
self.notify('servers')
async def get_relay_fee():
relayfee = await session.send_request('blockchain.relayfee')
if relayfee is None:
self.relay_fee = None
else:
relayfee = int(relayfee * COIN)
self.relay_fee = max(0, relayfee)
self.relay_fee = await interface.get_relay_fee()
async with TaskGroup() as group:
await group.spawn(get_banner)
@ -456,9 +446,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
await group.spawn(self._request_fee_estimates(interface))
async def _request_fee_estimates(self, interface):
session = interface.session
self.config.requested_fee_estimates()
histogram = await session.send_request('mempool.get_fee_histogram')
histogram = await interface.get_fee_histogram()
self.config.mempool_fees = histogram
self.logger.info(f'fee_histogram {histogram}')
self.notify('fee_histogram')

4
electrum/simple_config.py

@ -5,7 +5,7 @@ import os
import stat
import ssl
from decimal import Decimal
from typing import Union, Optional, Dict
from typing import Union, Optional, Dict, Sequence, Tuple
from numbers import Real
from copy import deepcopy
@ -65,7 +65,7 @@ class SimpleConfig(Logger):
# a thread-safe way.
self.lock = threading.RLock()
self.mempool_fees = {} # type: Dict[Union[float, int], int]
self.mempool_fees = [] # type: Sequence[Tuple[Union[float, int], int]]
self.fee_estimates = {}
self.fee_estimates_last_updated = {}
self.last_time_fee_estimates_requested = 0 # zero ensures immediate fees

87
electrum/tests/test_util.py

@ -2,7 +2,9 @@ from decimal import Decimal
from electrum.util import (format_satoshis, format_fee_satoshis, parse_URI,
is_hash256_str, chunks, is_ip_address, list_enabled_bits,
format_satoshis_plain, is_private_netaddress)
format_satoshis_plain, is_private_netaddress, is_hex_str,
is_integer, is_non_negative_integer, is_int_or_float,
is_non_negative_int_or_float)
from . import ElectrumTestCase
@ -121,6 +123,89 @@ class TestUtil(ElectrumTestCase):
self.assertFalse(is_hash256_str(None))
self.assertFalse(is_hash256_str(7))
def test_is_hex_str(self):
self.assertTrue(is_hex_str('09a4'))
self.assertTrue(is_hex_str('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
self.assertTrue(is_hex_str('00' * 33))
self.assertFalse(is_hex_str('000'))
self.assertFalse(is_hex_str('qweqwe'))
self.assertFalse(is_hex_str(None))
self.assertFalse(is_hex_str(7))
def test_is_integer(self):
self.assertTrue(is_integer(7))
self.assertTrue(is_integer(0))
self.assertTrue(is_integer(-1))
self.assertTrue(is_integer(-7))
self.assertFalse(is_integer(Decimal("2.0")))
self.assertFalse(is_integer(Decimal(2.0)))
self.assertFalse(is_integer(Decimal(2)))
self.assertFalse(is_integer(0.72))
self.assertFalse(is_integer(2.0))
self.assertFalse(is_integer(-2.0))
self.assertFalse(is_integer('09a4'))
self.assertFalse(is_integer('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
self.assertFalse(is_integer('000'))
self.assertFalse(is_integer('qweqwe'))
self.assertFalse(is_integer(None))
def test_is_non_negative_integer(self):
self.assertTrue(is_non_negative_integer(7))
self.assertTrue(is_non_negative_integer(0))
self.assertFalse(is_non_negative_integer(Decimal("2.0")))
self.assertFalse(is_non_negative_integer(Decimal(2.0)))
self.assertFalse(is_non_negative_integer(Decimal(2)))
self.assertFalse(is_non_negative_integer(0.72))
self.assertFalse(is_non_negative_integer(2.0))
self.assertFalse(is_non_negative_integer(-2.0))
self.assertFalse(is_non_negative_integer(-1))
self.assertFalse(is_non_negative_integer(-7))
self.assertFalse(is_non_negative_integer('09a4'))
self.assertFalse(is_non_negative_integer('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
self.assertFalse(is_non_negative_integer('000'))
self.assertFalse(is_non_negative_integer('qweqwe'))
self.assertFalse(is_non_negative_integer(None))
def test_is_int_or_float(self):
self.assertTrue(is_int_or_float(7))
self.assertTrue(is_int_or_float(0))
self.assertTrue(is_int_or_float(-1))
self.assertTrue(is_int_or_float(-7))
self.assertTrue(is_int_or_float(0.72))
self.assertTrue(is_int_or_float(2.0))
self.assertTrue(is_int_or_float(-2.0))
self.assertFalse(is_int_or_float(Decimal("2.0")))
self.assertFalse(is_int_or_float(Decimal(2.0)))
self.assertFalse(is_int_or_float(Decimal(2)))
self.assertFalse(is_int_or_float('09a4'))
self.assertFalse(is_int_or_float('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
self.assertFalse(is_int_or_float('000'))
self.assertFalse(is_int_or_float('qweqwe'))
self.assertFalse(is_int_or_float(None))
def test_is_non_negative_int_or_float(self):
self.assertTrue(is_non_negative_int_or_float(7))
self.assertTrue(is_non_negative_int_or_float(0))
self.assertTrue(is_non_negative_int_or_float(0.0))
self.assertTrue(is_non_negative_int_or_float(0.72))
self.assertTrue(is_non_negative_int_or_float(2.0))
self.assertFalse(is_non_negative_int_or_float(-1))
self.assertFalse(is_non_negative_int_or_float(-7))
self.assertFalse(is_non_negative_int_or_float(-2.0))
self.assertFalse(is_non_negative_int_or_float(Decimal("2.0")))
self.assertFalse(is_non_negative_int_or_float(Decimal(2.0)))
self.assertFalse(is_non_negative_int_or_float(Decimal(2)))
self.assertFalse(is_non_negative_int_or_float('09a4'))
self.assertFalse(is_non_negative_int_or_float('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
self.assertFalse(is_non_negative_int_or_float('000'))
self.assertFalse(is_non_negative_int_or_float('qweqwe'))
self.assertFalse(is_non_negative_int_or_float(None))
def test_chunks(self):
self.assertEqual([[1, 2], [3, 4], [5]],
list(chunks([1, 2, 3, 4, 5], 2)))

40
electrum/util.py

@ -588,38 +588,24 @@ def is_hex_str(text: Any) -> bool:
return True
def is_non_negative_integer(val) -> bool:
try:
val = int(val)
if val >= 0:
return True
except:
pass
def is_integer(val: Any) -> bool:
return isinstance(val, int)
def is_non_negative_integer(val: Any) -> bool:
if is_integer(val):
return val >= 0
return False
def is_integer(val) -> bool:
try:
int(val)
except:
return False
else:
return True
def is_int_or_float(val: Any) -> bool:
return isinstance(val, (int, float))
def is_real_number(val, *, as_str: bool = False) -> bool:
if as_str: # only accept str
if not isinstance(val, str):
return False
else: # only accept int/float/etc.
if isinstance(val, str):
return False
try:
Decimal(val)
except:
return False
else:
return True
def is_non_negative_int_or_float(val: Any) -> bool:
if is_int_or_float(val):
return val >= 0
return False
def chunks(items, size: int):

Loading…
Cancel
Save