Browse Source

fix: mypy errors

Login
Eneko Illarramendi 5 years ago
committed by Sebastian Geisler
parent
commit
c3e337a319
  1. 8
      lnbits/__init__.py
  2. 2
      lnbits/bolt11.py
  3. 2
      lnbits/core/__init__.py
  4. 17
      lnbits/core/crud.py
  5. 10
      lnbits/core/models.py
  6. 9
      lnbits/core/services.py
  7. 4
      lnbits/core/views/lnurl.py
  8. 2
      lnbits/decorators.py
  9. 2
      lnbits/extensions/amilk/__init__.py
  10. 2
      lnbits/extensions/diagonalley/__init__.py
  11. 2
      lnbits/extensions/events/__init__.py
  12. 2
      lnbits/extensions/example/__init__.py
  13. 2
      lnbits/extensions/tpos/__init__.py
  14. 2
      lnbits/extensions/withdraw/__init__.py
  15. 11
      lnbits/helpers.py
  16. 9
      lnbits/wallets/clightning.py
  17. 78
      lnbits/wallets/lndgrpc.py
  18. 5
      lnbits/wallets/lndrest.py

8
lnbits/__init__.py

@ -1,9 +1,9 @@
import importlib import importlib
from flask import Flask from flask import Flask
from flask_assets import Environment, Bundle from flask_assets import Environment, Bundle # type: ignore
from flask_compress import Compress from flask_compress import Compress # type: ignore
from flask_talisman import Talisman from flask_talisman import Talisman # type: ignore
from os import getenv from os import getenv
from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.middleware.proxy_fix import ProxyFix
@ -15,7 +15,7 @@ from .settings import FORCE_HTTPS
disabled_extensions = getenv("LNBITS_DISABLED_EXTENSIONS", "").split(",") disabled_extensions = getenv("LNBITS_DISABLED_EXTENSIONS", "").split(",")
app = Flask(__name__) app = Flask(__name__)
app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1) app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1) # type: ignore
valid_extensions = [ext for ext in ExtensionManager(disabled=disabled_extensions).extensions if ext.is_valid] valid_extensions = [ext for ext in ExtensionManager(disabled=disabled_extensions).extensions if ext.is_valid]

2
lnbits/bolt11.py

@ -1,3 +1,5 @@
# type: ignore
import bitstring import bitstring
import re import re
from binascii import hexlify from binascii import hexlify

2
lnbits/core/__init__.py

@ -1,7 +1,7 @@
from flask import Blueprint from flask import Blueprint
core_app = Blueprint("core", __name__, template_folder="templates", static_folder="static") core_app: Blueprint = Blueprint("core", __name__, template_folder="templates", static_folder="static")
from .views.api import * # noqa from .views.api import * # noqa

17
lnbits/core/crud.py

