Browse Source

Handle client protocol range requests.

Add more tests.
master
Neil Booth 7 years ago
parent
commit
9c25685eb9
  1. 24
      lib/util.py
  2. 21
      server/session.py
  3. 17
      tests/lib/test_util.py

24
lib/util.py

@ -278,3 +278,27 @@ def protocol_tuple(s):
return tuple(int(part) for part in s.split('.'))
except Exception:
return (0, )
def protocol_version(client_req, server_min, server_max):
'''Given a client protocol request, return the protocol version
to use as a tuple.
If a mutually acceptable protocol version does not exist, return None.
'''
if isinstance(client_req, list) and len(client_req) == 2:
client_min, client_max = client_req
elif client_req is None:
client_min = client_max = server_min
else:
client_min = client_max = client_req
client_min = protocol_tuple(client_min)
client_max = protocol_tuple(client_max)
server_min = protocol_tuple(server_min)
server_max = protocol_tuple(server_max)
result = min(client_max, server_max)
if result < max(client_min, server_min) or result == (0, ):
result = None
return result

21
server/session.py

@ -367,19 +367,18 @@ class ElectrumX(SessionBase):
return message
def set_protocol_handlers(self, version_str):
controller = self.controller
if version_str is None:
version_str = version.PROTOCOL_MIN
ptuple = util.protocol_tuple(version_str)
# Disconnect if requested protocol version in unsupported
if (ptuple < util.protocol_tuple(version.PROTOCOL_MIN)
or ptuple > util.protocol_tuple(version.PROTOCOL_MAX)):
self.log_info('unsupported protocol version {}'
.format(version_str))
def set_protocol_handlers(self, version_req):
# Find the highest common protocol version. Disconnect if
# that protocol version in unsupported.
ptuple = util.protocol_version(version_req, version.PROTOCOL_MIN,
version.PROTOCOL_MAX)
if ptuple is None:
self.log_info('unsupported protocol version request {}'
.format(version_req))
raise RPCError('unsupported protocol version: {}'
.format(version_str), JSONRPC.FATAL_ERROR)
.format(version_req), JSONRPC.FATAL_ERROR)
controller = self.controller
handlers = {
'blockchain.address.get_balance': controller.address_get_balance,
'blockchain.address.get_history': controller.address_get_history,

17
tests/lib/test_util.py

@ -93,3 +93,20 @@ def test_protocol_tuple():
assert util.protocol_tuple("0.1") == (0, 1)
assert util.protocol_tuple("0.10") == (0, 10)
assert util.protocol_tuple("2.5.3") == (2, 5, 3)
def test_protocol_version():
assert util.protocol_version(None, "1.0", "1.0") == (1, 0)
assert util.protocol_version("0.10", "0.10", "1.1") == (0, 10)
assert util.protocol_version("1.0", "1.0", "1.0") == (1, 0)
assert util.protocol_version("1.0", "1.0", "1.1") == (1, 0)
assert util.protocol_version("1.1", "1.0", "1.1") == (1, 1)
assert util.protocol_version("1.2", "1.0", "1.1") is None
assert util.protocol_version("0.9", "1.0", "1.1") is None
assert util.protocol_version(["0.9", "1.0"], "1.0", "1.1") == (1, 0)
assert util.protocol_version(["0.9", "1.1"], "1.0", "1.1") == (1, 1)
assert util.protocol_version(["1.1", "0.9"], "1.0", "1.1") is None
assert util.protocol_version(["0.8", "0.9"], "1.0", "1.1") is None
assert util.protocol_version(["1.1", "1.2"], "1.0", "1.1") == (1, 1)
assert util.protocol_version(["1.2", "1.3"], "1.0", "1.1") is None

Loading…
Cancel
Save