Browse Source

pyln: Add type annotations to lightning.py

travis-experimental
Christian Decker 5 years ago
committed by Rusty Russell
parent
commit
8ecb157af6
  1. 109
      contrib/pyln-client/pyln/client/lightning.py

109
contrib/pyln-client/pyln/client/lightning.py

@ -1,5 +1,6 @@
from decimal import Decimal from decimal import Decimal
from math import floor, log10 from math import floor, log10
from typing import Optional, Union
import json import json
import logging import logging
import os import os
@ -23,9 +24,12 @@ def monkey_patch_json(patch=True):
class RpcError(ValueError): class RpcError(ValueError):
def __init__(self, method, payload, error): def __init__(self, method: str, payload: dict, error: str):
super(ValueError, self).__init__("RPC call failed: method: {}, payload: {}, error: {}" super(ValueError, self).__init__(
.format(method, payload, error)) "RPC call failed: method: {}, payload: {}, error: {}".format(
method, payload, error
)
)
self.method = method self.method = method
self.payload = payload self.payload = payload
@ -36,10 +40,10 @@ class Millisatoshi:
""" """
A subtype to represent thousandths of a satoshi. A subtype to represent thousandths of a satoshi.
Many JSON API fields are expressed in millisatoshis: these automatically get Many JSON API fields are expressed in millisatoshis: these automatically
turned into Millisatoshi types. Converts to and from int. get turned into Millisatoshi types. Converts to and from int.
""" """
def __init__(self, v): def __init__(self, v: Union[int, str, Decimal]):
""" """
Takes either a string ending in 'msat', 'sat', 'btc' or an integer. Takes either a string ending in 'msat', 'sat', 'btc' or an integer.
""" """
@ -47,43 +51,50 @@ class Millisatoshi:
if v.endswith("msat"): if v.endswith("msat"):
self.millisatoshis = int(v[0:-4]) self.millisatoshis = int(v[0:-4])
elif v.endswith("sat"): elif v.endswith("sat"):
self.millisatoshis = Decimal(v[0:-3]) * 1000 self.millisatoshis = int(v[0:-3]) * 1000
elif v.endswith("btc"): elif v.endswith("btc"):
self.millisatoshis = Decimal(v[0:-3]) * 1000 * 10**8 self.millisatoshis = int(v[0:-3]) * 1000 * 10**8
else: else:
raise TypeError("Millisatoshi must be string with msat/sat/btc suffix or int") raise TypeError(
"Millisatoshi must be string with msat/sat/btc suffix or"
" int"
)
if self.millisatoshis != int(self.millisatoshis): if self.millisatoshis != int(self.millisatoshis):
raise ValueError("Millisatoshi must be a whole number") raise ValueError("Millisatoshi must be a whole number")
self.millisatoshis = int(self.millisatoshis) self.millisatoshis = int(self.millisatoshis)
elif isinstance(v, Millisatoshi): elif isinstance(v, Millisatoshi):
self.millisatoshis = v.millisatoshis self.millisatoshis = v.millisatoshis
elif int(v) == v: elif int(v) == v:
self.millisatoshis = int(v) self.millisatoshis = int(v)
else: else:
raise TypeError("Millisatoshi must be string with msat/sat/btc suffix or int") raise TypeError(
"Millisatoshi must be string with msat/sat/btc suffix or int"
)
if self.millisatoshis < 0: if self.millisatoshis < 0:
raise ValueError("Millisatoshi must be >= 0") raise ValueError("Millisatoshi must be >= 0")
def __repr__(self): def __repr__(self) -> str:
""" """
Appends the 'msat' as expected for this type. Appends the 'msat' as expected for this type.
""" """
return str(self.millisatoshis) + "msat" return str(self.millisatoshis) + "msat"
def to_satoshi(self): def to_satoshi(self) -> Decimal:
""" """
Return a Decimal representing the number of satoshis. Return a Decimal representing the number of satoshis.
""" """
return Decimal(self.millisatoshis) / 1000 return Decimal(self.millisatoshis) / 1000
def to_btc(self): def to_btc(self) -> Decimal:
""" """
Return a Decimal representing the number of bitcoin. Return a Decimal representing the number of bitcoin.
""" """
return Decimal(self.millisatoshis) / 1000 / 10**8 return Decimal(self.millisatoshis) / 1000 / 10**8
def to_satoshi_str(self): def to_satoshi_str(self) -> str:
""" """
Return a string of form 1234sat or 1234.567sat. Return a string of form 1234sat or 1234.567sat.
""" """
@ -92,7 +103,7 @@ class Millisatoshi:
else: else:
return '{:.0f}sat'.format(self.to_satoshi()) return '{:.0f}sat'.format(self.to_satoshi())
def to_btc_str(self): def to_btc_str(self) -> str:
""" """
Return a string of form 12.34567890btc or 12.34567890123btc. Return a string of form 12.34567890btc or 12.34567890123btc.
""" """
@ -101,13 +112,14 @@ class Millisatoshi:
else: else:
return '{:.8f}btc'.format(self.to_btc()) return '{:.8f}btc'.format(self.to_btc())
def to_approx_str(self, digits: int = 3): def to_approx_str(self, digits: int = 3) -> str:
"""Returns the shortmost string using common units representation. """Returns the shortmost string using common units representation.
Rounds to significant `digits`. Default: 3 Rounds to significant `digits`. Default: 3
""" """
round_to_n = lambda x, n: round(x, -int(floor(log10(x))) + (n - 1)) def round_to_n(x: int, n: int) -> float:
result = None return round(x, -int(floor(log10(x))) + (n - 1))
result = self.to_satoshi_str()
# we try to increase digits to check if we did loose out on precision # we try to increase digits to check if we did loose out on precision
# without gaining a shorter string, since this is a rarely used UI # without gaining a shorter string, since this is a rarely used UI
@ -132,46 +144,51 @@ class Millisatoshi:
else: else:
return result return result
def to_json(self): def to_json(self) -> str:
return self.__repr__() return self.__repr__()
def __int__(self): def __int__(self) -> int:
return self.millisatoshis return self.millisatoshis
def __lt__(self, other): def __lt__(self, other: 'Millisatoshi') -> bool:
return self.millisatoshis < other.millisatoshis return self.millisatoshis < other.millisatoshis
def __le__(self, other): def __le__(self, other: 'Millisatoshi') -> bool:
return self.millisatoshis <= other.millisatoshis return self.millisatoshis <= other.millisatoshis
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if isinstance(other, Millisatoshi):
return self.millisatoshis == other.millisatoshis return self.millisatoshis == other.millisatoshis
elif isinstance(other, int):
return self.millisatoshis == other
else:
return False
def __gt__(self, other): def __gt__(self, other: 'Millisatoshi') -> bool:
return self.millisatoshis > other.millisatoshis return self.millisatoshis > other.millisatoshis
def __ge__(self, other): def __ge__(self, other: 'Millisatoshi') -> bool:
return self.millisatoshis >= other.millisatoshis return self.millisatoshis >= other.millisatoshis
def __add__(self, other): def __add__(self, other: 'Millisatoshi') -> 'Millisatoshi':
return Millisatoshi(int(self) + int(other)) return Millisatoshi(int(self) + int(other))
def __sub__(self, other): def __sub__(self, other: 'Millisatoshi') -> 'Millisatoshi':
return Millisatoshi(int(self) - int(other)) return Millisatoshi(int(self) - int(other))
def __mul__(self, other): def __mul__(self, other: int) -> 'Millisatoshi':
return Millisatoshi(int(int(self) * other)) return Millisatoshi(self.millisatoshis * other)
def __truediv__(self, other): def __truediv__(self, other: Union[int, float]) -> 'Millisatoshi':
return Millisatoshi(int(int(self) / other)) return Millisatoshi(int(self.millisatoshis / other))
def __floordiv__(self, other): def __floordiv__(self, other: Union[int, float]) -> 'Millisatoshi':
return Millisatoshi(int(self) // other) return Millisatoshi(int(self.millisatoshis // float(other)))
def __mod__(self, other): def __mod__(self, other: Union[float, int]) -> 'Millisatoshi':
return Millisatoshi(int(self) % other) return Millisatoshi(int(self.millisatoshis % other))
def __radd__(self, other): def __radd__(self, other: 'Millisatoshi') -> 'Millisatoshi':
return Millisatoshi(int(self) + int(other)) return Millisatoshi(int(self) + int(other))
@ -188,17 +205,17 @@ class UnixSocket(object):
""" """
def __init__(self, path): def __init__(self, path: str):
self.path = path self.path = path
self.sock = None self.sock: Optional[socket.SocketType] = None
self.connect() self.connect()
def connect(self): def connect(self) -> None:
try: try:
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
return self.sock.connect(self.path) self.sock.connect(self.path)
except OSError as e: except OSError as e:
self.sock.close() self.close()
if (e.args[0] == "AF_UNIX path too long" and os.uname()[0] == "Linux"): if (e.args[0] == "AF_UNIX path too long" and os.uname()[0] == "Linux"):
# If this is a Linux system we may be able to work around this # If this is a Linux system we may be able to work around this
@ -216,29 +233,29 @@ class UnixSocket(object):
dirfd = os.open(dirname, os.O_DIRECTORY | os.O_RDONLY) dirfd = os.open(dirname, os.O_DIRECTORY | os.O_RDONLY)
short_path = "/proc/self/fd/%d/%s" % (dirfd, basename) short_path = "/proc/self/fd/%d/%s" % (dirfd, basename)
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
return self.sock.connect(short_path) self.sock.connect(short_path)
else: else:
# There is no good way to recover from this. # There is no good way to recover from this.
raise raise
def close(self): def close(self) -> None:
if self.sock is not None: if self.sock is not None:
self.sock.close() self.sock.close()
self.sock = None self.sock = None
def sendall(self, b): def sendall(self, b: bytes) -> None:
if self.sock is None: if self.sock is None:
raise socket.error("not connected") raise socket.error("not connected")
self.sock.sendall(b) self.sock.sendall(b)
def recv(self, length): def recv(self, length: int) -> bytes:
if self.sock is None: if self.sock is None:
raise socket.error("not connected") raise socket.error("not connected")
return self.sock.recv(length) return self.sock.recv(length)
def __del__(self): def __del__(self) -> None:
self.close() self.close()

Loading…
Cancel
Save