@ -16,7 +16,10 @@ def create_account() -> User:
user_id = uuid4().hex user_id = uuid4().hex
db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,)) db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
return get_account(user_id=user_id) new_account = get_account(user_id=user_id)
assert new_account, "Newly created account couldn't be retrieved"
return new_account
def get_account(user_id: str) -> Optional[User]: def get_account(user_id: str) -> Optional[User]:
@ -74,7 +77,10 @@ def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
(wallet_id, wallet_name or DEFAULT_WALLET_NAME, user_id, uuid4().hex, uuid4().hex), (wallet_id, wallet_name or DEFAULT_WALLET_NAME, user_id, uuid4().hex, uuid4().hex),
) )
return get_wallet(wallet_id=wallet_id) new_wallet = get_wallet(wallet_id=wallet_id)
assert new_wallet, "Newly created wallet couldn't be retrieved"
return new_wallet
def delete_wallet(*, user_id: str, wallet_id: str) -> None: def delete_wallet(*, user_id: str, wallet_id: str) -> None:
@ -175,7 +181,7 @@ def delete_wallet_payments_expired(wallet_id: str, *, seconds: int = 86400) -> N
def create_payment( def create_payment(
*, wallet_id: str, checking_id: str, amount: str, memo: str, fee: int = 0, pending: bool = True *, wallet_id: str, checking_id: str, amount: int, memo: str, fee: int = 0, pending: bool = True
) -> Payment: ) -> Payment:
with open_db() as db: with open_db() as db:
db.execute( db.execute(
@ -186,7 +192,10 @@ def create_payment(
(wallet_id, checking_id, amount, int(pending), memo, fee), (wallet_id, checking_id, amount, int(pending), memo, fee),
) )
return get_wallet_payment(wallet_id, checking_id) new_payment = get_wallet_payment(wallet_id, checking_id)
assert new_payment, "Newly created payment couldn't be retrieved"
return new_payment
def update_payment_status(checking_id: str, pending: bool) -> None: def update_payment_status(checking_id: str, pending: bool) -> None:

10
lnbits/core/models.py

@ -4,8 +4,8 @@ from typing import List, NamedTuple, Optional
class User(NamedTuple): class User(NamedTuple):
id: str id: str
email: str email: str
extensions: Optional[List[str]] = [] extensions: List[str] = []
wallets: Optional[List["Wallet"]] = [] wallets: List["Wallet"] = []
password: Optional[str] = None password: Optional[str] = None
@property @property
@ -27,9 +27,9 @@ class Wallet(NamedTuple):
@property @property
def balance(self) -> int: def balance(self) -> int:
return int(self.balance / 1000) return self.balance // 1000
def get_payment(self, checking_id: str) -> "Payment": def get_payment(self, checking_id: str) -> Optional["Payment"]:
from .crud import get_wallet_payment from .crud import get_wallet_payment
return get_wallet_payment(self.id, checking_id) return get_wallet_payment(self.id, checking_id)
@ -59,7 +59,7 @@ class Payment(NamedTuple):
@property @property
def sat(self) -> int: def sat(self) -> int:
return self.amount / 1000 return self.amount // 1000
@property @property
def is_in(self) -> bool: def is_in(self) -> bool:

9
lnbits/core/services.py

@ -1,6 +1,6 @@
from typing import Optional, Tuple from typing import Optional, Tuple
from lnbits.bolt11 import decode as bolt11_decode from lnbits.bolt11 import decode as bolt11_decode # type: ignore
from lnbits.helpers import urlsafe_short_hash from lnbits.helpers import urlsafe_short_hash
from lnbits.settings import WALLET from lnbits.settings import WALLET
@ -24,7 +24,6 @@ def create_invoice(*, wallet_id: str, amount: int, memo: str) -> Tuple[str, str]
def pay_invoice(*, wallet_id: str, bolt11: str, max_sat: Optional[int] = None) -> str: def pay_invoice(*, wallet_id: str, bolt11: str, max_sat: Optional[int] = None) -> str:
temp_id = f"temp_{urlsafe_short_hash()}" temp_id = f"temp_{urlsafe_short_hash()}"
try: try:
invoice = bolt11_decode(bolt11) invoice = bolt11_decode(bolt11)
@ -34,7 +33,7 @@ def pay_invoice(*, wallet_id: str, bolt11: str, max_sat: Optional[int] = None) -
if max_sat and invoice.amount_msat > max_sat * 1000: if max_sat and invoice.amount_msat > max_sat * 1000:
raise ValueError("Amount in invoice is too high.") raise ValueError("Amount in invoice is too high.")
fee_reserve = max(1000, invoice.amount_msat * 0.01) fee_reserve = max(1000, int(invoice.amount_msat * 0.01))
create_payment( create_payment(
wallet_id=wallet_id, wallet_id=wallet_id,
checking_id=temp_id, checking_id=temp_id,
@ -43,7 +42,9 @@ def pay_invoice(*, wallet_id: str, bolt11: str, max_sat: Optional[int] = None) -
memo=temp_id, memo=temp_id,
) )
if get_wallet(wallet_id).balance_msat < 0: wallet = get_wallet(wallet_id)
assert wallet, "invalid wallet id"
if wallet.balance_msat < 0:
raise PermissionError("Insufficient balance.") raise PermissionError("Insufficient balance.")
ok, checking_id, fee_msat, error_message = WALLET.pay_invoice(bolt11) ok, checking_id, fee_msat, error_message = WALLET.pay_invoice(bolt11)

4
lnbits/core/views/lnurl.py

@ -1,8 +1,8 @@
import requests import requests
from flask import abort, redirect, request, url_for from flask import abort, redirect, request, url_for
from lnurl import LnurlWithdrawResponse, handle as handle_lnurl from lnurl import LnurlWithdrawResponse, handle as handle_lnurl # type: ignore
from lnurl.exceptions import LnurlException from lnurl.exceptions import LnurlException # type: ignore
from time import sleep from time import sleep
from lnbits.core import core_app from lnbits.core import core_app

2
lnbits/decorators.py

@ -1,4 +1,4 @@
from cerberus import Validator from cerberus import Validator # type: ignore
from flask import g, abort, jsonify, request from flask import g, abort, jsonify, request
from functools import wraps from functools import wraps
from typing import List, Union from typing import List, Union

2
lnbits/extensions/amilk/__init__.py

@ -1,7 +1,7 @@
from flask import Blueprint from flask import Blueprint
amilk_ext = Blueprint("amilk", __name__, static_folder="static", template_folder="templates") amilk_ext: Blueprint = Blueprint("amilk", __name__, static_folder="static", template_folder="templates")
from .views_api import * # noqa from .views_api import * # noqa

2
lnbits/extensions/diagonalley/__init__.py

@ -1,7 +1,7 @@
from flask import Blueprint from flask import Blueprint
diagonalley_ext = Blueprint("diagonalley", __name__, static_folder="static", template_folder="templates") diagonalley_ext: Blueprint = Blueprint("diagonalley", __name__, static_folder="static", template_folder="templates")
from .views_api import * # noqa from .views_api import * # noqa

2
lnbits/extensions/events/__init__.py

@ -1,7 +1,7 @@
from flask import Blueprint from flask import Blueprint
events_ext = Blueprint("events", __name__, static_folder="static", template_folder="templates") events_ext: Blueprint = Blueprint("events", __name__, static_folder="static", template_folder="templates")
from .views_api import * # noqa from .views_api import * # noqa

2
lnbits/extensions/example/__init__.py

@ -1,7 +1,7 @@
from flask import Blueprint from flask import Blueprint
example_ext = Blueprint("example", __name__, static_folder="static", template_folder="templates") example_ext: Blueprint = Blueprint("example", __name__, static_folder="static", template_folder="templates")
from .views_api import * # noqa from .views_api import * # noqa

2
lnbits/extensions/tpos/__init__.py

@ -1,7 +1,7 @@
from flask import Blueprint from flask import Blueprint
tpos_ext = Blueprint("tpos", __name__, static_folder="static", template_folder="templates") tpos_ext: Blueprint = Blueprint("tpos", __name__, static_folder="static", template_folder="templates")
from .views_api import * # noqa from .views_api import * # noqa

2
lnbits/extensions/withdraw/__init__.py

@ -1,7 +1,7 @@
from flask import Blueprint from flask import Blueprint
withdraw_ext = Blueprint("withdraw", __name__, static_folder="static", template_folder="templates") withdraw_ext: Blueprint = Blueprint("withdraw", __name__, static_folder="static", template_folder="templates")
from .views_api import * # noqa from .views_api import * # noqa

11
lnbits/helpers.py

@ -1,6 +1,6 @@
import json import json
import os import os
import shortuuid import shortuuid # type: ignore
from typing import List, NamedTuple, Optional from typing import List, NamedTuple, Optional
@ -34,7 +34,14 @@ class ExtensionManager:
config = {} config = {}
is_valid = False is_valid = False
output.append(Extension(**{**{"code": extension, "is_valid": is_valid}, **config})) output.append(Extension(
extension,
is_valid,
config.get('name'),
config.get('short_description'),
config.get('icon'),
config.get('contributors')
))
return output return output

9
lnbits/wallets/clightning.py

@ -1,7 +1,6 @@
from requests import get, post
from os import getenv from os import getenv
from .base import InvoiceResponse, PaymentResponse, PaymentStatus, Wallet from .base import InvoiceResponse, PaymentResponse, PaymentStatus, Wallet
from lightning import LightningRpc from lightning import LightningRpc # type: ignore
import random import random
class CLightningWallet(Wallet): class CLightningWallet(Wallet):
@ -17,7 +16,7 @@ class CLightningWallet(Wallet):
def pay_invoice(self, bolt11: str) -> PaymentResponse: def pay_invoice(self, bolt11: str) -> PaymentResponse:
r = self.l1.pay(bolt11) r = self.l1.pay(bolt11)
ok, checking_id, fee_msat, error_message = True, None, None, None ok, checking_id, fee_msat, error_message = True, None, 0, None
return PaymentResponse(ok, checking_id, fee_msat, error_message) return PaymentResponse(ok, checking_id, fee_msat, error_message)
def get_invoice_status(self, checking_id: str) -> PaymentStatus: def get_invoice_status(self, checking_id: str) -> PaymentStatus:
@ -29,8 +28,8 @@ class CLightningWallet(Wallet):
def get_payment_status(self, checking_id: str) -> PaymentStatus: def get_payment_status(self, checking_id: str) -> PaymentStatus:
r = self.l1.listsendpays(checking_id) r = self.l1.listsendpays(checking_id)
if not r.ok: if not r.ok:
return PaymentStatus(r, None) return PaymentStatus(None)
payments = [p for p in r.json()["payments"] if p["payment_hash"] == payment_hash] payments = [p for p in r.json()["payments"] if p["payment_hash"] == checking_id]
payment = payments[0] if payments else None payment = payments[0] if payments else None
statuses = {"UNKNOWN": None, "IN_FLIGHT": None, "SUCCEEDED": True, "FAILED": False} statuses = {"UNKNOWN": None, "IN_FLIGHT": None, "SUCCEEDED": True, "FAILED": False}
return PaymentStatus(statuses[payment["status"]] if payment else None) return PaymentStatus(statuses[payment["status"]] if payment else None)

78
lnbits/wallets/lndgrpc.py

@ -1,14 +1,13 @@
from os import getenv
import os
import base64 import base64
import lnd_grpc # https://github.com/willcl-ark/lnd_grpc import lnd_grpc # type: ignore
from os import getenv
from .base import InvoiceResponse, PaymentResponse, PaymentStatus, Wallet from .base import InvoiceResponse, PaymentResponse, PaymentStatus, Wallet
class LndWallet(Wallet): class LndWallet(Wallet):
def __init__(self): def __init__(self):
endpoint = getenv("LND_GRPC_ENDPOINT") endpoint = getenv("LND_GRPC_ENDPOINT")
self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
self.port = getenv("LND_GRPC_PORT") self.port = getenv("LND_GRPC_PORT")
@ -18,31 +17,21 @@ class LndWallet(Wallet):
self.auth_cert = getenv("LND_CERT") self.auth_cert = getenv("LND_CERT")
lnd_rpc = lnd_grpc.Client( lnd_rpc = lnd_grpc.Client(
lnd_dir = None, lnd_dir=None, tls_cert_path=self.auth_cert, network="mainnet", grpc_host=self.endpoint, grpc_port=self.port
tls_cert_path = self.auth_cert,
network = 'mainnet',
grpc_host = self.endpoint,
grpc_port = self.port
) )
def create_invoice(self, amount: int, memo: str = "") -> InvoiceResponse: def create_invoice(self, amount: int, memo: str = "") -> InvoiceResponse:
lnd_rpc = lnd_grpc.Client( lnd_rpc = lnd_grpc.Client(
lnd_dir = None, lnd_dir=None,
macaroon_path = self.auth_invoice, macaroon_path=self.auth_invoice,
tls_cert_path = self.auth_cert, tls_cert_path=self.auth_cert,
network = 'mainnet', network="mainnet",
grpc_host = self.endpoint, grpc_host=self.endpoint,
grpc_port = self.port grpc_port=self.port,
) )
lndResponse = lnd_rpc.add_invoice( lndResponse = lnd_rpc.add_invoice(memo=memo, value=amount, expiry=600, private=True)
memo = memo, decoded_hash = base64.b64encode(lndResponse.r_hash).decode("utf-8").replace("/", "_")
value = amount,
expiry = 600,
private = True
)
decoded_hash = base64.b64encode(lndResponse.r_hash).decode('utf-8').replace("/","_")
print(lndResponse.r_hash) print(lndResponse.r_hash)
ok, checking_id, payment_request, error_message = True, decoded_hash, str(lndResponse.payment_request), None ok, checking_id, payment_request, error_message = True, decoded_hash, str(lndResponse.payment_request), None
return InvoiceResponse(ok, checking_id, payment_request, error_message) return InvoiceResponse(ok, checking_id, payment_request, error_message)
@ -50,52 +39,43 @@ class LndWallet(Wallet):
def pay_invoice(self, bolt11: str) -> PaymentResponse: def pay_invoice(self, bolt11: str) -> PaymentResponse:
lnd_rpc = lnd_grpc.Client( lnd_rpc = lnd_grpc.Client(
lnd_dir = None, lnd_dir=None,
macaroon_path = self.auth_admin, macaroon_path=self.auth_admin,
tls_cert_path = self.auth_cert, tls_cert_path=self.auth_cert,
network = 'mainnet', network="mainnet",
grpc_host = self.endpoint, grpc_host=self.endpoint,
grpc_port = self.port grpc_port=self.port,
) )
payinvoice = lnd_rpc.pay_invoice( payinvoice = lnd_rpc.pay_invoice(payment_request=bolt11,)
payment_request = bolt11,
)
ok, checking_id, fee_msat, error_message = True, None, 0, None ok, checking_id, fee_msat, error_message = True, None, 0, None
if payinvoice.payment_error: if payinvoice.payment_error:
ok, error_message = False, payinvoice.payment_error ok, error_message = False, payinvoice.payment_error
else: else:
checking_id = base64.b64encode(payinvoice.payment_hash).decode('utf-8').replace("/","_") checking_id = base64.b64encode(payinvoice.payment_hash).decode("utf-8").replace("/", "_")
return PaymentResponse(ok, checking_id, fee_msat, error_message) return PaymentResponse(ok, checking_id, fee_msat, error_message)
def get_invoice_status(self, checking_id: str) -> PaymentStatus: def get_invoice_status(self, checking_id: str) -> PaymentStatus:
check_id = base64.b64decode(checking_id.replace("_","/")) check_id = base64.b64decode(checking_id.replace("_", "/"))
print(check_id) print(check_id)
lnd_rpc = lnd_grpc.Client( lnd_rpc = lnd_grpc.Client(
lnd_dir = None, lnd_dir=None,
macaroon_path = self.auth_invoice, macaroon_path=self.auth_invoice,
tls_cert_path = self.auth_cert, tls_cert_path=self.auth_cert,
network = 'mainnet', network="mainnet",
grpc_host = self.endpoint, grpc_host=self.endpoint,
grpc_port = self.port grpc_port=self.port,
) )
for _response in lnd_rpc.subscribe_single_invoice(check_id): for _response in lnd_rpc.subscribe_single_invoice(check_id):
if _response.state == 1: if _response.state == 1:
return PaymentStatus(True) return PaymentStatus(True)
invoiceThread = threading.Thread( return PaymentStatus(None)
target=detectPayment,
args=[lndResponse.check_id, ],
daemon=True
)
invoiceThread.start()
def get_payment_status(self, checking_id: str) -> PaymentStatus: def get_payment_status(self, checking_id: str) -> PaymentStatus:

5
lnbits/wallets/lndrest.py

@ -71,11 +71,10 @@ class LndRestWallet(Wallet):
return PaymentStatus(r.json()["settled"]) return PaymentStatus(r.json()["settled"])
def get_payment_status(self, checking_id: str) -> PaymentStatus: def get_payment_status(self, checking_id: str) -> PaymentStatus:
r = get(url=f"{self.endpoint}/v1/payments", headers=self.auth_admin, verify=self.auth_cert, params={"include_incomplete": "True", "max_payments": "20"})
r = get(url=f"{self.endpoint}/v1/payments", headers=self.auth_admin, verify=self.auth_cert, params={"include_incomplete": True, "max_payments": "20"})
if not r.ok: if not r.ok:
return PaymentStatus(r, None) return PaymentStatus(None)
payments = [p for p in r.json()["payments"] if p["payment_hash"] == checking_id] payments = [p for p in r.json()["payments"] if p["payment_hash"] == checking_id]
print(checking_id) print(checking_id)

Loading…
Cancel
Save