diff --git a/Pipfile b/Pipfile index fed4b71..a875303 100644 --- a/Pipfile +++ b/Pipfile @@ -22,6 +22,7 @@ secure = "*" typing-extensions = "*" aiohttp = "*" aiohttp-sse-client = "*" +aiosqlite = "*" [dev-packages] black = "==20.8b1" diff --git a/lnbits/__main__.py b/lnbits/__main__.py index 6932fba..d8f0e12 100644 --- a/lnbits/__main__.py +++ b/lnbits/__main__.py @@ -1,9 +1,4 @@ from .app import create_app -from .commands import migrate_databases, transpile_scss, bundle_vendored - -migrate_databases() -transpile_scss() -bundle_vendored() app = create_app() app.run(host=app.config["HOST"], port=app.config["PORT"]) diff --git a/lnbits/app.py b/lnbits/app.py index a25fb49..d5bdf31 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -1,14 +1,14 @@ import importlib import asyncio -from quart import Quart, g +from quart import Quart, Blueprint, g from quart_cors import cors # type: ignore from quart_compress import Compress # type: ignore from secure import SecureHeaders # type: ignore from .commands import db_migrate from .core import core_app -from .db import open_db +from .db import open_db, open_ext_db from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored from .proxy_fix import ASGIProxyFix @@ -26,6 +26,7 @@ def create_app(config_object="lnbits.settings") -> Quart: cors(app) Compress(app) + register_preparation_actions(app) register_assets(app) register_blueprints(app) register_filters(app) @@ -36,6 +37,19 @@ def create_app(config_object="lnbits.settings") -> Quart: return app +def register_preparation_actions(app): + """Actions we will perform before serving, but in the main event loop.""" + + from .commands import migrate_databases, transpile_scss, bundle_vendored + + @app.before_serving + async def preparation_tasks(): + await migrate_databases() + + transpile_scss() + bundle_vendored() + + def register_blueprints(app: Quart) -> None: """Register Flask blueprints / LNbits extensions.""" app.register_blueprint(core_app) @@ -43,11 +57,25 @@ def register_blueprints(app: Quart) -> None: for ext in get_valid_extensions(): try: ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}") - app.register_blueprint(getattr(ext_module, f"{ext.code}_ext"), url_prefix=f"/{ext.code}") + bp = getattr(ext_module, f"{ext.code}_ext") + + register_request_hooks(bp) + + app.register_blueprint(bp, url_prefix=f"/{ext.code}") except Exception: raise ImportError(f"Please make sure that the extension `{ext.code}` follows conventions.") +def register_blueprint_hooks(bp: Blueprint) -> None: + @bp.before_request + async def before_request(): + g.ext_db = await open_ext_db(bp.name) + + @bp.teardown_request + async def after_request(exc): + await g.ext_db.__aexit__(type(exc), exc, None) + + def register_commands(app: Quart): """Register Click commands.""" app.cli.add_command(db_migrate) @@ -73,20 +101,18 @@ def register_filters(app: Quart): def register_request_hooks(app: Quart): - """Open the core db for each request so everything happens in a big transaction""" - - @app.before_request - async def before_request(): - g.db = open_db() - @app.after_request async def set_secure_headers(response): secure_headers.quart(response) return response + @app.before_request + async def before_request(): + g.db = await open_db() + @app.teardown_request async def after_request(exc): - g.db.__exit__(type(exc), exc, None) + await g.db.__aexit__(type(exc), exc, None) def register_async_tasks(app): diff --git a/lnbits/commands.py b/lnbits/commands.py index 653175f..8826b09 100644 --- a/lnbits/commands.py +++ b/lnbits/commands.py @@ -1,3 +1,4 @@ +import asyncio import click import importlib import re @@ -14,7 +15,7 @@ from .settings import LNBITS_PATH @click.command("migrate") def db_migrate(): - migrate_databases() + asyncio.run(migrate_databases()) @click.command("assets") @@ -42,21 +43,21 @@ def bundle_vendored(): f.write(output) -def migrate_databases(): +async def migrate_databases(): """Creates the necessary databases if they don't exist already; or migrates them.""" - with open_db() as core_db: + async with await open_db() as core_db: try: - rows = core_db.fetchall("SELECT * FROM dbversions") + rows = await core_db.fetchall("SELECT * FROM dbversions") except sqlite3.OperationalError: # migration 3 wasn't ran core_migrations.m000_create_migrations_table(core_db) - rows = core_db.fetchall("SELECT * FROM dbversions") + rows = await core_db.fetchall("SELECT * FROM dbversions") current_versions = {row["db"]: row["version"] for row in rows} matcher = re.compile(r"^m(\d\d\d)_") - def run_migration(db, migrations_module): + async def run_migration(db, migrations_module): db_name = migrations_module.__name__.split(".")[-2] for key, run_migration in migrations_module.__dict__.items(): match = match = matcher.match(key) @@ -64,17 +65,17 @@ def migrate_databases(): version = int(match.group(1)) if version > current_versions.get(db_name, 0): print(f"running migration {db_name}.{version}") - run_migration(db) - core_db.execute( + await run_migration(db) + await core_db.execute( "INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)", (db_name, version) ) - run_migration(core_db, core_migrations) + await run_migration(core_db, core_migrations) for ext in get_valid_extensions(): try: ext_migrations = importlib.import_module(f"lnbits.extensions.{ext.code}.migrations") - with open_ext_db(ext.code) as db: - run_migration(db, ext_migrations) + async with await open_ext_db(ext.code) as db: + await run_migration(db, ext_migrations) except ImportError: raise ImportError(f"Please make sure that the extension `{ext.code}` has a migrations file.") diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index 984492a..2bf8e55 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -14,9 +14,9 @@ from .models import User, Wallet, Payment # -------- -def create_account() -> User: +async def create_account() -> User: user_id = uuid4().hex - g.db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,)) + await g.db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,)) new_account = get_account(user_id=user_id) assert new_account, "Newly created account couldn't be retrieved" @@ -24,18 +24,18 @@ def create_account() -> User: return new_account -def get_account(user_id: str) -> Optional[User]: - row = g.db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,)) +async def get_account(user_id: str) -> Optional[User]: + row = await g.db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,)) return User(**row) if row else None -def get_user(user_id: str) -> Optional[User]: - user = g.db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,)) +async def get_user(user_id: str) -> Optional[User]: + user = await g.db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,)) if user: - extensions = g.db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,)) - wallets = g.db.fetchall( + extensions = await g.db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,)) + wallets = await g.db.fetchall( """ SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat FROM wallets @@ -51,8 +51,8 @@ def get_user(user_id: str) -> Optional[User]: ) -def update_user_extension(*, user_id: str, extension: str, active: int) -> None: - g.db.execute( +async def update_user_extension(*, user_id: str, extension: str, active: int) -> None: + await g.db.execute( """ INSERT OR REPLACE INTO extensions (user, extension, active) VALUES (?, ?, ?) @@ -65,9 +65,9 @@ def update_user_extension(*, user_id: str, extension: str, active: int) -> None: # ------- -def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: +async def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: wallet_id = uuid4().hex - g.db.execute( + await g.db.execute( """ INSERT INTO wallets (id, name, user, adminkey, inkey) VALUES (?, ?, ?, ?, ?) @@ -81,8 +81,8 @@ def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: return new_wallet -def delete_wallet(*, user_id: str, wallet_id: str) -> None: - g.db.execute( +async def delete_wallet(*, user_id: str, wallet_id: str) -> None: + await g.db.execute( """ UPDATE wallets AS w SET @@ -95,8 +95,8 @@ def delete_wallet(*, user_id: str, wallet_id: str) -> None: ) -def get_wallet(wallet_id: str) -> Optional[Wallet]: - row = g.db.fetchone( +async def get_wallet(wallet_id: str) -> Optional[Wallet]: + row = await g.db.fetchone( """ SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat FROM wallets @@ -108,8 +108,8 @@ def get_wallet(wallet_id: str) -> Optional[Wallet]: return Wallet(**row) if row else None -def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: - row = g.db.fetchone( +async def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: + row = await g.db.fetchone( """ SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat FROM wallets @@ -131,8 +131,8 @@ def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: # --------------- -def get_standalone_payment(checking_id: str) -> Optional[Payment]: - row = g.db.fetchone( +async def get_standalone_payment(checking_id: str) -> Optional[Payment]: + row = await g.db.fetchone( """ SELECT * FROM apipayments @@ -144,8 +144,8 @@ def get_standalone_payment(checking_id: str) -> Optional[Payment]: return Payment.from_row(row) if row else None -def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]: - row = g.db.fetchone( +async def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]: + row = await g.db.fetchone( """ SELECT * FROM apipayments @@ -157,7 +157,7 @@ def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]: return Payment.from_row(row) if row else None -def get_wallet_payments( +async def get_wallet_payments( wallet_id: str, *, complete: bool = False, @@ -197,7 +197,7 @@ def get_wallet_payments( clause += "AND checking_id NOT LIKE 'temp_%' " clause += "AND checking_id NOT LIKE 'internal_%' " - rows = g.db.fetchall( + rows = await g.db.fetchall( f""" SELECT * FROM apipayments @@ -210,8 +210,8 @@ def get_wallet_payments( return [Payment.from_row(row) for row in rows] -def delete_expired_invoices() -> None: - rows = g.db.fetchall( +async def delete_expired_invoices() -> None: + rows = await g.db.fetchall( """ SELECT bolt11 FROM apipayments @@ -228,7 +228,7 @@ def delete_expired_invoices() -> None: if expiration_date > datetime.datetime.utcnow(): continue - g.db.execute( + await g.db.execute( """ DELETE FROM apipayments WHERE pending = 1 AND hash = ? @@ -241,7 +241,7 @@ def delete_expired_invoices() -> None: # -------- -def create_payment( +async def create_payment( *, wallet_id: str, checking_id: str, @@ -254,7 +254,7 @@ def create_payment( pending: bool = True, extra: Optional[Dict] = None, ) -> Payment: - g.db.execute( + await g.db.execute( """ INSERT INTO apipayments (wallet, checking_id, bolt11, hash, preimage, @@ -275,14 +275,14 @@ def create_payment( ), ) - new_payment = get_wallet_payment(wallet_id, payment_hash) + new_payment = await get_wallet_payment(wallet_id, payment_hash) assert new_payment, "Newly created payment couldn't be retrieved" return new_payment -def update_payment_status(checking_id: str, pending: bool) -> None: - g.db.execute( +async def update_payment_status(checking_id: str, pending: bool) -> None: + await g.db.execute( "UPDATE apipayments SET pending = ? WHERE checking_id = ?", ( int(pending), @@ -291,12 +291,12 @@ def update_payment_status(checking_id: str, pending: bool) -> None: ) -def delete_payment(checking_id: str) -> None: - g.db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,)) +async def delete_payment(checking_id: str) -> None: + await g.db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,)) -def check_internal(payment_hash: str) -> Optional[str]: - row = g.db.fetchone( +async def check_internal(payment_hash: str) -> Optional[str]: + row = await g.db.fetchone( """ SELECT checking_id FROM apipayments WHERE hash = ? AND pending AND amount > 0 diff --git a/lnbits/core/models.py b/lnbits/core/models.py index 243f934..12961e4 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -33,12 +33,12 @@ class Wallet(NamedTuple): def balance(self) -> int: return self.balance_msat // 1000 - def get_payment(self, payment_hash: str) -> Optional["Payment"]: + async def get_payment(self, payment_hash: str) -> Optional["Payment"]: from .crud import get_wallet_payment - return get_wallet_payment(self.id, payment_hash) + return await get_wallet_payment(self.id, payment_hash) - def get_payments( + async def get_payments( self, *, complete: bool = True, @@ -49,7 +49,7 @@ class Wallet(NamedTuple): ) -> List["Payment"]: from .crud import get_wallet_payments - return get_wallet_payments( + return await get_wallet_payments( self.id, complete=complete, pending=pending, @@ -110,12 +110,12 @@ class Payment(NamedTuple): def is_uncheckable(self) -> bool: return self.checking_id.startswith("temp_") or self.checking_id.startswith("internal_") - def set_pending(self, pending: bool) -> None: + async def set_pending(self, pending: bool) -> None: from .crud import update_payment_status - update_payment_status(self.checking_id, pending) + await update_payment_status(self.checking_id, pending) - def check_pending(self) -> None: + async def check_pending(self) -> None: if self.is_uncheckable: return @@ -124,9 +124,9 @@ class Payment(NamedTuple): else: pending = WALLET.get_invoice_status(self.checking_id) - self.set_pending(pending.pending) + await self.set_pending(pending.pending) - def delete(self) -> None: + async def delete(self) -> None: from .crud import delete_payment - delete_payment(self.checking_id) + await delete_payment(self.checking_id) diff --git a/lnbits/core/services.py b/lnbits/core/services.py index e16b2f2..6025392 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -14,7 +14,7 @@ from lnbits.wallets.base import PaymentStatus from .crud import get_wallet, create_payment, delete_payment, check_internal, update_payment_status, get_wallet_payment -def create_invoice( +async def create_invoice( *, wallet_id: str, amount: int, @@ -34,7 +34,7 @@ def create_invoice( invoice = bolt11.decode(payment_request) amount_msat = amount * 1000 - create_payment( + await create_payment( wallet_id=wallet_id, checking_id=checking_id, payment_request=payment_request, @@ -44,11 +44,11 @@ def create_invoice( extra=extra, ) - g.db.commit() + await g.db.commit() return invoice.payment_hash, payment_request -def pay_invoice( +async def pay_invoice( *, wallet_id: str, payment_request: str, max_sat: Optional[int] = None, extra: Optional[Dict] = None ) -> str: temp_id = f"temp_{urlsafe_short_hash()}" @@ -82,45 +82,45 @@ def pay_invoice( ) # check_internal() returns the checking_id of the invoice we're waiting for - internal = check_internal(invoice.payment_hash) + internal = await check_internal(invoice.payment_hash) if internal: # create a new payment from this wallet - create_payment(checking_id=internal_id, fee=0, pending=False, **payment_kwargs) + await create_payment(checking_id=internal_id, fee=0, pending=False, **payment_kwargs) else: # create a temporary payment here so we can check if # the balance is enough in the next step fee_reserve = max(1000, int(invoice.amount_msat * 0.01)) - create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs) + await create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs) # do the balance check - wallet = get_wallet(wallet_id) + wallet = await get_wallet(wallet_id) assert wallet, "invalid wallet id" if wallet.balance_msat < 0: - g.db.rollback() + await g.db.rollback() raise PermissionError("Insufficient balance.") else: - g.db.commit() + await g.db.commit() if internal: # mark the invoice from the other side as not pending anymore # so the other side only has access to his new money when we are sure # the payer has enough to deduct from - update_payment_status(checking_id=internal, pending=False) + await update_payment_status(checking_id=internal, pending=False) else: # actually pay the external invoice ok, checking_id, fee_msat, error_message = WALLET.pay_invoice(payment_request) if ok: - create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs) - delete_payment(temp_id) + await create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs) + await delete_payment(temp_id) else: raise Exception(error_message or "Failed to pay_invoice on backend.") - g.db.commit() + await g.db.commit() return invoice.payment_hash -def check_invoice_status(wallet_id: str, payment_hash: str) -> PaymentStatus: - payment = get_wallet_payment(wallet_id, payment_hash) +async def check_invoice_status(wallet_id: str, payment_hash: str) -> PaymentStatus: + payment = await get_wallet_payment(wallet_id, payment_hash) if not payment: return PaymentStatus(None) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 70d20bd..6414563 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -13,12 +13,12 @@ from lnbits.decorators import api_check_wallet_key, api_validate_post_request @api_check_wallet_key("invoice") async def api_payments(): if "check_pending" in request.args: - delete_expired_invoices() + await delete_expired_invoices() - for payment in g.wallet.get_payments(complete=False, pending=True, exclude_uncheckable=True): - payment.check_pending() + for payment in await g.wallet.get_payments(complete=False, pending=True, exclude_uncheckable=True): + await payment.check_pending() - return jsonify(g.wallet.get_payments()), HTTPStatus.OK + return jsonify(await g.wallet.get_payments()), HTTPStatus.OK @api_check_wallet_key("invoice") @@ -38,11 +38,11 @@ async def api_payments_create_invoice(): memo = g.data["memo"] try: - payment_hash, payment_request = create_invoice( + payment_hash, payment_request = await create_invoice( wallet_id=g.wallet.id, amount=g.data["amount"], memo=memo, description_hash=description_hash ) except Exception as e: - g.db.rollback() + await g.db.rollback() return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR invoice = bolt11.decode(payment_request) @@ -63,14 +63,14 @@ async def api_payments_create_invoice(): @api_validate_post_request(schema={"bolt11": {"type": "string", "empty": False, "required": True}}) async def api_payments_pay_invoice(): try: - payment_hash = pay_invoice(wallet_id=g.wallet.id, payment_request=g.data["bolt11"]) + payment_hash = await pay_invoice(wallet_id=g.wallet.id, payment_request=g.data["bolt11"]) except ValueError as e: return jsonify({"message": str(e)}), HTTPStatus.BAD_REQUEST except PermissionError as e: return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN except Exception as e: print(e) - g.db.rollback() + await g.db.rollback() return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR return ( @@ -96,7 +96,7 @@ async def api_payments_create(): @core_app.route("/api/v1/payments/", methods=["GET"]) @api_check_wallet_key("invoice") async def api_payment(payment_hash): - payment = g.wallet.get_payment(payment_hash) + payment = await g.wallet.get_payment(payment_hash) if not payment: return jsonify({"message": "Payment does not exist."}), HTTPStatus.NOT_FOUND @@ -104,7 +104,7 @@ async def api_payment(payment_hash): return jsonify({"paid": True}), HTTPStatus.OK try: - payment.check_pending() + await payment.check_pending() except Exception: return jsonify({"paid": False}), HTTPStatus.OK diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index 36720d9..13dc91c 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -36,11 +36,11 @@ async def extensions(): abort(HTTPStatus.BAD_REQUEST, "You can either `enable` or `disable` an extension.") if extension_to_enable: - update_user_extension(user_id=g.user.id, extension=extension_to_enable, active=1) + await update_user_extension(user_id=g.user.id, extension=extension_to_enable, active=1) elif extension_to_disable: - update_user_extension(user_id=g.user.id, extension=extension_to_disable, active=0) + await update_user_extension(user_id=g.user.id, extension=extension_to_disable, active=0) - return await render_template("core/extensions.html", user=get_user(g.user.id)) + return await render_template("core/extensions.html", user=g.user) @core_app.route("/wallet") @@ -58,9 +58,12 @@ async def wallet(): # nothing: create everything if not user_id: - user = get_user(create_account().id) + account = await create_account() + user = await get_user(account.id) else: - user = get_user(user_id) or abort(HTTPStatus.NOT_FOUND, "User does not exist.") + user = await get_user(user_id) + if not user: + abort(HTTPStatus.NOT_FOUND, "User does not exist.") if LNBITS_ALLOWED_USERS and user_id not in LNBITS_ALLOWED_USERS: abort(HTTPStatus.UNAUTHORIZED, "User not authorized.") @@ -69,7 +72,7 @@ async def wallet(): if user.wallets and not wallet_name: wallet = user.wallets[0] else: - wallet = create_wallet(user_id=user.id, wallet_name=wallet_name) + wallet = await create_wallet(user_id=user.id, wallet_name=wallet_name) return redirect(url_for("core.wallet", usr=user.id, wal=wallet.id)) @@ -91,7 +94,7 @@ async def deletewallet(): if wallet_id not in user_wallet_ids: abort(HTTPStatus.FORBIDDEN, "Not your wallet.") else: - delete_wallet(user_id=g.user.id, wallet_id=wallet_id) + await delete_wallet(user_id=g.user.id, wallet_id=wallet_id) user_wallet_ids.remove(wallet_id) if user_wallet_ids: diff --git a/lnbits/core/views/lnurl.py b/lnbits/core/views/lnurl.py index 0d0ac12..f94bc12 100644 --- a/lnbits/core/views/lnurl.py +++ b/lnbits/core/views/lnurl.py @@ -51,9 +51,9 @@ async def lnurlwallet(): continue break - user = get_user(create_account().id) - wallet = create_wallet(user_id=user.id) - create_payment( + user = await get_user(await create_account().id) + wallet = await create_wallet(user_id=user.id) + await create_payment( wallet_id=wallet.id, checking_id=checking_id, amount=withdraw_res.max_sats * 1000, diff --git a/lnbits/db.py b/lnbits/db.py index ec26d69..73efeea 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -1,5 +1,6 @@ import os -import sqlite3 +import aiosqlite +from typing import Dict from .settings import LNBITS_DATA_FOLDER @@ -7,51 +8,57 @@ from .settings import LNBITS_DATA_FOLDER class Database: def __init__(self, db_path: str): self.path = db_path - self.connection = sqlite3.connect(db_path) - self.connection.row_factory = sqlite3.Row - self.cursor = self.connection.cursor() - def __enter__(self): + async def connect(self): + self.connection = await aiosqlite.connect(self.path) + self.connection.row_factory = aiosqlite.Row + self.cursor = await self.connection.cursor() return self - def __exit__(self, exc_type, exc_val, exc_tb): + async def __aenter__(self): + self.cursor = await self.connection.cursor() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): if exc_val: - self.connection.rollback() - self.cursor.close() - self.cursor.close() + await self.connection.rollback() else: - self.connection.commit() - self.cursor.close() - self.connection.close() + await self.connection.commit() - def commit(self): - self.connection.commit() + async def commit(self): + await self.connection.commit() - def rollback(self): - self.connection.rollback() + async def rollback(self): + await self.connection.rollback() - def fetchall(self, query: str, values: tuple = ()) -> list: - """Given a query, return cursor.fetchall() rows.""" - self.execute(query, values) - return self.cursor.fetchall() + async def fetchall(self, query: str, values: tuple = ()) -> list: + await self.execute(query, values) + return await self.cursor.fetchall() - def fetchone(self, query: str, values: tuple = ()): - self.execute(query, values) - return self.cursor.fetchone() + async def fetchone(self, query: str, values: tuple = ()): + await self.execute(query, values) + return await self.cursor.fetchone() - def execute(self, query: str, values: tuple = ()) -> None: - """Given a query, cursor.execute() it.""" + async def execute(self, query: str, values: tuple = ()) -> None: try: - self.cursor.execute(query, values) - except sqlite3.Error as exc: - self.connection.rollback() + await self.cursor.execute(query, values) + except aiosqlite.Error as exc: + print("sqlite error", exc) + await self.rollback() raise exc -def open_db(db_name: str = "database") -> Database: - db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") - return Database(db_path=db_path) +_db_objects: Dict[str, Database] = {} + + +async def open_db(db_name: str = "database") -> Database: + try: + return _db_objects[db_name] + except KeyError: + db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") + _db_objects[db_name] = await Database(db_path).connect() + return _db_objects[db_name] -def open_ext_db(extension_name: str) -> Database: - return open_db(f"ext_{extension_name}") +async def open_ext_db(extension_name: str) -> Database: + return await open_db(f"ext_{extension_name}") diff --git a/lnbits/decorators.py b/lnbits/decorators.py index ac73e4e..f53c7d6 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -14,7 +14,7 @@ def api_check_wallet_key(key_type: str = "invoice"): @wraps(view) async def wrapped_view(**kwargs): try: - g.wallet = get_wallet_for_key(request.headers["X-Api-Key"], key_type) + g.wallet = await get_wallet_for_key(request.headers["X-Api-Key"], key_type) except KeyError: return ( jsonify({"message": "`X-Api-Key` header missing."}), @@ -62,7 +62,9 @@ def check_user_exists(param: str = "usr"): def wrap(view): @wraps(view) async def wrapped_view(**kwargs): - g.user = get_user(request.args.get(param, type=str)) or abort(HTTPStatus.NOT_FOUND, "User does not exist.") + g.user = await get_user(request.args.get(param, type=str)) + if not g.user: + abort(HTTPStatus.NOT_FOUND, "User does not exist.") if LNBITS_ALLOWED_USERS and g.user.id not in LNBITS_ALLOWED_USERS: abort(HTTPStatus.UNAUTHORIZED, "User not authorized.") diff --git a/lnbits/extensions/lnurlp/crud.py b/lnbits/extensions/lnurlp/crud.py index adebb84..ab5f0be 100644 --- a/lnbits/extensions/lnurlp/crud.py +++ b/lnbits/extensions/lnurlp/crud.py @@ -1,109 +1,102 @@ from typing import List, Optional, Union +from quart import g + from lnbits import bolt11 -from lnbits.db import open_ext_db from .models import PayLink -def create_pay_link(*, wallet_id: str, description: str, amount: int, webhook_url: str) -> Optional[PayLink]: - with open_ext_db("lnurlp") as db: - db.execute( - """ - INSERT INTO pay_links ( - wallet, - description, - amount, - served_meta, - served_pr, - webhook_url - ) - VALUES (?, ?, ?, 0, 0, ?) - """, - (wallet_id, description, amount, webhook_url), +async def create_pay_link(*, wallet_id: str, description: str, amount: int, webhook_url: str) -> Optional[PayLink]: + await g.db.execute( + """ + INSERT INTO pay_links ( + wallet, + description, + amount, + served_meta, + served_pr, + webhook_url ) - link_id = db.cursor.lastrowid - return get_pay_link(link_id) + VALUES (?, ?, ?, 0, 0, ?) + """, + (wallet_id, description, amount, webhook_url), + ) + + link_id = g.ext_db.cursor.lastrowid + return await get_pay_link(link_id) -def get_pay_link(link_id: int) -> Optional[PayLink]: - with open_ext_db("lnurlp") as db: - row = db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,)) +async def get_pay_link(link_id: int) -> Optional[PayLink]: + row = await g.ext_db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,)) return PayLink.from_row(row) if row else None -def get_pay_link_by_invoice(payment_hash: str) -> Optional[PayLink]: +async def get_pay_link_by_invoice(payment_hash: str) -> Optional[PayLink]: # this excludes invoices with webhooks that have been sent already - with open_ext_db("lnurlp") as db: - row = db.fetchone( - """ - SELECT pay_links.* FROM pay_links - INNER JOIN invoices ON invoices.pay_link = pay_links.id - WHERE payment_hash = ? AND webhook_sent IS NULL - """, - (payment_hash,), - ) + row = await g.db.fetchone( + """ + SELECT pay_links.* FROM pay_links + INNER JOIN invoices ON invoices.pay_link = pay_links.id + WHERE payment_hash = ? AND webhook_sent IS NULL + """, + (payment_hash,), + ) return PayLink.from_row(row) if row else None -def get_pay_links(wallet_ids: Union[str, List[str]]) -> List[PayLink]: +async def get_pay_links(wallet_ids: Union[str, List[str]]) -> List[PayLink]: if isinstance(wallet_ids, str): wallet_ids = [wallet_ids] - with open_ext_db("lnurlp") as db: - q = ",".join(["?"] * len(wallet_ids)) - rows = db.fetchall(f"SELECT * FROM pay_links WHERE wallet IN ({q})", (*wallet_ids,)) + q = ",".join(["?"] * len(wallet_ids)) + rows = await g.ext_db.fetchall(f"SELECT * FROM pay_links WHERE wallet IN ({q})", (*wallet_ids,)) return [PayLink.from_row(row) for row in rows] -def update_pay_link(link_id: int, **kwargs) -> Optional[PayLink]: +async def update_pay_link(link_id: int, **kwargs) -> Optional[PayLink]: q = ", ".join([f"{field[0]} = ?" for field in kwargs.items()]) - with open_ext_db("lnurlp") as db: - db.execute(f"UPDATE pay_links SET {q} WHERE id = ?", (*kwargs.values(), link_id)) - row = db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,)) + await g.ext_db.execute(f"UPDATE pay_links SET {q} WHERE id = ?", (*kwargs.values(), link_id)) + row = await g.ext_db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,)) return PayLink.from_row(row) if row else None -def increment_pay_link(link_id: int, **kwargs) -> Optional[PayLink]: +async def increment_pay_link(link_id: int, **kwargs) -> Optional[PayLink]: q = ", ".join([f"{field[0]} = {field[0]} + ?" for field in kwargs.items()]) - with open_ext_db("lnurlp") as db: - db.execute(f"UPDATE pay_links SET {q} WHERE id = ?", (*kwargs.values(), link_id)) - row = db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,)) + await g.ext_db.execute(f"UPDATE pay_links SET {q} WHERE id = ?", (*kwargs.values(), link_id)) + row = await g.ext_db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,)) return PayLink.from_row(row) if row else None -def delete_pay_link(link_id: int) -> None: - with open_ext_db("lnurlp") as db: - db.execute("DELETE FROM pay_links WHERE id = ?", (link_id,)) +async def delete_pay_link(link_id: int) -> None: + await g.ext_db.execute("DELETE FROM pay_links WHERE id = ?", (link_id,)) -def save_link_invoice(link_id: int, payment_request: str) -> None: +async def save_link_invoice(link_id: int, payment_request: str) -> None: inv = bolt11.decode(payment_request) - with open_ext_db("lnurlp") as db: - db.execute( - """ - INSERT INTO invoices (pay_link, payment_hash, expiry) - VALUES (?, ?, ?) - """, - (link_id, inv.payment_hash, inv.expiry), - ) + await g.db.execute( + """ + INSERT INTO invoices (pay_link, payment_hash, expiry) + VALUES (?, ?, ?) + """, + (link_id, inv.payment_hash, inv.expiry), + ) def mark_webhook_sent(payment_hash: str, status: int) -> None: - with open_ext_db("lnurlp") as db: - db.execute( - """ - UPDATE invoices SET webhook_sent = ? - WHERE payment_hash = ? - """, - (status, payment_hash), - ) + await g.db.execute( + """ + UPDATE invoices SET webhook_sent = ? + WHERE payment_hash = ? + """, + (status, payment_hash), + ) diff --git a/lnbits/extensions/lnurlp/lnurl.py b/lnbits/extensions/lnurlp/lnurl.py index 747e5ba..f2441d0 100644 --- a/lnbits/extensions/lnurlp/lnurl.py +++ b/lnbits/extensions/lnurlp/lnurl.py @@ -12,7 +12,7 @@ from .crud import increment_pay_link, save_link_invoice @lnurlp_ext.route("/api/v1/lnurl/", methods=["GET"]) async def api_lnurl_response(link_id): - link = increment_pay_link(link_id, served_meta=1) + link = await increment_pay_link(link_id, served_meta=1) if not link: return jsonify({"status": "ERROR", "reason": "LNURL-pay not found."}), HTTPStatus.OK @@ -30,7 +30,7 @@ async def api_lnurl_response(link_id): @lnurlp_ext.route("/api/v1/lnurl/cb/", methods=["GET"]) async def api_lnurl_callback(link_id): - link = increment_pay_link(link_id, served_pr=1) + link = await increment_pay_link(link_id, served_pr=1) if not link: return jsonify({"status": "ERROR", "reason": "LNURL-pay not found."}), HTTPStatus.OK diff --git a/lnbits/extensions/lnurlp/migrations.py b/lnbits/extensions/lnurlp/migrations.py index d9c61d3..a129309 100644 --- a/lnbits/extensions/lnurlp/migrations.py +++ b/lnbits/extensions/lnurlp/migrations.py @@ -1,8 +1,8 @@ -def m001_initial(db): +async def m001_initial(db): """ Initial pay table. """ - db.execute( + await db.execute( """ CREATE TABLE IF NOT EXISTS pay_links ( id INTEGER PRIMARY KEY AUTOINCREMENT, diff --git a/lnbits/extensions/lnurlp/views.py b/lnbits/extensions/lnurlp/views.py index 25d02e9..137fc5b 100644 --- a/lnbits/extensions/lnurlp/views.py +++ b/lnbits/extensions/lnurlp/views.py @@ -16,11 +16,11 @@ async def index(): @lnurlp_ext.route("/") async def display(link_id): - link = get_pay_link(link_id) or abort(HTTPStatus.NOT_FOUND, "Pay link does not exist.") + link = await get_pay_link(link_id) or abort(HTTPStatus.NOT_FOUND, "Pay link does not exist.") return await render_template("lnurlp/display.html", link=link) @lnurlp_ext.route("/print/") async def print_qr(link_id): - link = get_pay_link(link_id) or abort(HTTPStatus.NOT_FOUND, "Pay link does not exist.") + link = await get_pay_link(link_id) or abort(HTTPStatus.NOT_FOUND, "Pay link does not exist.") return await render_template("lnurlp/print_qr.html", link=link) diff --git a/lnbits/extensions/lnurlp/views_api.py b/lnbits/extensions/lnurlp/views_api.py index e548547..f4fad5b 100644 --- a/lnbits/extensions/lnurlp/views_api.py +++ b/lnbits/extensions/lnurlp/views_api.py @@ -25,7 +25,7 @@ async def api_links(): try: return ( - jsonify([{**link._asdict(), **{"lnurl": link.lnurl}} for link in get_pay_links(wallet_ids)]), + jsonify([{**link._asdict(), **{"lnurl": link.lnurl}} for link in await get_pay_links(wallet_ids)]), HTTPStatus.OK, ) except LnurlInvalidUrl: @@ -38,7 +38,7 @@ async def api_links(): @lnurlp_ext.route("/api/v1/links/", methods=["GET"]) @api_check_wallet_key("invoice") async def api_link_retrieve(link_id): - link = get_pay_link(link_id) + link = await get_pay_link(link_id) if not link: return jsonify({"message": "Pay link does not exist."}), HTTPStatus.NOT_FOUND @@ -61,7 +61,7 @@ async def api_link_retrieve(link_id): ) async def api_link_create_or_update(link_id=None): if link_id: - link = get_pay_link(link_id) + link = await get_pay_link(link_id) if not link: return jsonify({"message": "Pay link does not exist."}), HTTPStatus.NOT_FOUND @@ -69,9 +69,9 @@ async def api_link_create_or_update(link_id=None): if link.wallet != g.wallet.id: return jsonify({"message": "Not your pay link."}), HTTPStatus.FORBIDDEN - link = update_pay_link(link_id, **g.data) + link = await update_pay_link(link_id, **g.data) else: - link = create_pay_link(wallet_id=g.wallet.id, **g.data) + link = await create_pay_link(wallet_id=g.wallet.id, **g.data) return jsonify({**link._asdict(), **{"lnurl": link.lnurl}}), HTTPStatus.OK if link_id else HTTPStatus.CREATED @@ -79,7 +79,7 @@ async def api_link_create_or_update(link_id=None): @lnurlp_ext.route("/api/v1/links/", methods=["DELETE"]) @api_check_wallet_key("invoice") async def api_link_delete(link_id): - link = get_pay_link(link_id) + link = await get_pay_link(link_id) if not link: return jsonify({"message": "Pay link does not exist."}), HTTPStatus.NOT_FOUND @@ -87,6 +87,5 @@ async def api_link_delete(link_id): if link.wallet != g.wallet.id: return jsonify({"message": "Not your pay link."}), HTTPStatus.FORBIDDEN - delete_pay_link(link_id) - + await delete_pay_link(link_id) return "", HTTPStatus.NO_CONTENT