Browse Source

pyln: Add type annotations to lightning.py

travis-experimental
Christian Decker 4 years ago
committed by Rusty Russell
parent
commit
8ecb157af6
  1. 111
      contrib/pyln-client/pyln/client/lightning.py

111
contrib/pyln-client/pyln/client/lightning.py

@ -1,5 +1,6 @@
from decimal import Decimal
from math import floor, log10
from typing import Optional, Union
import json
import logging
import os
@ -23,9 +24,12 @@ def monkey_patch_json(patch=True):
class RpcError(ValueError):
def __init__(self, method, payload, error):
super(ValueError, self).__init__("RPC call failed: method: {}, payload: {}, error: {}"
.format(method, payload, error))
def __init__(self, method: str, payload: dict, error: str):
super(ValueError, self).__init__(
"RPC call failed: method: {}, payload: {}, error: {}".format(
method, payload, error
)
)
self.method = method
self.payload = payload
@ -36,10 +40,10 @@ class Millisatoshi:
"""
A subtype to represent thousandths of a satoshi.
Many JSON API fields are expressed in millisatoshis: these automatically get
turned into Millisatoshi types. Converts to and from int.
Many JSON API fields are expressed in millisatoshis: these automatically
get turned into Millisatoshi types. Converts to and from int.
"""
def __init__(self, v):
def __init__(self, v: Union[int, str, Decimal]):
"""
Takes either a string ending in 'msat', 'sat', 'btc' or an integer.
"""
@ -47,43 +51,50 @@ class Millisatoshi:
if v.endswith("msat"):
self.millisatoshis = int(v[0:-4])
elif v.endswith("sat"):
self.millisatoshis = Decimal(v[0:-3]) * 1000
self.millisatoshis = int(v[0:-3]) * 1000
elif v.endswith("btc"):
self.millisatoshis = Decimal(v[0:-3]) * 1000 * 10**8
self.millisatoshis = int(v[0:-3]) * 1000 * 10**8
else:
raise TypeError("Millisatoshi must be string with msat/sat/btc suffix or int")
raise TypeError(
"Millisatoshi must be string with msat/sat/btc suffix or"
" int"
)
if self.millisatoshis != int(self.millisatoshis):
raise ValueError("Millisatoshi must be a whole number")
self.millisatoshis = int(self.millisatoshis)
elif isinstance(v, Millisatoshi):
self.millisatoshis = v.millisatoshis
elif int(v) == v:
self.millisatoshis = int(v)
else:
raise TypeError("Millisatoshi must be string with msat/sat/btc suffix or int")
raise TypeError(
"Millisatoshi must be string with msat/sat/btc suffix or int"
)
if self.millisatoshis < 0:
raise ValueError("Millisatoshi must be >= 0")
def __repr__(self):
def __repr__(self) -> str:
"""
Appends the 'msat' as expected for this type.
"""
return str(self.millisatoshis) + "msat"
def to_satoshi(self):
def to_satoshi(self) -> Decimal:
"""
Return a Decimal representing the number of satoshis.
"""
return Decimal(self.millisatoshis) / 1000
def to_btc(self):
def to_btc(self) -> Decimal:
"""
Return a Decimal representing the number of bitcoin.
"""
return Decimal(self.millisatoshis) / 1000 / 10**8
def to_satoshi_str(self):
def to_satoshi_str(self) -> str:
"""
Return a string of form 1234sat or 1234.567sat.
"""
@ -92,7 +103,7 @@ class Millisatoshi:
else:
return '{:.0f}sat'.format(self.to_satoshi())
def to_btc_str(self):
def to_btc_str(self) -> str:
"""
Return a string of form 12.34567890btc or 12.34567890123btc.
"""
@ -101,13 +112,14 @@ class Millisatoshi:
else:
return '{:.8f}btc'.format(self.to_btc())
def to_approx_str(self, digits: int = 3):
def to_approx_str(self, digits: int = 3) -> str:
"""Returns the shortmost string using common units representation.
Rounds to significant `digits`. Default: 3
"""
round_to_n = lambda x, n: round(x, -int(floor(log10(x))) + (n - 1))
result = None
def round_to_n(x: int, n: int) -> float:
return round(x, -int(floor(log10(x))) + (n - 1))
result = self.to_satoshi_str()
# we try to increase digits to check if we did loose out on precision
# without gaining a shorter string, since this is a rarely used UI
@ -132,46 +144,51 @@ class Millisatoshi:
else:
return result
def to_json(self):
def to_json(self) -> str:
return self.__repr__()
def __int__(self):
def __int__(self) -> int:
return self.millisatoshis
def __lt__(self, other):
def __lt__(self, other: 'Millisatoshi') -> bool:
return self.millisatoshis < other.millisatoshis
def __le__(self, other):
def __le__(self, other: 'Millisatoshi') -> bool:
return self.millisatoshis <= other.millisatoshis
def __eq__(self, other):
return self.millisatoshis == other.millisatoshis
def __eq__(self, other: object) -> bool:
if isinstance(other, Millisatoshi):
return self.millisatoshis == other.millisatoshis
elif isinstance(other, int):
return self.millisatoshis == other
else:
return False
def __gt__(self, other):
def __gt__(self, other: 'Millisatoshi') -> bool:
return self.millisatoshis > other.millisatoshis
def __ge__(self, other):
def __ge__(self, other: 'Millisatoshi') -> bool:
return self.millisatoshis >= other.millisatoshis
def __add__(self, other):
def __add__(self, other: 'Millisatoshi') -> 'Millisatoshi':
return Millisatoshi(int(self) + int(other))
def __sub__(self, other):
def __sub__(self, other: 'Millisatoshi') -> 'Millisatoshi':
return Millisatoshi(int(self) - int(other))
def __mul__(self, other):
return Millisatoshi(int(int(self) * other))
def __mul__(self, other: int) -> 'Millisatoshi':
return Millisatoshi(self.millisatoshis * other)
def __truediv__(self, other):
return Millisatoshi(int(int(self) / other))
def __truediv__(self, other: Union[int, float]) -> 'Millisatoshi':
return Millisatoshi(int(self.millisatoshis / other))
def __floordiv__(self, other):
return Millisatoshi(int(self) // other)
def __floordiv__(self, other: Union[int, float]) -> 'Millisatoshi':
return Millisatoshi(int(self.millisatoshis // float(other)))
def __mod__(self, other):
return Millisatoshi(int(self) % other)
def __mod__(self, other: Union[float, int]) -> 'Millisatoshi':
return Millisatoshi(int(self.millisatoshis % other))
def __radd__(self, other):
def __radd__(self, other: 'Millisatoshi') -> 'Millisatoshi':
return Millisatoshi(int(self) + int(other))
@ -188,17 +205,17 @@ class UnixSocket(object):
"""
def __init__(self, path):
def __init__(self, path: str):
self.path = path
self.sock = None
self.sock: Optional[socket.SocketType] = None
self.connect()
def connect(self):
def connect(self) -> None:
try:
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
return self.sock.connect(self.path)
self.sock.connect(self.path)
except OSError as e:
self.sock.close()
self.close()
if (e.args[0] == "AF_UNIX path too long" and os.uname()[0] == "Linux"):
# If this is a Linux system we may be able to work around this
@ -216,29 +233,29 @@ class UnixSocket(object):
dirfd = os.open(dirname, os.O_DIRECTORY | os.O_RDONLY)
short_path = "/proc/self/fd/%d/%s" % (dirfd, basename)
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
return self.sock.connect(short_path)
self.sock.connect(short_path)
else:
# There is no good way to recover from this.
raise
def close(self):
def close(self) -> None:
if self.sock is not None:
self.sock.close()
self.sock = None
def sendall(self, b):
def sendall(self, b: bytes) -> None:
if self.sock is None:
raise socket.error("not connected")
self.sock.sendall(b)
def recv(self, length):
def recv(self, length: int) -> bytes:
if self.sock is None:
raise socket.error("not connected")
return self.sock.recv(length)
def __del__(self):
def __del__(self) -> None:
self.close()

Loading…
Cancel
Save