Browse Source

make all dbs and crud calls async

(with aiosqlite at least for now).
core and lnurlp almost ported. other extensions waiting.
aiosqlite
fiatjaf 4 years ago
parent
commit
1f2fd2280a
  1. 1
      Pipfile
  2. 5
      lnbits/__main__.py
  3. 46
      lnbits/app.py
  4. 23
      lnbits/commands.py
  5. 72
      lnbits/core/crud.py
  6. 20
      lnbits/core/models.py
  7. 32
      lnbits/core/services.py
  8. 20
      lnbits/core/views/api.py
  9. 17
      lnbits/core/views/generic.py
  10. 6
      lnbits/core/views/lnurl.py
  11. 73
      lnbits/db.py
  12. 6
      lnbits/decorators.py
  13. 119
      lnbits/extensions/lnurlp/crud.py
  14. 4
      lnbits/extensions/lnurlp/lnurl.py
  15. 4
      lnbits/extensions/lnurlp/migrations.py
  16. 4
      lnbits/extensions/lnurlp/views.py
  17. 15
      lnbits/extensions/lnurlp/views_api.py

1
Pipfile

@ -22,6 +22,7 @@ secure = "*"
typing-extensions = "*"
aiohttp = "*"
aiohttp-sse-client = "*"
aiosqlite = "*"
[dev-packages]
black = "==20.8b1"

5
lnbits/__main__.py

@ -1,9 +1,4 @@
from .app import create_app
from .commands import migrate_databases, transpile_scss, bundle_vendored
migrate_databases()
transpile_scss()
bundle_vendored()
app = create_app()
app.run(host=app.config["HOST"], port=app.config["PORT"])

46
lnbits/app.py

@ -1,14 +1,14 @@
import importlib
import asyncio
from quart import Quart, g
from quart import Quart, Blueprint, g
from quart_cors import cors # type: ignore
from quart_compress import Compress # type: ignore
from secure import SecureHeaders # type: ignore
from .commands import db_migrate
from .core import core_app
from .db import open_db
from .db import open_db, open_ext_db
from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored
from .proxy_fix import ASGIProxyFix
@ -26,6 +26,7 @@ def create_app(config_object="lnbits.settings") -> Quart:
cors(app)
Compress(app)
register_preparation_actions(app)
register_assets(app)
register_blueprints(app)
register_filters(app)
@ -36,6 +37,19 @@ def create_app(config_object="lnbits.settings") -> Quart:
return app
def register_preparation_actions(app):
"""Actions we will perform before serving, but in the main event loop."""
from .commands import migrate_databases, transpile_scss, bundle_vendored
@app.before_serving
async def preparation_tasks():
await migrate_databases()
transpile_scss()
bundle_vendored()
def register_blueprints(app: Quart) -> None:
"""Register Flask blueprints / LNbits extensions."""
app.register_blueprint(core_app)
@ -43,11 +57,25 @@ def register_blueprints(app: Quart) -> None:
for ext in get_valid_extensions():
try:
ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
app.register_blueprint(getattr(ext_module, f"{ext.code}_ext"), url_prefix=f"/{ext.code}")
bp = getattr(ext_module, f"{ext.code}_ext")
register_request_hooks(bp)
app.register_blueprint(bp, url_prefix=f"/{ext.code}")
except Exception:
raise ImportError(f"Please make sure that the extension `{ext.code}` follows conventions.")
def register_blueprint_hooks(bp: Blueprint) -> None:
@bp.before_request
async def before_request():
g.ext_db = await open_ext_db(bp.name)
@bp.teardown_request
async def after_request(exc):
await g.ext_db.__aexit__(type(exc), exc, None)
def register_commands(app: Quart):
"""Register Click commands."""
app.cli.add_command(db_migrate)
@ -73,20 +101,18 @@ def register_filters(app: Quart):
def register_request_hooks(app: Quart):
"""Open the core db for each request so everything happens in a big transaction"""
@app.before_request
async def before_request():
g.db = open_db()
@app.after_request
async def set_secure_headers(response):
secure_headers.quart(response)
return response
@app.before_request
async def before_request():
g.db = await open_db()
@app.teardown_request
async def after_request(exc):
g.db.__exit__(type(exc), exc, None)
await g.db.__aexit__(type(exc), exc, None)
def register_async_tasks(app):

23
lnbits/commands.py

