From 4855e2cd3dea8eb908f44c7780aad388f9638159 Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Mon, 7 Sep 2020 00:47:13 -0300 Subject: [PATCH] make every request a database transaction. --- lnbits/app.py | 17 ++- lnbits/core/crud.py | 295 +++++++++++++++++++-------------------- lnbits/core/services.py | 4 +- lnbits/core/views/api.py | 2 + lnbits/db.py | 15 +- 5 files changed, 174 insertions(+), 159 deletions(-) diff --git a/lnbits/app.py b/lnbits/app.py index 95dc331..e281524 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -1,6 +1,6 @@ import importlib -from flask import Flask +from flask import Flask, g from flask_assets import Bundle # type: ignore from flask_cors import CORS # type: ignore from flask_talisman import Talisman # type: ignore @@ -8,6 +8,7 @@ from werkzeug.middleware.proxy_fix import ProxyFix from .commands import flask_migrate from .core import core_app +from .db import open_db from .ext import assets, compress from .helpers import get_valid_extensions @@ -24,6 +25,7 @@ def create_app(config_object="lnbits.settings") -> Flask: register_blueprints(app) register_filters(app) register_commands(app) + register_request_hooks(app) return app @@ -73,3 +75,16 @@ def register_filters(app): app.jinja_env.globals["DEBUG"] = app.config["DEBUG"] app.jinja_env.globals["EXTENSIONS"] = get_valid_extensions() app.jinja_env.globals["SITE_TITLE"] = app.config["LNBITS_SITE_TITLE"] + + +def register_request_hooks(app): + """Open the core db for each request so everything happens in a big transaction""" + + @app.before_request + def before_request(): + g.db = open_db() + + @app.teardown_request + def after_request(exc): + print("after", exc) + g.db.__exit__(type(exc), exc, None) diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index 6758624..0963416 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -2,8 +2,8 @@ import json import datetime from uuid import uuid4 from typing import List, Optional, Dict +from flask import g -from lnbits.db import open_db from lnbits import bolt11 from lnbits.settings import DEFAULT_WALLET_NAME @@ -15,9 +15,8 @@ from .models import User, Wallet, Payment def create_account() -> User: - with open_db() as db: - user_id = uuid4().hex - db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,)) + user_id = uuid4().hex + g.db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,)) new_account = get_account(user_id=user_id) assert new_account, "Newly created account couldn't be retrieved" @@ -26,26 +25,24 @@ def create_account() -> User: def get_account(user_id: str) -> Optional[User]: - with open_db() as db: - row = db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,)) + row = g.db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,)) return User(**row) if row else None def get_user(user_id: str) -> Optional[User]: - with open_db() as db: - user = db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,)) - - if user: - extensions = db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,)) - wallets = db.fetchall( - """ - SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat - FROM wallets - WHERE user = ? - """, - (user_id,), - ) + user = g.db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,)) + + if user: + extensions = g.db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,)) + wallets = g.db.fetchall( + """ + SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat + FROM wallets + WHERE user = ? + """, + (user_id,), + ) return ( User(**{**user, **{"extensions": [e[0] for e in extensions], "wallets": [Wallet(**w) for w in wallets]}}) @@ -55,14 +52,13 @@ def get_user(user_id: str) -> Optional[User]: def update_user_extension(*, user_id: str, extension: str, active: int) -> None: - with open_db() as db: - db.execute( - """ - INSERT OR REPLACE INTO extensions (user, extension, active) - VALUES (?, ?, ?) - """, - (user_id, extension, active), - ) + g.db.execute( + """ + INSERT OR REPLACE INTO extensions (user, extension, active) + VALUES (?, ?, ?) + """, + (user_id, extension, active), + ) # wallets @@ -70,15 +66,14 @@ def update_user_extension(*, user_id: str, extension: str, active: int) -> None: def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: - with open_db() as db: - wallet_id = uuid4().hex - db.execute( - """ - INSERT INTO wallets (id, name, user, adminkey, inkey) - VALUES (?, ?, ?, ?, ?) - """, - (wallet_id, wallet_name or DEFAULT_WALLET_NAME, user_id, uuid4().hex, uuid4().hex), - ) + wallet_id = uuid4().hex + g.db.execute( + """ + INSERT INTO wallets (id, name, user, adminkey, inkey) + VALUES (?, ?, ?, ?, ?) + """, + (wallet_id, wallet_name or DEFAULT_WALLET_NAME, user_id, uuid4().hex, uuid4().hex), + ) new_wallet = get_wallet(wallet_id=wallet_id) assert new_wallet, "Newly created wallet couldn't be retrieved" @@ -87,52 +82,49 @@ def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: def delete_wallet(*, user_id: str, wallet_id: str) -> None: - with open_db() as db: - db.execute( - """ - UPDATE wallets AS w - SET - user = 'del:' || w.user, - adminkey = 'del:' || w.adminkey, - inkey = 'del:' || w.inkey - WHERE id = ? AND user = ? - """, - (wallet_id, user_id), - ) + g.db.execute( + """ + UPDATE wallets AS w + SET + user = 'del:' || w.user, + adminkey = 'del:' || w.adminkey, + inkey = 'del:' || w.inkey + WHERE id = ? AND user = ? + """, + (wallet_id, user_id), + ) def get_wallet(wallet_id: str) -> Optional[Wallet]: - with open_db() as db: - row = db.fetchone( - """ - SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat - FROM wallets - WHERE id = ? - """, - (wallet_id,), - ) + row = g.db.fetchone( + """ + SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat + FROM wallets + WHERE id = ? + """, + (wallet_id,), + ) return Wallet(**row) if row else None def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: - with open_db() as db: - row = db.fetchone( - """ - SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat - FROM wallets - WHERE adminkey = ? OR inkey = ? - """, - (key, key), - ) + row = g.db.fetchone( + """ + SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat + FROM wallets + WHERE adminkey = ? OR inkey = ? + """, + (key, key), + ) - if not row: - return None + if not row: + return None - if key_type == "admin" and row["adminkey"] != key: - return None + if key_type == "admin" and row["adminkey"] != key: + return None - return Wallet(**row) + return Wallet(**row) # wallet payments @@ -140,15 +132,14 @@ def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]: - with open_db() as db: - row = db.fetchone( - """ - SELECT * - FROM apipayments - WHERE wallet = ? AND hash = ? - """, - (wallet_id, payment_hash), - ) + row = g.db.fetchone( + """ + SELECT * + FROM apipayments + WHERE wallet = ? AND hash = ? + """, + (wallet_id, payment_hash), + ) return Payment.from_row(row) if row else None @@ -179,46 +170,44 @@ def get_wallet_payments( else: raise TypeError("at least one of [outgoing, incoming] must be True.") - with open_db() as db: - rows = db.fetchall( - f""" - SELECT * - FROM apipayments - WHERE wallet = ? {clause} - ORDER BY time DESC - """, - (wallet_id,), - ) + rows = g.db.fetchall( + f""" + SELECT * + FROM apipayments + WHERE wallet = ? {clause} + ORDER BY time DESC + """, + (wallet_id,), + ) return [Payment.from_row(row) for row in rows] def delete_expired_invoices() -> None: - with open_db() as db: - rows = db.fetchall( - """ - SELECT bolt11 - FROM apipayments - WHERE pending = 1 AND amount > 0 AND time < strftime('%s', 'now') - 86400 + rows = g.db.fetchall( """ - ) - for (payment_request,) in rows: - try: - invoice = bolt11.decode(payment_request) - except: - continue + SELECT bolt11 + FROM apipayments + WHERE pending = 1 AND amount > 0 AND time < strftime('%s', 'now') - 86400 + """ + ) + for (payment_request,) in rows: + try: + invoice = bolt11.decode(payment_request) + except: + continue - expiration_date = datetime.datetime.fromtimestamp(invoice.date + invoice.expiry) - if expiration_date > datetime.datetime.utcnow(): - continue + expiration_date = datetime.datetime.fromtimestamp(invoice.date + invoice.expiry) + if expiration_date > datetime.datetime.utcnow(): + continue - db.execute( - """ - DELETE FROM apipayments - WHERE pending = 1 AND hash = ? - """, - (invoice.payment_hash,), - ) + g.db.execute( + """ + DELETE FROM apipayments + WHERE pending = 1 AND hash = ? + """, + (invoice.payment_hash,), + ) # payments @@ -238,27 +227,26 @@ def create_payment( pending: bool = True, extra: Optional[Dict] = None, ) -> Payment: - with open_db() as db: - db.execute( - """ - INSERT INTO apipayments - (wallet, checking_id, bolt11, hash, preimage, - amount, pending, memo, fee, extra) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - wallet_id, - checking_id, - payment_request, - payment_hash, - preimage, - amount, - int(pending), - memo, - fee, - json.dumps(extra) if extra and extra != {} and type(extra) is dict else None, - ), - ) + g.db.execute( + """ + INSERT INTO apipayments + (wallet, checking_id, bolt11, hash, preimage, + amount, pending, memo, fee, extra) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + wallet_id, + checking_id, + payment_request, + payment_hash, + preimage, + amount, + int(pending), + memo, + fee, + json.dumps(extra) if extra and extra != {} and type(extra) is dict else None, + ), + ) new_payment = get_wallet_payment(wallet_id, payment_hash) assert new_payment, "Newly created payment couldn't be retrieved" @@ -267,31 +255,28 @@ def create_payment( def update_payment_status(checking_id: str, pending: bool) -> None: - with open_db() as db: - db.execute( - "UPDATE apipayments SET pending = ? WHERE checking_id = ?", - ( - int(pending), - checking_id, - ), - ) + g.db.execute( + "UPDATE apipayments SET pending = ? WHERE checking_id = ?", + ( + int(pending), + checking_id, + ), + ) def delete_payment(checking_id: str) -> None: - with open_db() as db: - db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,)) + g.db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,)) def check_internal(payment_hash: str) -> Optional[str]: - with open_db() as db: - row = db.fetchone( - """ - SELECT checking_id FROM apipayments - WHERE hash = ? AND pending AND amount > 0 - """, - (payment_hash,), - ) - if not row: - return None - else: - return row["checking_id"] + row = g.db.fetchone( + """ + SELECT checking_id FROM apipayments + WHERE hash = ? AND pending AND amount > 0 + """, + (payment_hash,), + ) + if not row: + return None + else: + return row["checking_id"] diff --git a/lnbits/core/services.py b/lnbits/core/services.py index 17ea8f1..eb82bd7 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -1,4 +1,5 @@ from typing import Optional, Tuple, Dict +from flask import g try: from typing import TypedDict # type: ignore @@ -94,6 +95,7 @@ def pay_invoice( wallet = get_wallet(wallet_id) assert wallet, "invalid wallet id" if wallet.balance_msat < 0: + g.db.rollback() raise PermissionError("Insufficient balance.") if internal: @@ -108,7 +110,7 @@ def pay_invoice( create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs) delete_payment(temp_id) else: - raise Exception(error_message or "Unexpected backend error.") + raise Exception(error_message or "Failed to pay_invoice on backend.") return invoice.payment_hash diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 345aa51..b9da9ff 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -48,6 +48,7 @@ def api_payments_create_invoice(): wallet_id=g.wallet.id, amount=g.data["amount"], memo=memo, description_hash=description_hash ) except Exception as e: + g.db.rollback() return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR invoice = bolt11.decode(payment_request) @@ -75,6 +76,7 @@ def api_payments_pay_invoice(): return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN except Exception as e: print(e) + g.db.rollback() return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR return ( diff --git a/lnbits/db.py b/lnbits/db.py index 316bb21..ec26d69 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -15,9 +15,20 @@ class Database: return self def __exit__(self, exc_type, exc_val, exc_tb): + if exc_val: + self.connection.rollback() + self.cursor.close() + self.cursor.close() + else: + self.connection.commit() + self.cursor.close() + self.connection.close() + + def commit(self): self.connection.commit() - self.cursor.close() - self.connection.close() + + def rollback(self): + self.connection.rollback() def fetchall(self, query: str, values: tuple = ()) -> list: """Given a query, return cursor.fetchall() rows."""