diff --git a/contrib/pyln-client/pyln/client/lightning.py b/contrib/pyln-client/pyln/client/lightning.py index 096809e38..e3545d9a9 100644 --- a/contrib/pyln-client/pyln/client/lightning.py +++ b/contrib/pyln-client/pyln/client/lightning.py @@ -1,5 +1,6 @@ from decimal import Decimal from math import floor, log10 +from typing import Optional, Union import json import logging import os @@ -23,9 +24,12 @@ def monkey_patch_json(patch=True): class RpcError(ValueError): - def __init__(self, method, payload, error): - super(ValueError, self).__init__("RPC call failed: method: {}, payload: {}, error: {}" - .format(method, payload, error)) + def __init__(self, method: str, payload: dict, error: str): + super(ValueError, self).__init__( + "RPC call failed: method: {}, payload: {}, error: {}".format( + method, payload, error + ) + ) self.method = method self.payload = payload @@ -36,10 +40,10 @@ class Millisatoshi: """ A subtype to represent thousandths of a satoshi. - Many JSON API fields are expressed in millisatoshis: these automatically get - turned into Millisatoshi types. Converts to and from int. + Many JSON API fields are expressed in millisatoshis: these automatically + get turned into Millisatoshi types. Converts to and from int. """ - def __init__(self, v): + def __init__(self, v: Union[int, str, Decimal]): """ Takes either a string ending in 'msat', 'sat', 'btc' or an integer. """ @@ -47,43 +51,50 @@ class Millisatoshi: if v.endswith("msat"): self.millisatoshis = int(v[0:-4]) elif v.endswith("sat"): - self.millisatoshis = Decimal(v[0:-3]) * 1000 + self.millisatoshis = int(v[0:-3]) * 1000 elif v.endswith("btc"): - self.millisatoshis = Decimal(v[0:-3]) * 1000 * 10**8 + self.millisatoshis = int(v[0:-3]) * 1000 * 10**8 else: - raise TypeError("Millisatoshi must be string with msat/sat/btc suffix or int") + raise TypeError( + "Millisatoshi must be string with msat/sat/btc suffix or" + " int" + ) if self.millisatoshis != int(self.millisatoshis): raise ValueError("Millisatoshi must be a whole number") self.millisatoshis = int(self.millisatoshis) + elif isinstance(v, Millisatoshi): self.millisatoshis = v.millisatoshis + elif int(v) == v: self.millisatoshis = int(v) else: - raise TypeError("Millisatoshi must be string with msat/sat/btc suffix or int") + raise TypeError( + "Millisatoshi must be string with msat/sat/btc suffix or int" + ) if self.millisatoshis < 0: raise ValueError("Millisatoshi must be >= 0") - def __repr__(self): + def __repr__(self) -> str: """ Appends the 'msat' as expected for this type. """ return str(self.millisatoshis) + "msat" - def to_satoshi(self): + def to_satoshi(self) -> Decimal: """ Return a Decimal representing the number of satoshis. """ return Decimal(self.millisatoshis) / 1000 - def to_btc(self): + def to_btc(self) -> Decimal: """ Return a Decimal representing the number of bitcoin. """ return Decimal(self.millisatoshis) / 1000 / 10**8 - def to_satoshi_str(self): + def to_satoshi_str(self) -> str: """ Return a string of form 1234sat or 1234.567sat. """ @@ -92,7 +103,7 @@ class Millisatoshi: else: return '{:.0f}sat'.format(self.to_satoshi()) - def to_btc_str(self): + def to_btc_str(self) -> str: """ Return a string of form 12.34567890btc or 12.34567890123btc. """ @@ -101,13 +112,14 @@ class Millisatoshi: else: return '{:.8f}btc'.format(self.to_btc()) - def to_approx_str(self, digits: int = 3): + def to_approx_str(self, digits: int = 3) -> str: """Returns the shortmost string using common units representation. Rounds to significant `digits`. Default: 3 """ - round_to_n = lambda x, n: round(x, -int(floor(log10(x))) + (n - 1)) - result = None + def round_to_n(x: int, n: int) -> float: + return round(x, -int(floor(log10(x))) + (n - 1)) + result = self.to_satoshi_str() # we try to increase digits to check if we did loose out on precision # without gaining a shorter string, since this is a rarely used UI @@ -132,46 +144,51 @@ class Millisatoshi: else: return result - def to_json(self): + def to_json(self) -> str: return self.__repr__() - def __int__(self): + def __int__(self) -> int: return self.millisatoshis - def __lt__(self, other): + def __lt__(self, other: 'Millisatoshi') -> bool: return self.millisatoshis < other.millisatoshis - def __le__(self, other): + def __le__(self, other: 'Millisatoshi') -> bool: return self.millisatoshis <= other.millisatoshis - def __eq__(self, other): - return self.millisatoshis == other.millisatoshis + def __eq__(self, other: object) -> bool: + if isinstance(other, Millisatoshi): + return self.millisatoshis == other.millisatoshis + elif isinstance(other, int): + return self.millisatoshis == other + else: + return False - def __gt__(self, other): + def __gt__(self, other: 'Millisatoshi') -> bool: return self.millisatoshis > other.millisatoshis - def __ge__(self, other): + def __ge__(self, other: 'Millisatoshi') -> bool: return self.millisatoshis >= other.millisatoshis - def __add__(self, other): + def __add__(self, other: 'Millisatoshi') -> 'Millisatoshi': return Millisatoshi(int(self) + int(other)) - def __sub__(self, other): + def __sub__(self, other: 'Millisatoshi') -> 'Millisatoshi': return Millisatoshi(int(self) - int(other)) - def __mul__(self, other): - return Millisatoshi(int(int(self) * other)) + def __mul__(self, other: int) -> 'Millisatoshi': + return Millisatoshi(self.millisatoshis * other) - def __truediv__(self, other): - return Millisatoshi(int(int(self) / other)) + def __truediv__(self, other: Union[int, float]) -> 'Millisatoshi': + return Millisatoshi(int(self.millisatoshis / other)) - def __floordiv__(self, other): - return Millisatoshi(int(self) // other) + def __floordiv__(self, other: Union[int, float]) -> 'Millisatoshi': + return Millisatoshi(int(self.millisatoshis // float(other))) - def __mod__(self, other): - return Millisatoshi(int(self) % other) + def __mod__(self, other: Union[float, int]) -> 'Millisatoshi': + return Millisatoshi(int(self.millisatoshis % other)) - def __radd__(self, other): + def __radd__(self, other: 'Millisatoshi') -> 'Millisatoshi': return Millisatoshi(int(self) + int(other)) @@ -188,17 +205,17 @@ class UnixSocket(object): """ - def __init__(self, path): + def __init__(self, path: str): self.path = path - self.sock = None + self.sock: Optional[socket.SocketType] = None self.connect() - def connect(self): + def connect(self) -> None: try: self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - return self.sock.connect(self.path) + self.sock.connect(self.path) except OSError as e: - self.sock.close() + self.close() if (e.args[0] == "AF_UNIX path too long" and os.uname()[0] == "Linux"): # If this is a Linux system we may be able to work around this @@ -216,29 +233,29 @@ class UnixSocket(object): dirfd = os.open(dirname, os.O_DIRECTORY | os.O_RDONLY) short_path = "/proc/self/fd/%d/%s" % (dirfd, basename) self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - return self.sock.connect(short_path) + self.sock.connect(short_path) else: # There is no good way to recover from this. raise - def close(self): + def close(self) -> None: if self.sock is not None: self.sock.close() self.sock = None - def sendall(self, b): + def sendall(self, b: bytes) -> None: if self.sock is None: raise socket.error("not connected") self.sock.sendall(b) - def recv(self, length): + def recv(self, length: int) -> bytes: if self.sock is None: raise socket.error("not connected") return self.sock.recv(length) - def __del__(self): + def __del__(self) -> None: self.close()