@ -1,3 +1,4 @@
import asyncio
import click
import importlib
import re
@ -14,7 +15,7 @@ from .settings import LNBITS_PATH
@click.command("migrate")
def db_migrate():
migrate_databases()
asyncio.run(migrate_databases())
@click.command("assets")
@ -42,21 +43,21 @@ def bundle_vendored():
f.write(output)
def migrate_databases():
async def migrate_databases():
"""Creates the necessary databases if they don't exist already; or migrates them."""
with open_db() as core_db:
async with await open_db() as core_db:
try:
rows = core_db.fetchall("SELECT * FROM dbversions")
rows = await core_db.fetchall("SELECT * FROM dbversions")
except sqlite3.OperationalError:
# migration 3 wasn't ran
core_migrations.m000_create_migrations_table(core_db)
rows = core_db.fetchall("SELECT * FROM dbversions")
rows = await core_db.fetchall("SELECT * FROM dbversions")
current_versions = {row["db"]: row["version"] for row in rows}
matcher = re.compile(r"^m(\d\d\d)_")
def run_migration(db, migrations_module):
async def run_migration(db, migrations_module):
db_name = migrations_module.__name__.split(".")[-2]
for key, run_migration in migrations_module.__dict__.items():
match = match = matcher.match(key)
@ -64,17 +65,17 @@ def migrate_databases():
version = int(match.group(1))
if version > current_versions.get(db_name, 0):
print(f"running migration {db_name}.{version}")
run_migration(db)
core_db.execute(
await run_migration(db)
await core_db.execute(
"INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)", (db_name, version)
)
run_migration(core_db, core_migrations)
await run_migration(core_db, core_migrations)
for ext in get_valid_extensions():
try:
ext_migrations = importlib.import_module(f"lnbits.extensions.{ext.code}.migrations")
with open_ext_db(ext.code) as db:
run_migration(db, ext_migrations)
async with await open_ext_db(ext.code) as db:
await run_migration(db, ext_migrations)
except ImportError:
raise ImportError(f"Please make sure that the extension `{ext.code}` has a migrations file.")

72
lnbits/core/crud.py

@ -14,9 +14,9 @@ from .models import User, Wallet, Payment
# --------
def create_account() -> User:
async def create_account() -> User:
user_id = uuid4().hex
g.db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
await g.db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
new_account = get_account(user_id=user_id)
assert new_account, "Newly created account couldn't be retrieved"
@ -24,18 +24,18 @@ def create_account() -> User:
return new_account
def get_account(user_id: str) -> Optional[User]:
row = g.db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,))
async def get_account(user_id: str) -> Optional[User]:
row = await g.db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,))
return User(**row) if row else None
def get_user(user_id: str) -> Optional[User]:
user = g.db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,))
async def get_user(user_id: str) -> Optional[User]:
user = await g.db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,))
if user:
extensions = g.db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,))
wallets = g.db.fetchall(
extensions = await g.db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,))
wallets = await g.db.fetchall(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets
@ -51,8 +51,8 @@ def get_user(user_id: str) -> Optional[User]:
)
def update_user_extension(*, user_id: str, extension: str, active: int) -> None:
g.db.execute(
async def update_user_extension(*, user_id: str, extension: str, active: int) -> None:
await g.db.execute(
"""
INSERT OR REPLACE INTO extensions (user, extension, active)
VALUES (?, ?, ?)
@ -65,9 +65,9 @@ def update_user_extension(*, user_id: str, extension: str, active: int) -> None:
# -------
def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
async def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
wallet_id = uuid4().hex
g.db.execute(
await g.db.execute(
"""
INSERT INTO wallets (id, name, user, adminkey, inkey)
VALUES (?, ?, ?, ?, ?)
@ -81,8 +81,8 @@ def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
return new_wallet
def delete_wallet(*, user_id: str, wallet_id: str) -> None:
g.db.execute(
async def delete_wallet(*, user_id: str, wallet_id: str) -> None:
await g.db.execute(
"""
UPDATE wallets AS w
SET
@ -95,8 +95,8 @@ def delete_wallet(*, user_id: str, wallet_id: str) -> None:
)
def get_wallet(wallet_id: str) -> Optional[Wallet]:
row = g.db.fetchone(
async def get_wallet(wallet_id: str) -> Optional[Wallet]:
row = await g.db.fetchone(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets
@ -108,8 +108,8 @@ def get_wallet(wallet_id: str) -> Optional[Wallet]:
return Wallet(**row) if row else None
def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]:
row = g.db.fetchone(
async def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]:
row = await g.db.fetchone(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets
@ -131,8 +131,8 @@ def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]:
# ---------------
def get_standalone_payment(checking_id: str) -> Optional[Payment]:
row = g.db.fetchone(
async def get_standalone_payment(checking_id: str) -> Optional[Payment]:
row = await g.db.fetchone(
"""
SELECT *
FROM apipayments
@ -144,8 +144,8 @@ def get_standalone_payment(checking_id: str) -> Optional[Payment]:
return Payment.from_row(row) if row else None
def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]:
row = g.db.fetchone(
async def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]:
row = await g.db.fetchone(
"""
SELECT *
FROM apipayments
@ -157,7 +157,7 @@ def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]:
return Payment.from_row(row) if row else None
def get_wallet_payments(
async def get_wallet_payments(
wallet_id: str,
*,
complete: bool = False,
@ -197,7 +197,7 @@ def get_wallet_payments(
clause += "AND checking_id NOT LIKE 'temp_%' "
clause += "AND checking_id NOT LIKE 'internal_%' "
rows = g.db.fetchall(
rows = await g.db.fetchall(
f"""
SELECT *
FROM apipayments
@ -210,8 +210,8 @@ def get_wallet_payments(
return [Payment.from_row(row) for row in rows]
def delete_expired_invoices() -> None:
rows = g.db.fetchall(
async def delete_expired_invoices() -> None:
rows = await g.db.fetchall(
"""
SELECT bolt11
FROM apipayments
@ -228,7 +228,7 @@ def delete_expired_invoices() -> None:
if expiration_date > datetime.datetime.utcnow():
continue
g.db.execute(
await g.db.execute(
"""
DELETE FROM apipayments
WHERE pending = 1 AND hash = ?
@ -241,7 +241,7 @@ def delete_expired_invoices() -> None:
# --------
def create_payment(
async def create_payment(
*,
wallet_id: str,
checking_id: str,
@ -254,7 +254,7 @@ def create_payment(
pending: bool = True,
extra: Optional[Dict] = None,
) -> Payment:
g.db.execute(
await g.db.execute(
"""
INSERT INTO apipayments
(wallet, checking_id, bolt11, hash, preimage,
@ -275,14 +275,14 @@ def create_payment(
),
)
new_payment = get_wallet_payment(wallet_id, payment_hash)
new_payment = await get_wallet_payment(wallet_id, payment_hash)
assert new_payment, "Newly created payment couldn't be retrieved"
return new_payment
def update_payment_status(checking_id: str, pending: bool) -> None:
g.db.execute(
async def update_payment_status(checking_id: str, pending: bool) -> None:
await g.db.execute(
"UPDATE apipayments SET pending = ? WHERE checking_id = ?",
(
int(pending),
@ -291,12 +291,12 @@ def update_payment_status(checking_id: str, pending: bool) -> None:
)
def delete_payment(checking_id: str) -> None:
g.db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,))
async def delete_payment(checking_id: str) -> None:
await g.db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,))
def check_internal(payment_hash: str) -> Optional[str]:
row = g.db.fetchone(
async def check_internal(payment_hash: str) -> Optional[str]:
row = await g.db.fetchone(
"""
SELECT checking_id FROM apipayments
WHERE hash = ? AND pending AND amount > 0

20
lnbits/core/models.py

@ -33,12 +33,12 @@ class Wallet(NamedTuple):
def balance(self) -> int:
return self.balance_msat // 1000
def get_payment(self, payment_hash: str) -> Optional["Payment"]:
async def get_payment(self, payment_hash: str) -> Optional["Payment"]:
from .crud import get_wallet_payment
return get_wallet_payment(self.id, payment_hash)
return await get_wallet_payment(self.id, payment_hash)
def get_payments(
async def get_payments(
self,
*,
complete: bool = True,
@ -49,7 +49,7 @@ class Wallet(NamedTuple):
) -> List["Payment"]:
from .crud import get_wallet_payments
return get_wallet_payments(
return await get_wallet_payments(
self.id,
complete=complete,
pending=pending,
@ -110,12 +110,12 @@ class Payment(NamedTuple):
def is_uncheckable(self) -> bool:
return self.checking_id.startswith("temp_") or self.checking_id.startswith("internal_")
def set_pending(self, pending: bool) -> None:
async def set_pending(self, pending: bool) -> None:
from .crud import update_payment_status
update_payment_status(self.checking_id, pending)
await update_payment_status(self.checking_id, pending)
def check_pending(self) -> None:
async def check_pending(self) -> None:
if self.is_uncheckable:
return
@ -124,9 +124,9 @@ class Payment(NamedTuple):
else:
pending = WALLET.get_invoice_status(self.checking_id)
self.set_pending(pending.pending)
await self.set_pending(pending.pending)
def delete(self) -> None:
async def delete(self) -> None:
from .crud import delete_payment
delete_payment(self.checking_id)
await delete_payment(self.checking_id)

32
lnbits/core/services.py

@ -14,7 +14,7 @@ from lnbits.wallets.base import PaymentStatus
from .crud import get_wallet, create_payment, delete_payment, check_internal, update_payment_status, get_wallet_payment
def create_invoice(
async def create_invoice(
*,
wallet_id: str,
amount: int,
@ -34,7 +34,7 @@ def create_invoice(
invoice = bolt11.decode(payment_request)
amount_msat = amount * 1000
create_payment(
await create_payment(
wallet_id=wallet_id,
checking_id=checking_id,
payment_request=payment_request,
@ -44,11 +44,11 @@ def create_invoice(
extra=extra,
)
g.db.commit()
await g.db.commit()
return invoice.payment_hash, payment_request
def pay_invoice(
async def pay_invoice(
*, wallet_id: str, payment_request: str, max_sat: Optional[int] = None, extra: Optional[Dict] = None
) -> str:
temp_id = f"temp_{urlsafe_short_hash()}"
@ -82,45 +82,45 @@ def pay_invoice(
)
# check_internal() returns the checking_id of the invoice we're waiting for
internal = check_internal(invoice.payment_hash)
internal = await check_internal(invoice.payment_hash)
if internal:
# create a new payment from this wallet
create_payment(checking_id=internal_id, fee=0, pending=False, **payment_kwargs)
await create_payment(checking_id=internal_id, fee=0, pending=False, **payment_kwargs)
else:
# create a temporary payment here so we can check if
# the balance is enough in the next step
fee_reserve = max(1000, int(invoice.amount_msat * 0.01))
create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs)
await create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs)
# do the balance check
wallet = get_wallet(wallet_id)
wallet = await get_wallet(wallet_id)
assert wallet, "invalid wallet id"
if wallet.balance_msat < 0:
g.db.rollback()
await g.db.rollback()
raise PermissionError("Insufficient balance.")
else:
g.db.commit()
await g.db.commit()
if internal:
# mark the invoice from the other side as not pending anymore
# so the other side only has access to his new money when we are sure
# the payer has enough to deduct from
update_payment_status(checking_id=internal, pending=False)
await update_payment_status(checking_id=internal, pending=False)
else:
# actually pay the external invoice
ok, checking_id, fee_msat, error_message = WALLET.pay_invoice(payment_request)
if ok:
create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs)
delete_payment(temp_id)
await create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs)
await delete_payment(temp_id)
else:
raise Exception(error_message or "Failed to pay_invoice on backend.")
g.db.commit()
await g.db.commit()
return invoice.payment_hash
def check_invoice_status(wallet_id: str, payment_hash: str) -> PaymentStatus:
payment = get_wallet_payment(wallet_id, payment_hash)
async def check_invoice_status(wallet_id: str, payment_hash: str) -> PaymentStatus:
payment = await get_wallet_payment(wallet_id, payment_hash)
if not payment:
return PaymentStatus(None)

20
lnbits/core/views/api.py

@ -13,12 +13,12 @@ from lnbits.decorators import api_check_wallet_key, api_validate_post_request
@api_check_wallet_key("invoice")
async def api_payments():
if "check_pending" in request.args:
delete_expired_invoices()
await delete_expired_invoices()
for payment in g.wallet.get_payments(complete=False, pending=True, exclude_uncheckable=True):
payment.check_pending()
for payment in await g.wallet.get_payments(complete=False, pending=True, exclude_uncheckable=True):
await payment.check_pending()
return jsonify(g.wallet.get_payments()), HTTPStatus.OK
return jsonify(await g.wallet.get_payments()), HTTPStatus.OK
@api_check_wallet_key("invoice")
@ -38,11 +38,11 @@ async def api_payments_create_invoice():
memo = g.data["memo"]
try:
payment_hash, payment_request = create_invoice(
payment_hash, payment_request = await create_invoice(
wallet_id=g.wallet.id, amount=g.data["amount"], memo=memo, description_hash=description_hash
)
except Exception as e:
g.db.rollback()
await g.db.rollback()
return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR
invoice = bolt11.decode(payment_request)
@ -63,14 +63,14 @@ async def api_payments_create_invoice():
@api_validate_post_request(schema={"bolt11": {"type": "string", "empty": False, "required": True}})
async def api_payments_pay_invoice():
try:
payment_hash = pay_invoice(wallet_id=g.wallet.id, payment_request=g.data["bolt11"])
payment_hash = await pay_invoice(wallet_id=g.wallet.id, payment_request=g.data["bolt11"])
except ValueError as e:
return jsonify({"message": str(e)}), HTTPStatus.BAD_REQUEST
except PermissionError as e:
return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN
except Exception as e:
print(e)
g.db.rollback()
await g.db.rollback()
return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR
return (
@ -96,7 +96,7 @@ async def api_payments_create():
@core_app.route("/api/v1/payments/<payment_hash>", methods=["GET"])
@api_check_wallet_key("invoice")
async def api_payment(payment_hash):
payment = g.wallet.get_payment(payment_hash)
payment = await g.wallet.get_payment(payment_hash)
if not payment:
return jsonify({"message": "Payment does not exist."}), HTTPStatus.NOT_FOUND
@ -104,7 +104,7 @@ async def api_payment(payment_hash):
return jsonify({"paid": True}), HTTPStatus.OK
try:
payment.check_pending()
await payment.check_pending()
except Exception:
return jsonify({"paid": False}), HTTPStatus.OK

17
lnbits/core/views/generic.py

@ -36,11 +36,11 @@ async def extensions():
abort(HTTPStatus.BAD_REQUEST, "You can either `enable` or `disable` an extension.")
if extension_to_enable:
update_user_extension(user_id=g.user.id, extension=extension_to_enable, active=1)
await update_user_extension(user_id=g.user.id, extension=extension_to_enable, active=1)
elif extension_to_disable:
update_user_extension(user_id=g.user.id, extension=extension_to_disable, active=0)
await update_user_extension(user_id=g.user.id, extension=extension_to_disable, active=0)
return await render_template("core/extensions.html", user=get_user(g.user.id))
return await render_template("core/extensions.html", user=g.user)
@core_app.route("/wallet")
@ -58,9 +58,12 @@ async def wallet():
# nothing: create everything
if not user_id:
user = get_user(create_account().id)
account = await create_account()
user = await get_user(account.id)
else:
user = get_user(user_id) or abort(HTTPStatus.NOT_FOUND, "User does not exist.")
user = await get_user(user_id)
if not user:
abort(HTTPStatus.NOT_FOUND, "User does not exist.")
if LNBITS_ALLOWED_USERS and user_id not in LNBITS_ALLOWED_USERS:
abort(HTTPStatus.UNAUTHORIZED, "User not authorized.")
@ -69,7 +72,7 @@ async def wallet():
if user.wallets and not wallet_name:
wallet = user.wallets[0]
else:
wallet = create_wallet(user_id=user.id, wallet_name=wallet_name)
wallet = await create_wallet(user_id=user.id, wallet_name=wallet_name)
return redirect(url_for("core.wallet", usr=user.id, wal=wallet.id))
@ -91,7 +94,7 @@ async def deletewallet():
if wallet_id not in user_wallet_ids:
abort(HTTPStatus.FORBIDDEN, "Not your wallet.")
else:
delete_wallet(user_id=g.user.id, wallet_id=wallet_id)
await delete_wallet(user_id=g.user.id, wallet_id=wallet_id)
user_wallet_ids.remove(wallet_id)
if user_wallet_ids:

6
lnbits/core/views/lnurl.py

@ -51,9 +51,9 @@ async def lnurlwallet():
continue
break
user = get_user(create_account().id)
wallet = create_wallet(user_id=user.id)
create_payment(
user = await get_user(await create_account().id)
wallet = await create_wallet(user_id=user.id)
await create_payment(
wallet_id=wallet.id,
checking_id=checking_id,
amount=withdraw_res.max_sats * 1000,

73
lnbits/db.py

@ -1,5 +1,6 @@
import os
import sqlite3
import aiosqlite
from typing import Dict
from .settings import LNBITS_DATA_FOLDER
@ -7,51 +8,57 @@ from .settings import LNBITS_DATA_FOLDER
class Database:
def __init__(self, db_path: str):
self.path = db_path
self.connection = sqlite3.connect(db_path)
self.connection.row_factory = sqlite3.Row
self.cursor = self.connection.cursor()
def __enter__(self):
async def connect(self):
self.connection = await aiosqlite.connect(self.path)
self.connection.row_factory = aiosqlite.Row
self.cursor = await self.connection.cursor()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
async def __aenter__(self):
self.cursor = await self.connection.cursor()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_val:
self.connection.rollback()
self.cursor.close()
self.cursor.close()
await self.connection.rollback()
else:
self.connection.commit()
self.cursor.close()
self.connection.close()
await self.connection.commit()
def commit(self):
self.connection.commit()
async def commit(self):
await self.connection.commit()
def rollback(self):
self.connection.rollback()
async def rollback(self):
await self.connection.rollback()
def fetchall(self, query: str, values: tuple = ()) -> list:
"""Given a query, return cursor.fetchall() rows."""
self.execute(query, values)
return self.cursor.fetchall()
async def fetchall(self, query: str, values: tuple = ()) -> list:
await self.execute(query, values)
return await self.cursor.fetchall()
def fetchone(self, query: str, values: tuple = ()):
self.execute(query, values)
return self.cursor.fetchone()
async def fetchone(self, query: str, values: tuple = ()):
await self.execute(query, values)
return await self.cursor.fetchone()
def execute(self, query: str, values: tuple = ()) -> None:
"""Given a query, cursor.execute() it."""
async def execute(self, query: str, values: tuple = ()) -> None:
try:
self.cursor.execute(query, values)
except sqlite3.Error as exc:
self.connection.rollback()
await self.cursor.execute(query, values)
except aiosqlite.Error as exc:
print("sqlite error", exc)
await self.rollback()
raise exc
def open_db(db_name: str = "database") -> Database:
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3")
return Database(db_path=db_path)
_db_objects: Dict[str, Database] = {}
async def open_db(db_name: str = "database") -> Database:
try:
return _db_objects[db_name]
except KeyError:
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3")
_db_objects[db_name] = await Database(db_path).connect()
return _db_objects[db_name]
def open_ext_db(extension_name: str) -> Database:
return open_db(f"ext_{extension_name}")
async def open_ext_db(extension_name: str) -> Database:
return await open_db(f"ext_{extension_name}")

6
lnbits/decorators.py

@ -14,7 +14,7 @@ def api_check_wallet_key(key_type: str = "invoice"):
@wraps(view)
async def wrapped_view(**kwargs):
try:
g.wallet = get_wallet_for_key(request.headers["X-Api-Key"], key_type)
g.wallet = await get_wallet_for_key(request.headers["X-Api-Key"], key_type)
except KeyError:
return (
jsonify({"message": "`X-Api-Key` header missing."}),
@ -62,7 +62,9 @@ def check_user_exists(param: str = "usr"):
def wrap(view):
@wraps(view)
async def wrapped_view(**kwargs):
g.user = get_user(request.args.get(param, type=str)) or abort(HTTPStatus.NOT_FOUND, "User does not exist.")
g.user = await get_user(request.args.get(param, type=str))
if not g.user:
abort(HTTPStatus.NOT_FOUND, "User does not exist.")
if LNBITS_ALLOWED_USERS and g.user.id not in LNBITS_ALLOWED_USERS:
abort(HTTPStatus.UNAUTHORIZED, "User not authorized.")

119
lnbits/extensions/lnurlp/crud.py

@ -1,109 +1,102 @@
from typing import List, Optional, Union
from quart import g
from lnbits import bolt11
from lnbits.db import open_ext_db
from .models import PayLink
def create_pay_link(*, wallet_id: str, description: str, amount: int, webhook_url: str) -> Optional[PayLink]:
with open_ext_db("lnurlp") as db:
db.execute(
"""
INSERT INTO pay_links (
wallet,
description,
amount,
served_meta,
served_pr,
webhook_url
)
VALUES (?, ?, ?, 0, 0, ?)
""",
(wallet_id, description, amount, webhook_url),
async def create_pay_link(*, wallet_id: str, description: str, amount: int, webhook_url: str) -> Optional[PayLink]:
await g.db.execute(
"""
INSERT INTO pay_links (
wallet,
description,
amount,
served_meta,
served_pr,
webhook_url
)
link_id = db.cursor.lastrowid
return get_pay_link(link_id)
VALUES (?, ?, ?, 0, 0, ?)
""",
(wallet_id, description, amount, webhook_url),
)
link_id = g.ext_db.cursor.lastrowid
return await get_pay_link(link_id)
def get_pay_link(link_id: int) -> Optional[PayLink]:
with open_ext_db("lnurlp") as db:
row = db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,))
async def get_pay_link(link_id: int) -> Optional[PayLink]:
row = await g.ext_db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,))
return PayLink.from_row(row) if row else None
def get_pay_link_by_invoice(payment_hash: str) -> Optional[PayLink]:
async def get_pay_link_by_invoice(payment_hash: str) -> Optional[PayLink]:
# this excludes invoices with webhooks that have been sent already
with open_ext_db("lnurlp") as db:
row = db.fetchone(
"""
SELECT pay_links.* FROM pay_links
INNER JOIN invoices ON invoices.pay_link = pay_links.id
WHERE payment_hash = ? AND webhook_sent IS NULL
""",
(payment_hash,),
)
row = await g.db.fetchone(
"""
SELECT pay_links.* FROM pay_links
INNER JOIN invoices ON invoices.pay_link = pay_links.id
WHERE payment_hash = ? AND webhook_sent IS NULL
""",
(payment_hash,),
)
return PayLink.from_row(row) if row else None
def get_pay_links(wallet_ids: Union[str, List[str]]) -> List[PayLink]:
async def get_pay_links(wallet_ids: Union[str, List[str]]) -> List[PayLink]:
if isinstance(wallet_ids, str):
wallet_ids = [wallet_ids]
with open_ext_db("lnurlp") as db:
q = ",".join(["?"] * len(wallet_ids))
rows = db.fetchall(f"SELECT * FROM pay_links WHERE wallet IN ({q})", (*wallet_ids,))
q = ",".join(["?"] * len(wallet_ids))
rows = await g.ext_db.fetchall(f"SELECT * FROM pay_links WHERE wallet IN ({q})", (*wallet_ids,))
return [PayLink.from_row(row) for row in rows]
def update_pay_link(link_id: int, **kwargs) -> Optional[PayLink]:
async def update_pay_link(link_id: int, **kwargs) -> Optional[PayLink]:
q = ", ".join([f"{field[0]} = ?" for field in kwargs.items()])
with open_ext_db("lnurlp") as db:
db.execute(f"UPDATE pay_links SET {q} WHERE id = ?", (*kwargs.values(), link_id))
row = db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,))
await g.ext_db.execute(f"UPDATE pay_links SET {q} WHERE id = ?", (*kwargs.values(), link_id))
row = await g.ext_db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,))
return PayLink.from_row(row) if row else None
def increment_pay_link(link_id: int, **kwargs) -> Optional[PayLink]:
async def increment_pay_link(link_id: int, **kwargs) -> Optional[PayLink]:
q = ", ".join([f"{field[0]} = {field[0]} + ?" for field in kwargs.items()])
with open_ext_db("lnurlp") as db:
db.execute(f"UPDATE pay_links SET {q} WHERE id = ?", (*kwargs.values(), link_id))
row = db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,))
await g.ext_db.execute(f"UPDATE pay_links SET {q} WHERE id = ?", (*kwargs.values(), link_id))
row = await g.ext_db.fetchone("SELECT * FROM pay_links WHERE id = ?", (link_id,))
return PayLink.from_row(row) if row else None
def delete_pay_link(link_id: int) -> None:
with open_ext_db("lnurlp") as db:
db.execute("DELETE FROM pay_links WHERE id = ?", (link_id,))
async def delete_pay_link(link_id: int) -> None:
await g.ext_db.execute("DELETE FROM pay_links WHERE id = ?", (link_id,))
def save_link_invoice(link_id: int, payment_request: str) -> None:
async def save_link_invoice(link_id: int, payment_request: str) -> None:
inv = bolt11.decode(payment_request)
with open_ext_db("lnurlp") as db:
db.execute(
"""
INSERT INTO invoices (pay_link, payment_hash, expiry)
VALUES (?, ?, ?)
""",
(link_id, inv.payment_hash, inv.expiry),
)
await g.db.execute(
"""
INSERT INTO invoices (pay_link, payment_hash, expiry)
VALUES (?, ?, ?)
""",
(link_id, inv.payment_hash, inv.expiry),
)
def mark_webhook_sent(payment_hash: str, status: int) -> None:
with open_ext_db("lnurlp") as db:
db.execute(
"""
UPDATE invoices SET webhook_sent = ?
WHERE payment_hash = ?
""",
(status, payment_hash),
)
await g.db.execute(
"""
UPDATE invoices SET webhook_sent = ?
WHERE payment_hash = ?
""",
(status, payment_hash),
)

4
lnbits/extensions/lnurlp/lnurl.py

@ -12,7 +12,7 @@ from .crud import increment_pay_link, save_link_invoice
@lnurlp_ext.route("/api/v1/lnurl/<link_id>", methods=["GET"])
async def api_lnurl_response(link_id):
link = increment_pay_link(link_id, served_meta=1)
link = await increment_pay_link(link_id, served_meta=1)
if not link:
return jsonify({"status": "ERROR", "reason": "LNURL-pay not found."}), HTTPStatus.OK
@ -30,7 +30,7 @@ async def api_lnurl_response(link_id):
@lnurlp_ext.route("/api/v1/lnurl/cb/<link_id>", methods=["GET"])
async def api_lnurl_callback(link_id):
link = increment_pay_link(link_id, served_pr=1)
link = await increment_pay_link(link_id, served_pr=1)
if not link:
return jsonify({"status": "ERROR", "reason": "LNURL-pay not found."}), HTTPStatus.OK

4
lnbits/extensions/lnurlp/migrations.py

@ -1,8 +1,8 @@
def m001_initial(db):
async def m001_initial(db):
"""
Initial pay table.
"""
db.execute(
await db.execute(
"""
CREATE TABLE IF NOT EXISTS pay_links (
id INTEGER PRIMARY KEY AUTOINCREMENT,

4
lnbits/extensions/lnurlp/views.py

@ -16,11 +16,11 @@ async def index():
@lnurlp_ext.route("/<link_id>")
async def display(link_id):
link = get_pay_link(link_id) or abort(HTTPStatus.NOT_FOUND, "Pay link does not exist.")
link = await get_pay_link(link_id) or abort(HTTPStatus.NOT_FOUND, "Pay link does not exist.")
return await render_template("lnurlp/display.html", link=link)
@lnurlp_ext.route("/print/<link_id>")
async def print_qr(link_id):
link = get_pay_link(link_id) or abort(HTTPStatus.NOT_FOUND, "Pay link does not exist.")
link = await get_pay_link(link_id) or abort(HTTPStatus.NOT_FOUND, "Pay link does not exist.")
return await render_template("lnurlp/print_qr.html", link=link)

15
lnbits/extensions/lnurlp/views_api.py

@ -25,7 +25,7 @@ async def api_links():
try:
return (
jsonify([{**link._asdict(), **{"lnurl": link.lnurl}} for link in get_pay_links(wallet_ids)]),
jsonify([{**link._asdict(), **{"lnurl": link.lnurl}} for link in await get_pay_links(wallet_ids)]),
HTTPStatus.OK,
)
except LnurlInvalidUrl:
@ -38,7 +38,7 @@ async def api_links():
@lnurlp_ext.route("/api/v1/links/<link_id>", methods=["GET"])
@api_check_wallet_key("invoice")
async def api_link_retrieve(link_id):
link = get_pay_link(link_id)
link = await get_pay_link(link_id)
if not link:
return jsonify({"message": "Pay link does not exist."}), HTTPStatus.NOT_FOUND
@ -61,7 +61,7 @@ async def api_link_retrieve(link_id):
)
async def api_link_create_or_update(link_id=None):
if link_id:
link = get_pay_link(link_id)
link = await get_pay_link(link_id)
if not link:
return jsonify({"message": "Pay link does not exist."}), HTTPStatus.NOT_FOUND
@ -69,9 +69,9 @@ async def api_link_create_or_update(link_id=None):
if link.wallet != g.wallet.id:
return jsonify({"message": "Not your pay link."}), HTTPStatus.FORBIDDEN
link = update_pay_link(link_id, **g.data)
link = await update_pay_link(link_id, **g.data)
else:
link = create_pay_link(wallet_id=g.wallet.id, **g.data)
link = await create_pay_link(wallet_id=g.wallet.id, **g.data)
return jsonify({**link._asdict(), **{"lnurl": link.lnurl}}), HTTPStatus.OK if link_id else HTTPStatus.CREATED
@ -79,7 +79,7 @@ async def api_link_create_or_update(link_id=None):
@lnurlp_ext.route("/api/v1/links/<link_id>", methods=["DELETE"])
@api_check_wallet_key("invoice")
async def api_link_delete(link_id):
link = get_pay_link(link_id)
link = await get_pay_link(link_id)
if not link:
return jsonify({"message": "Pay link does not exist."}), HTTPStatus.NOT_FOUND
@ -87,6 +87,5 @@ async def api_link_delete(link_id):
if link.wallet != g.wallet.id:
return jsonify({"message": "Not your pay link."}), HTTPStatus.FORBIDDEN
delete_pay_link(link_id)
await delete_pay_link(link_id)
return "", HTTPStatus.NO_CONTENT

Loading…
Cancel
Save