From c3e337a3195f1c2597fb11dc3ad3859c52634484 Mon Sep 17 00:00:00 2001 From: Eneko Illarramendi Date: Sun, 26 Apr 2020 13:28:19 +0200 Subject: [PATCH] fix: mypy errors --- lnbits/__init__.py | 8 +-- lnbits/bolt11.py | 2 + lnbits/core/__init__.py | 2 +- lnbits/core/crud.py | 17 +++-- lnbits/core/models.py | 10 +-- lnbits/core/services.py | 9 +-- lnbits/core/views/lnurl.py | 4 +- lnbits/decorators.py | 2 +- lnbits/extensions/amilk/__init__.py | 2 +- lnbits/extensions/diagonalley/__init__.py | 2 +- lnbits/extensions/events/__init__.py | 2 +- lnbits/extensions/example/__init__.py | 2 +- lnbits/extensions/tpos/__init__.py | 2 +- lnbits/extensions/withdraw/__init__.py | 2 +- lnbits/helpers.py | 11 ++- lnbits/wallets/clightning.py | 13 ++-- lnbits/wallets/lndgrpc.py | 86 +++++++++-------------- lnbits/wallets/lndrest.py | 9 ++- 18 files changed, 91 insertions(+), 94 deletions(-) diff --git a/lnbits/__init__.py b/lnbits/__init__.py index b2274d1..e332e0e 100644 --- a/lnbits/__init__.py +++ b/lnbits/__init__.py @@ -1,9 +1,9 @@ import importlib from flask import Flask -from flask_assets import Environment, Bundle -from flask_compress import Compress -from flask_talisman import Talisman +from flask_assets import Environment, Bundle # type: ignore +from flask_compress import Compress # type: ignore +from flask_talisman import Talisman # type: ignore from os import getenv from werkzeug.middleware.proxy_fix import ProxyFix @@ -15,7 +15,7 @@ from .settings import FORCE_HTTPS disabled_extensions = getenv("LNBITS_DISABLED_EXTENSIONS", "").split(",") app = Flask(__name__) -app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1) +app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1) # type: ignore valid_extensions = [ext for ext in ExtensionManager(disabled=disabled_extensions).extensions if ext.is_valid] diff --git a/lnbits/bolt11.py b/lnbits/bolt11.py index 32ab5ee..bc3a797 100644 --- a/lnbits/bolt11.py +++ b/lnbits/bolt11.py @@ -1,3 +1,5 @@ +# type: ignore + import bitstring import re from binascii import hexlify diff --git a/lnbits/core/__init__.py b/lnbits/core/__init__.py index a2e89d8..5af72f9 100644 --- a/lnbits/core/__init__.py +++ b/lnbits/core/__init__.py @@ -1,7 +1,7 @@ from flask import Blueprint -core_app = Blueprint("core", __name__, template_folder="templates", static_folder="static") +core_app: Blueprint = Blueprint("core", __name__, template_folder="templates", static_folder="static") from .views.api import * # noqa diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index d31356f..195748a 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -16,7 +16,10 @@ def create_account() -> User: user_id = uuid4().hex db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,)) - return get_account(user_id=user_id) + new_account = get_account(user_id=user_id) + assert new_account, "Newly created account couldn't be retrieved" + + return new_account def get_account(user_id: str) -> Optional[User]: @@ -74,7 +77,10 @@ def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: (wallet_id, wallet_name or DEFAULT_WALLET_NAME, user_id, uuid4().hex, uuid4().hex), ) - return get_wallet(wallet_id=wallet_id) + new_wallet = get_wallet(wallet_id=wallet_id) + assert new_wallet, "Newly created wallet couldn't be retrieved" + + return new_wallet def delete_wallet(*, user_id: str, wallet_id: str) -> None: @@ -175,7 +181,7 @@ def delete_wallet_payments_expired(wallet_id: str, *, seconds: int = 86400) -> N def create_payment( - *, wallet_id: str, checking_id: str, amount: str, memo: str, fee: int = 0, pending: bool = True + *, wallet_id: str, checking_id: str, amount: int, memo: str, fee: int = 0, pending: bool = True ) -> Payment: with open_db() as db: db.execute( @@ -186,7 +192,10 @@ def create_payment( (wallet_id, checking_id, amount, int(pending), memo, fee), ) - return get_wallet_payment(wallet_id, checking_id) + new_payment = get_wallet_payment(wallet_id, checking_id) + assert new_payment, "Newly created payment couldn't be retrieved" + + return new_payment def update_payment_status(checking_id: str, pending: bool) -> None: diff --git a/lnbits/core/models.py b/lnbits/core/models.py index c57c281..5ee0868 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -4,8 +4,8 @@ from typing import List, NamedTuple, Optional class User(NamedTuple): id: str email: str - extensions: Optional[List[str]] = [] - wallets: Optional[List["Wallet"]] = [] + extensions: List[str] = [] + wallets: List["Wallet"] = [] password: Optional[str] = None @property @@ -27,9 +27,9 @@ class Wallet(NamedTuple): @property def balance(self) -> int: - return int(self.balance / 1000) + return self.balance // 1000 - def get_payment(self, checking_id: str) -> "Payment": + def get_payment(self, checking_id: str) -> Optional["Payment"]: from .crud import get_wallet_payment return get_wallet_payment(self.id, checking_id) @@ -59,7 +59,7 @@ class Payment(NamedTuple): @property def sat(self) -> int: - return self.amount / 1000 + return self.amount // 1000 @property def is_in(self) -> bool: diff --git a/lnbits/core/services.py b/lnbits/core/services.py index 6e43aca..d5924c0 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -1,6 +1,6 @@ from typing import Optional, Tuple -from lnbits.bolt11 import decode as bolt11_decode +from lnbits.bolt11 import decode as bolt11_decode # type: ignore from lnbits.helpers import urlsafe_short_hash from lnbits.settings import WALLET @@ -24,7 +24,6 @@ def create_invoice(*, wallet_id: str, amount: int, memo: str) -> Tuple[str, str] def pay_invoice(*, wallet_id: str, bolt11: str, max_sat: Optional[int] = None) -> str: temp_id = f"temp_{urlsafe_short_hash()}" - try: invoice = bolt11_decode(bolt11) @@ -34,7 +33,7 @@ def pay_invoice(*, wallet_id: str, bolt11: str, max_sat: Optional[int] = None) - if max_sat and invoice.amount_msat > max_sat * 1000: raise ValueError("Amount in invoice is too high.") - fee_reserve = max(1000, invoice.amount_msat * 0.01) + fee_reserve = max(1000, int(invoice.amount_msat * 0.01)) create_payment( wallet_id=wallet_id, checking_id=temp_id, @@ -43,7 +42,9 @@ def pay_invoice(*, wallet_id: str, bolt11: str, max_sat: Optional[int] = None) - memo=temp_id, ) - if get_wallet(wallet_id).balance_msat < 0: + wallet = get_wallet(wallet_id) + assert wallet, "invalid wallet id" + if wallet.balance_msat < 0: raise PermissionError("Insufficient balance.") ok, checking_id, fee_msat, error_message = WALLET.pay_invoice(bolt11) diff --git a/lnbits/core/views/lnurl.py b/lnbits/core/views/lnurl.py index ab4fc96..929649e 100644 --- a/lnbits/core/views/lnurl.py +++ b/lnbits/core/views/lnurl.py @@ -1,8 +1,8 @@ import requests from flask import abort, redirect, request, url_for -from lnurl import LnurlWithdrawResponse, handle as handle_lnurl -from lnurl.exceptions import LnurlException +from lnurl import LnurlWithdrawResponse, handle as handle_lnurl # type: ignore +from lnurl.exceptions import LnurlException # type: ignore from time import sleep from lnbits.core import core_app diff --git a/lnbits/decorators.py b/lnbits/decorators.py index ee5fbcc..fd62f0e 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -1,4 +1,4 @@ -from cerberus import Validator +from cerberus import Validator # type: ignore from flask import g, abort, jsonify, request from functools import wraps from typing import List, Union diff --git a/lnbits/extensions/amilk/__init__.py b/lnbits/extensions/amilk/__init__.py index 175dcb4..ea93c98 100644 --- a/lnbits/extensions/amilk/__init__.py +++ b/lnbits/extensions/amilk/__init__.py @@ -1,7 +1,7 @@ from flask import Blueprint -amilk_ext = Blueprint("amilk", __name__, static_folder="static", template_folder="templates") +amilk_ext: Blueprint = Blueprint("amilk", __name__, static_folder="static", template_folder="templates") from .views_api import * # noqa diff --git a/lnbits/extensions/diagonalley/__init__.py b/lnbits/extensions/diagonalley/__init__.py index 44fbeef..c3eaf52 100644 --- a/lnbits/extensions/diagonalley/__init__.py +++ b/lnbits/extensions/diagonalley/__init__.py @@ -1,7 +1,7 @@ from flask import Blueprint -diagonalley_ext = Blueprint("diagonalley", __name__, static_folder="static", template_folder="templates") +diagonalley_ext: Blueprint = Blueprint("diagonalley", __name__, static_folder="static", template_folder="templates") from .views_api import * # noqa diff --git a/lnbits/extensions/events/__init__.py b/lnbits/extensions/events/__init__.py index 3b76e2f..52d499e 100644 --- a/lnbits/extensions/events/__init__.py +++ b/lnbits/extensions/events/__init__.py @@ -1,7 +1,7 @@ from flask import Blueprint -events_ext = Blueprint("events", __name__, static_folder="static", template_folder="templates") +events_ext: Blueprint = Blueprint("events", __name__, static_folder="static", template_folder="templates") from .views_api import * # noqa diff --git a/lnbits/extensions/example/__init__.py b/lnbits/extensions/example/__init__.py index 1950e6c..f8ef9ab 100644 --- a/lnbits/extensions/example/__init__.py +++ b/lnbits/extensions/example/__init__.py @@ -1,7 +1,7 @@ from flask import Blueprint -example_ext = Blueprint("example", __name__, static_folder="static", template_folder="templates") +example_ext: Blueprint = Blueprint("example", __name__, static_folder="static", template_folder="templates") from .views_api import * # noqa diff --git a/lnbits/extensions/tpos/__init__.py b/lnbits/extensions/tpos/__init__.py index b661fd2..514e663 100644 --- a/lnbits/extensions/tpos/__init__.py +++ b/lnbits/extensions/tpos/__init__.py @@ -1,7 +1,7 @@ from flask import Blueprint -tpos_ext = Blueprint("tpos", __name__, static_folder="static", template_folder="templates") +tpos_ext: Blueprint = Blueprint("tpos", __name__, static_folder="static", template_folder="templates") from .views_api import * # noqa diff --git a/lnbits/extensions/withdraw/__init__.py b/lnbits/extensions/withdraw/__init__.py index f1e1c26..f89fec8 100644 --- a/lnbits/extensions/withdraw/__init__.py +++ b/lnbits/extensions/withdraw/__init__.py @@ -1,7 +1,7 @@ from flask import Blueprint -withdraw_ext = Blueprint("withdraw", __name__, static_folder="static", template_folder="templates") +withdraw_ext: Blueprint = Blueprint("withdraw", __name__, static_folder="static", template_folder="templates") from .views_api import * # noqa diff --git a/lnbits/helpers.py b/lnbits/helpers.py index 01e653a..daa5dbd 100644 --- a/lnbits/helpers.py +++ b/lnbits/helpers.py @@ -1,6 +1,6 @@ import json import os -import shortuuid +import shortuuid # type: ignore from typing import List, NamedTuple, Optional @@ -34,7 +34,14 @@ class ExtensionManager: config = {} is_valid = False - output.append(Extension(**{**{"code": extension, "is_valid": is_valid}, **config})) + output.append(Extension( + extension, + is_valid, + config.get('name'), + config.get('short_description'), + config.get('icon'), + config.get('contributors') + )) return output diff --git a/lnbits/wallets/clightning.py b/lnbits/wallets/clightning.py index 62c69c9..f17d343 100644 --- a/lnbits/wallets/clightning.py +++ b/lnbits/wallets/clightning.py @@ -1,23 +1,22 @@ -from requests import get, post from os import getenv from .base import InvoiceResponse, PaymentResponse, PaymentStatus, Wallet -from lightning import LightningRpc +from lightning import LightningRpc # type: ignore import random class CLightningWallet(Wallet): def __init__(self): self.l1 = LightningRpc(getenv("CLIGHTNING_RPC")) - + def create_invoice(self, amount: int, memo: str = "") -> InvoiceResponse: - label = "lbl{}".format(random.random()) + label = "lbl{}".format(random.random()) r = self.l1.invoice(amount*1000, label, memo, exposeprivatechannels=True) ok, checking_id, payment_request, error_message = True, r["payment_hash"], r["bolt11"], None return InvoiceResponse(ok, checking_id, payment_request, error_message) def pay_invoice(self, bolt11: str) -> PaymentResponse: r = self.l1.pay(bolt11) - ok, checking_id, fee_msat, error_message = True, None, None, None + ok, checking_id, fee_msat, error_message = True, None, 0, None return PaymentResponse(ok, checking_id, fee_msat, error_message) def get_invoice_status(self, checking_id: str) -> PaymentStatus: @@ -29,8 +28,8 @@ class CLightningWallet(Wallet): def get_payment_status(self, checking_id: str) -> PaymentStatus: r = self.l1.listsendpays(checking_id) if not r.ok: - return PaymentStatus(r, None) - payments = [p for p in r.json()["payments"] if p["payment_hash"] == payment_hash] + return PaymentStatus(None) + payments = [p for p in r.json()["payments"] if p["payment_hash"] == checking_id] payment = payments[0] if payments else None statuses = {"UNKNOWN": None, "IN_FLIGHT": None, "SUCCEEDED": True, "FAILED": False} return PaymentStatus(statuses[payment["status"]] if payment else None) diff --git a/lnbits/wallets/lndgrpc.py b/lnbits/wallets/lndgrpc.py index 4079f94..1005229 100644 --- a/lnbits/wallets/lndgrpc.py +++ b/lnbits/wallets/lndgrpc.py @@ -1,14 +1,13 @@ -from os import getenv -import os import base64 -import lnd_grpc # https://github.com/willcl-ark/lnd_grpc +import lnd_grpc # type: ignore + +from os import getenv + from .base import InvoiceResponse, PaymentResponse, PaymentStatus, Wallet class LndWallet(Wallet): - def __init__(self): - endpoint = getenv("LND_GRPC_ENDPOINT") self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint self.port = getenv("LND_GRPC_PORT") @@ -18,85 +17,66 @@ class LndWallet(Wallet): self.auth_cert = getenv("LND_CERT") lnd_rpc = lnd_grpc.Client( - lnd_dir = None, - tls_cert_path = self.auth_cert, - network = 'mainnet', - grpc_host = self.endpoint, - grpc_port = self.port + lnd_dir=None, tls_cert_path=self.auth_cert, network="mainnet", grpc_host=self.endpoint, grpc_port=self.port ) def create_invoice(self, amount: int, memo: str = "") -> InvoiceResponse: - lnd_rpc = lnd_grpc.Client( - lnd_dir = None, - macaroon_path = self.auth_invoice, - tls_cert_path = self.auth_cert, - network = 'mainnet', - grpc_host = self.endpoint, - grpc_port = self.port + lnd_dir=None, + macaroon_path=self.auth_invoice, + tls_cert_path=self.auth_cert, + network="mainnet", + grpc_host=self.endpoint, + grpc_port=self.port, ) - lndResponse = lnd_rpc.add_invoice( - memo = memo, - value = amount, - expiry = 600, - private = True - ) - decoded_hash = base64.b64encode(lndResponse.r_hash).decode('utf-8').replace("/","_") + lndResponse = lnd_rpc.add_invoice(memo=memo, value=amount, expiry=600, private=True) + decoded_hash = base64.b64encode(lndResponse.r_hash).decode("utf-8").replace("/", "_") print(lndResponse.r_hash) - ok, checking_id, payment_request, error_message = True, decoded_hash, str(lndResponse.payment_request), None + ok, checking_id, payment_request, error_message = True, decoded_hash, str(lndResponse.payment_request), None return InvoiceResponse(ok, checking_id, payment_request, error_message) def pay_invoice(self, bolt11: str) -> PaymentResponse: lnd_rpc = lnd_grpc.Client( - lnd_dir = None, - macaroon_path = self.auth_admin, - tls_cert_path = self.auth_cert, - network = 'mainnet', - grpc_host = self.endpoint, - grpc_port = self.port + lnd_dir=None, + macaroon_path=self.auth_admin, + tls_cert_path=self.auth_cert, + network="mainnet", + grpc_host=self.endpoint, + grpc_port=self.port, ) - payinvoice = lnd_rpc.pay_invoice( - payment_request = bolt11, - ) + payinvoice = lnd_rpc.pay_invoice(payment_request=bolt11,) ok, checking_id, fee_msat, error_message = True, None, 0, None - + if payinvoice.payment_error: ok, error_message = False, payinvoice.payment_error else: - checking_id = base64.b64encode(payinvoice.payment_hash).decode('utf-8').replace("/","_") + checking_id = base64.b64encode(payinvoice.payment_hash).decode("utf-8").replace("/", "_") return PaymentResponse(ok, checking_id, fee_msat, error_message) def get_invoice_status(self, checking_id: str) -> PaymentStatus: - - check_id = base64.b64decode(checking_id.replace("_","/")) + + check_id = base64.b64decode(checking_id.replace("_", "/")) print(check_id) lnd_rpc = lnd_grpc.Client( - lnd_dir = None, - macaroon_path = self.auth_invoice, - tls_cert_path = self.auth_cert, - network = 'mainnet', - grpc_host = self.endpoint, - grpc_port = self.port + lnd_dir=None, + macaroon_path=self.auth_invoice, + tls_cert_path=self.auth_cert, + network="mainnet", + grpc_host=self.endpoint, + grpc_port=self.port, ) for _response in lnd_rpc.subscribe_single_invoice(check_id): - if _response.state == 1: - return PaymentStatus(True) - invoiceThread = threading.Thread( - target=detectPayment, - args=[lndResponse.check_id, ], - daemon=True - ) - invoiceThread.start() + return PaymentStatus(None) def get_payment_status(self, checking_id: str) -> PaymentStatus: - + return PaymentStatus(True) diff --git a/lnbits/wallets/lndrest.py b/lnbits/wallets/lndrest.py index 9cf532f..f9c98a2 100644 --- a/lnbits/wallets/lndrest.py +++ b/lnbits/wallets/lndrest.py @@ -56,7 +56,7 @@ class LndRestWallet(Wallet): checking_id = r.json()["payment_hash"] else: error_message = r.json()["error"] - + return PaymentResponse(ok, checking_id, fee_msat, error_message) @@ -71,11 +71,10 @@ class LndRestWallet(Wallet): return PaymentStatus(r.json()["settled"]) def get_payment_status(self, checking_id: str) -> PaymentStatus: - - r = get(url=f"{self.endpoint}/v1/payments", headers=self.auth_admin, verify=self.auth_cert, params={"include_incomplete": True, "max_payments": "20"}) - + r = get(url=f"{self.endpoint}/v1/payments", headers=self.auth_admin, verify=self.auth_cert, params={"include_incomplete": "True", "max_payments": "20"}) + if not r.ok: - return PaymentStatus(r, None) + return PaymentStatus(None) payments = [p for p in r.json()["payments"] if p["payment_hash"] == checking_id] print(checking_id)