From 59c1d03f018026ac301c4e74facfc64da8ae4708 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Sat, 16 Jun 2018 06:34:03 +0200 Subject: [PATCH] ecc.py: properly handle point at infinity --- lib/ecc.py | 33 +++++++++++++++++++++++++++------ lib/tests/test_bitcoin.py | 26 ++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/lib/ecc.py b/lib/ecc.py index b46a91069..585883fc2 100644 --- a/lib/ecc.py +++ b/lib/ecc.py @@ -49,6 +49,10 @@ def generator(): return ECPubkey.from_point(generator_secp256k1) +def point_at_infinity(): + return ECPubkey(None) + + def sig_string_from_der_sig(der_sig, order=CURVE_ORDER): r, s = ecdsa.util.sigdecode_der(der_sig, order) return ecdsa.util.sigencode_string(r, s, order) @@ -83,6 +87,8 @@ def point_to_ser(P, compressed=True) -> bytes: x, y = P else: x, y = P.x(), P.y() + if x is None or y is None: # infinity + return None if compressed: return bfh(('%02x' % (2+(y&1))) + ('%064x' % x)) return bfh('04'+('%064x' % x)+('%064x' % y)) @@ -115,7 +121,10 @@ def ser_to_point(ser: bytes) -> (int, int): def _ser_to_python_ecdsa_point(ser: bytes) -> ecdsa.ellipticcurve.Point: x, y = ser_to_point(ser) - return Point(curve_secp256k1, x, y, CURVE_ORDER) + try: + return Point(curve_secp256k1, x, y, CURVE_ORDER) + except: + raise InvalidECPointException() class InvalidECPointException(Exception): @@ -166,12 +175,19 @@ class _MySigningKey(ecdsa.SigningKey): return r, s +class _PubkeyForPointAtInfinity: + point = ecdsa.ellipticcurve.INFINITY + + class ECPubkey(object): def __init__(self, b: bytes): - assert_bytes(b) - point = _ser_to_python_ecdsa_point(b) - self._pubkey = ecdsa.ecdsa.Public_key(generator_secp256k1, point) + if b is not None: + assert_bytes(b) + point = _ser_to_python_ecdsa_point(b) + self._pubkey = ecdsa.ecdsa.Public_key(generator_secp256k1, point) + else: + self._pubkey = _PubkeyForPointAtInfinity() @classmethod def from_sig_string(cls, sig_string: bytes, recid: int, msg_hash: bytes): @@ -205,6 +221,7 @@ class ECPubkey(object): return ECPubkey(_bytes) def get_public_key_bytes(self, compressed=True): + if self.is_at_infinity(): raise Exception('point is at infinity') return point_to_ser(self.point(), compressed) def get_public_key_hex(self, compressed=True): @@ -229,7 +246,8 @@ class ECPubkey(object): return self.from_point(ecdsa_point) def __eq__(self, other): - return self.get_public_key_bytes() == other.get_public_key_bytes() + return self._pubkey.point.x() == other._pubkey.point.x() \ + and self._pubkey.point.y() == other._pubkey.point.y() def __ne__(self, other): return not (self == other) @@ -275,6 +293,9 @@ class ECPubkey(object): def order(cls): return CURVE_ORDER + def is_at_infinity(self): + return self == point_at_infinity() + def msg_magic(message: bytes) -> bytes: from .bitcoin import var_int @@ -318,7 +339,7 @@ class ECPrivkey(ECPubkey): raise Exception('unexpected size for secret. should be 32 bytes, not {}'.format(len(privkey_bytes))) secret = string_to_number(privkey_bytes) if not is_secret_within_curve_range(secret): - raise Exception('Invalid secret scalar (not within curve order)') + raise InvalidECPointException('Invalid secret scalar (not within curve order)') self.secret_scalar = secret point = generator_secp256k1 * secret diff --git a/lib/tests/test_bitcoin.py b/lib/tests/test_bitcoin.py index d0a303858..20ec5a30f 100644 --- a/lib/tests/test_bitcoin.py +++ b/lib/tests/test_bitcoin.py @@ -125,6 +125,32 @@ class Test_bitcoin(SequentialTestCase): #print signature eck.verify_message_for_address(signature, message) + @needs_test_with_all_ecc_implementations + def test_ecc_sanity(self): + G = ecc.generator() + n = G.order() + self.assertEqual(ecc.CURVE_ORDER, n) + inf = n * G + self.assertEqual(ecc.point_at_infinity(), inf) + self.assertTrue(inf.is_at_infinity()) + self.assertFalse(G.is_at_infinity()) + self.assertEqual(11 * G, 7 * G + 4 * G) + self.assertEqual((n + 2) * G, 2 * G) + self.assertEqual((n - 2) * G, -2 * G) + A = (n - 2) * G + B = (n - 1) * G + C = n * G + D = (n + 1) * G + self.assertFalse(A.is_at_infinity()) + self.assertFalse(B.is_at_infinity()) + self.assertTrue(C.is_at_infinity()) + self.assertTrue((C * 5).is_at_infinity()) + self.assertFalse(D.is_at_infinity()) + self.assertEqual(inf, C) + self.assertEqual(inf, A + 2 * G) + self.assertEqual(inf, D + (-1) * G) + self.assertNotEqual(A, B) + @needs_test_with_all_ecc_implementations def test_msg_signing(self): msg1 = b'Chancellor on brink of second bailout for banks'