diff --git a/electrumx/lib/util.py b/electrumx/lib/util.py index e678626..8f49e2d 100644 --- a/electrumx/lib/util.py +++ b/electrumx/lib/util.py @@ -304,7 +304,7 @@ def version_string(ptuple): return '.'.join(str(p) for p in ptuple) -def protocol_version(client_req, server_min, server_max): +def protocol_version(client_req, min_tuple, max_tuple): '''Given a client's protocol version string, return a pair of protocol tuples: @@ -312,27 +312,19 @@ def protocol_version(client_req, server_min, server_max): If the request is unsupported, the negotiated protocol tuple is None. - - ''' - '''Given a client protocol request, return the protocol version - to use as a tuple. - - If a mutually acceptable protocol version does not exist, return None. ''' - if isinstance(client_req, list) and len(client_req) == 2: - client_min, client_max = client_req - elif client_req is None: - client_min = client_max = server_min + if client_req is None: + client_min = client_max = min_tuple else: - client_min = client_max = client_req - - client_min = protocol_tuple(client_min) - client_max = protocol_tuple(client_max) - server_min = protocol_tuple(server_min) - server_max = protocol_tuple(server_max) - - result = min(client_max, server_max) - if result < max(client_min, server_min) or result == (0, ): + if isinstance(client_req, list) and len(client_req) == 2: + client_min, client_max = client_req + else: + client_min = client_max = client_req + client_min = protocol_tuple(client_min) + client_max = protocol_tuple(client_max) + + result = min(client_max, max_tuple) + if result < max(client_min, min_tuple) or result == (0, ): result = None return result, client_min diff --git a/electrumx/server/controller.py b/electrumx/server/controller.py index 4bf3509..0686eed 100644 --- a/electrumx/server/controller.py +++ b/electrumx/server/controller.py @@ -61,11 +61,10 @@ class Controller(ServerBase): raise RuntimeError('ElectrumX requires aiorpcX >= ' f'{version_string(self.AIORPCX_MIN)}') - sclass = env.coin.SESSIONCLS + min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings() self.logger.info(f'software version: {electrumx.version}') self.logger.info(f'aiorpcX version: {version_string(aiorpcx_version)}') - self.logger.info(f'supported protocol versions: ' - f'{sclass.PROTOCOL_MIN}-{sclass.PROTOCOL_MAX}') + self.logger.info(f'supported protocol versions: {min_str}-{max_str}') self.logger.info(f'event loop policy: {env.loop_policy}') self.coin = env.coin diff --git a/electrumx/server/session.py b/electrumx/server/session.py index 1485749..b706c49 100644 --- a/electrumx/server/session.py +++ b/electrumx/server/session.py @@ -133,8 +133,8 @@ class SessionBase(ServerSession): class ElectrumX(SessionBase): '''A TCP server that handles incoming Electrum connections.''' - PROTOCOL_MIN = '1.1' - PROTOCOL_MAX = '1.4' + PROTOCOL_MIN = (1, 1) + PROTOCOL_MAX = (1, 4) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -146,17 +146,23 @@ class ElectrumX(SessionBase): self.hashX_subs = {} self.sv_seen = False self.mempool_statuses = {} - self.set_protocol_handlers(util.protocol_tuple(self.PROTOCOL_MIN)) + self.set_protocol_handlers(self.PROTOCOL_MIN) + + @classmethod + def protocol_min_max_strings(cls): + return [util.version_string(ver) + for ver in (cls.PROTOCOL_MIN, cls.PROTOCOL_MAX)] @classmethod def server_features(cls, env): '''Return the server features dictionary.''' + min_str, max_str = cls.protocol_min_max_strings() return { 'hosts': env.hosts_dict(), 'pruning': None, 'server_version': electrumx.version, - 'protocol_min': cls.PROTOCOL_MIN, - 'protocol_max': cls.PROTOCOL_MAX, + 'protocol_min': min_str, + 'protocol_max': max_str, 'genesis_hash': env.coin.GENESIS_HASH, 'hash_function': 'sha256', } @@ -164,7 +170,7 @@ class ElectrumX(SessionBase): @classmethod def server_version_args(cls): '''The arguments to a server.version RPC call to a peer.''' - return [electrumx.version, [cls.PROTOCOL_MIN, cls.PROTOCOL_MAX]] + return [electrumx.version, cls.protocol_min_max_strings()] def protocol_version_string(self): return util.version_string(self.protocol_tuple) @@ -463,9 +469,9 @@ class ElectrumX(SessionBase): ptuple, client_min = util.protocol_version( protocol_version, self.PROTOCOL_MIN, self.PROTOCOL_MAX) if ptuple is None: - if client_min > util.protocol_tuple(self.PROTOCOL_MIN): + if client_min > self.PROTOCOL_MIN: self.logger.info(f'client requested future protocol version ' - f'{version_string(client_min)} ' + f'{util.version_string(client_min)} ' f'- is your software out of date?') self.close_after_send = True raise RPCError(BAD_REQUEST, diff --git a/tests/lib/test_util.py b/tests/lib/test_util.py index 891b648..0351471 100644 --- a/tests/lib/test_util.py +++ b/tests/lib/test_util.py @@ -183,26 +183,26 @@ def test_version_string(): assert util.version_string((1, 3, 2)) == "1.3.2" def test_protocol_version(): - assert util.protocol_version(None, "1.0", "1.0") == ((1, 0), (1, 0)) - assert util.protocol_version("0.10", "0.10", "1.1") == ((0, 10), (0, 10)) + assert util.protocol_version(None, (1, 0), (1, 0)) == ((1, 0), (1, 0)) + assert util.protocol_version("0.10", (0, 1), (1, 1)) == ((0, 10), (0, 10)) - assert util.protocol_version("1.0", "1.0", "1.0") == ((1, 0), (1, 0)) - assert util.protocol_version("1.0", "1.0", "1.1") == ((1, 0), (1, 0)) - assert util.protocol_version("1.1", "1.0", "1.1") == ((1, 1), (1, 1)) - assert util.protocol_version("1.2", "1.0", "1.1") == (None, (1, 2)) - assert util.protocol_version("0.9", "1.0", "1.1") == (None, (0, 9)) + assert util.protocol_version("1.0", (1, 0), (1, 0)) == ((1, 0), (1, 0)) + assert util.protocol_version("1.0", (1, 0), (1, 1)) == ((1, 0), (1, 0)) + assert util.protocol_version("1.1", (1, 0), (1, 1)) == ((1, 1), (1, 1)) + assert util.protocol_version("1.2", (1, 0), (1, 1)) == (None, (1, 2)) + assert util.protocol_version("0.9", (1, 0), (1, 1)) == (None, (0, 9)) - assert util.protocol_version(["0.9", "1.0"], "1.0", "1.1") \ + assert util.protocol_version(["0.9", "1.0"], (1, 0), (1, 1)) \ == ((1, 0), (0, 9)) - assert util.protocol_version(["0.9", "1.1"], "1.0", "1.1") \ + assert util.protocol_version(["0.9", "1.1"], (1, 0), (1, 1)) \ == ((1, 1), (0,9)) - assert util.protocol_version(["1.1", "0.9"], "1.0", "1.1") \ + assert util.protocol_version(["1.1", "0.9"], (1, 0), (1, 1)) \ == (None, (1, 1)) - assert util.protocol_version(["0.8", "0.9"], "1.0", "1.1") \ + assert util.protocol_version(["0.8", "0.9"], (1, 0), (1, 1)) \ == (None, (0, 8)) - assert util.protocol_version(["1.1", "1.2"], "1.0", "1.1") \ + assert util.protocol_version(["1.1", "1.2"], (1, 0), (1, 1)) \ == ((1, 1), (1, 1)) - assert util.protocol_version(["1.2", "1.3"], "1.0", "1.1") \ + assert util.protocol_version(["1.2", "1.3"], (1, 0), (1, 1)) \ == (None, (1, 2)) diff --git a/tests/server/test_api.py b/tests/server/test_api.py index eb8b951..a8c4abb 100644 --- a/tests/server/test_api.py +++ b/tests/server/test_api.py @@ -10,6 +10,7 @@ loop = asyncio.get_event_loop() def set_env(): env = mock.create_autospec(Env) env.coin = mock.Mock() + env.coin.SESSIONCLS.protocol_min_max_strings = lambda : ["1.1", "1.4"] env.loop_policy = None env.max_sessions = 0 env.max_subs = 0