Browse Source

make every request a database transaction.

aiosqlite
fiatjaf 4 years ago
committed by fiatjaf
parent
commit
4855e2cd3d
  1. 17
      lnbits/app.py
  2. 53
      lnbits/core/crud.py
  3. 4
      lnbits/core/services.py
  4. 2
      lnbits/core/views/api.py
  5. 11
      lnbits/db.py

17
lnbits/app.py

@ -1,6 +1,6 @@
import importlib import importlib
from flask import Flask from flask import Flask, g
from flask_assets import Bundle # type: ignore from flask_assets import Bundle # type: ignore
from flask_cors import CORS # type: ignore from flask_cors import CORS # type: ignore
from flask_talisman import Talisman # type: ignore from flask_talisman import Talisman # type: ignore
@ -8,6 +8,7 @@ from werkzeug.middleware.proxy_fix import ProxyFix
from .commands import flask_migrate from .commands import flask_migrate
from .core import core_app from .core import core_app
from .db import open_db
from .ext import assets, compress from .ext import assets, compress
from .helpers import get_valid_extensions from .helpers import get_valid_extensions
@ -24,6 +25,7 @@ def create_app(config_object="lnbits.settings") -> Flask:
register_blueprints(app) register_blueprints(app)
register_filters(app) register_filters(app)
register_commands(app) register_commands(app)
register_request_hooks(app)
return app return app
@ -73,3 +75,16 @@ def register_filters(app):
app.jinja_env.globals["DEBUG"] = app.config["DEBUG"] app.jinja_env.globals["DEBUG"] = app.config["DEBUG"]
app.jinja_env.globals["EXTENSIONS"] = get_valid_extensions() app.jinja_env.globals["EXTENSIONS"] = get_valid_extensions()
app.jinja_env.globals["SITE_TITLE"] = app.config["LNBITS_SITE_TITLE"] app.jinja_env.globals["SITE_TITLE"] = app.config["LNBITS_SITE_TITLE"]
def register_request_hooks(app):
"""Open the core db for each request so everything happens in a big transaction"""
@app.before_request
def before_request():
g.db = open_db()
@app.teardown_request
def after_request(exc):
print("after", exc)
g.db.__exit__(type(exc), exc, None)

53
lnbits/core/crud.py

@ -2,8 +2,8 @@ import json
import datetime import datetime
from uuid import uuid4 from uuid import uuid4
from typing import List, Optional, Dict from typing import List, Optional, Dict
from flask import g
from lnbits.db import open_db
from lnbits import bolt11 from lnbits import bolt11
from lnbits.settings import DEFAULT_WALLET_NAME from lnbits.settings import DEFAULT_WALLET_NAME
@ -15,9 +15,8 @@ from .models import User, Wallet, Payment
def create_account() -> User: def create_account() -> User:
with open_db() as db:
user_id = uuid4().hex user_id = uuid4().hex
db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,)) g.db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
new_account = get_account(user_id=user_id) new_account = get_account(user_id=user_id)
assert new_account, "Newly created account couldn't be retrieved" assert new_account, "Newly created account couldn't be retrieved"
@ -26,19 +25,17 @@ def create_account() -> User:
def get_account(user_id: str) -> Optional[User]: def get_account(user_id: str) -> Optional[User]:
with open_db() as db: row = g.db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,))
row = db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,))
return User(**row) if row else None return User(**row) if row else None
def get_user(user_id: str) -> Optional[User]: def get_user(user_id: str) -> Optional[User]:
with open_db() as db: user = g.db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,))
user = db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,))
if user: if user:
extensions = db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,)) extensions = g.db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,))
wallets = db.fetchall( wallets = g.db.fetchall(
""" """
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets FROM wallets
@ -55,8 +52,7 @@ def get_user(user_id: str) -> Optional[User]:
def update_user_extension(*, user_id: str, extension: str, active: int) -> None: def update_user_extension(*, user_id: str, extension: str, active: int) -> None:
with open_db() as db: g.db.execute(
db.execute(
""" """
INSERT OR REPLACE INTO extensions (user, extension, active) INSERT OR REPLACE INTO extensions (user, extension, active)
VALUES (?, ?, ?) VALUES (?, ?, ?)
@ -70,9 +66,8 @@ def update_user_extension(*, user_id: str, extension: str, active: int) -> None:
def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
with open_db() as db:
wallet_id = uuid4().hex wallet_id = uuid4().hex
db.execute( g.db.execute(
""" """
INSERT INTO wallets (id, name, user, adminkey, inkey) INSERT INTO wallets (id, name, user, adminkey, inkey)
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
@ -87,8 +82,7 @@ def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
def delete_wallet(*, user_id: str, wallet_id: str) -> None: def delete_wallet(*, user_id: str, wallet_id: str) -> None:
with open_db() as db: g.db.execute(
db.execute(
""" """
UPDATE wallets AS w UPDATE wallets AS w
SET SET
@ -102,8 +96,7 @@ def delete_wallet(*, user_id: str, wallet_id: str) -> None:
def get_wallet(wallet_id: str) -> Optional[Wallet]: def get_wallet(wallet_id: str) -> Optional[Wallet]:
with open_db() as db: row = g.db.fetchone(
row = db.fetchone(
""" """
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets FROM wallets
@ -116,8 +109,7 @@ def get_wallet(wallet_id: str) -> Optional[Wallet]:
def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]:
with open_db() as db: row = g.db.fetchone(
row = db.fetchone(
""" """
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets FROM wallets
@ -140,8 +132,7 @@ def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]:
def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]: def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]:
with open_db() as db: row = g.db.fetchone(
row = db.fetchone(
""" """
SELECT * SELECT *
FROM apipayments FROM apipayments
@ -179,8 +170,7 @@ def get_wallet_payments(
else: else:
raise TypeError("at least one of [outgoing, incoming] must be True.") raise TypeError("at least one of [outgoing, incoming] must be True.")
with open_db() as db: rows = g.db.fetchall(
rows = db.fetchall(
f""" f"""
SELECT * SELECT *
FROM apipayments FROM apipayments
@ -194,8 +184,7 @@ def get_wallet_payments(
def delete_expired_invoices() -> None: def delete_expired_invoices() -> None:
with open_db() as db: rows = g.db.fetchall(
rows = db.fetchall(
""" """
SELECT bolt11 SELECT bolt11
FROM apipayments FROM apipayments
@ -212,7 +201,7 @@ def delete_expired_invoices() -> None:
if expiration_date > datetime.datetime.utcnow(): if expiration_date > datetime.datetime.utcnow():
continue continue
db.execute( g.db.execute(
""" """
DELETE FROM apipayments DELETE FROM apipayments
WHERE pending = 1 AND hash = ? WHERE pending = 1 AND hash = ?
@ -238,8 +227,7 @@ def create_payment(
pending: bool = True, pending: bool = True,
extra: Optional[Dict] = None, extra: Optional[Dict] = None,
) -> Payment: ) -> Payment:
with open_db() as db: g.db.execute(
db.execute(
""" """
INSERT INTO apipayments INSERT INTO apipayments
(wallet, checking_id, bolt11, hash, preimage, (wallet, checking_id, bolt11, hash, preimage,
@ -267,8 +255,7 @@ def create_payment(
def update_payment_status(checking_id: str, pending: bool) -> None: def update_payment_status(checking_id: str, pending: bool) -> None:
with open_db() as db: g.db.execute(
db.execute(
"UPDATE apipayments SET pending = ? WHERE checking_id = ?", "UPDATE apipayments SET pending = ? WHERE checking_id = ?",
( (
int(pending), int(pending),
@ -278,13 +265,11 @@ def update_payment_status(checking_id: str, pending: bool) -> None:
def delete_payment(checking_id: str) -> None: def delete_payment(checking_id: str) -> None:
with open_db() as db: g.db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,))
db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,))
def check_internal(payment_hash: str) -> Optional[str]: def check_internal(payment_hash: str) -> Optional[str]:
with open_db() as db: row = g.db.fetchone(
row = db.fetchone(
""" """
SELECT checking_id FROM apipayments SELECT checking_id FROM apipayments
WHERE hash = ? AND pending AND amount > 0 WHERE hash = ? AND pending AND amount > 0

4
lnbits/core/services.py

@ -1,4 +1,5 @@
from typing import Optional, Tuple, Dict from typing import Optional, Tuple, Dict
from flask import g
try: try:
from typing import TypedDict # type: ignore from typing import TypedDict # type: ignore
@ -94,6 +95,7 @@ def pay_invoice(
wallet = get_wallet(wallet_id) wallet = get_wallet(wallet_id)
assert wallet, "invalid wallet id" assert wallet, "invalid wallet id"
if wallet.balance_msat < 0: if wallet.balance_msat < 0:
g.db.rollback()
raise PermissionError("Insufficient balance.") raise PermissionError("Insufficient balance.")
if internal: if internal:
@ -108,7 +110,7 @@ def pay_invoice(
create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs) create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs)
delete_payment(temp_id) delete_payment(temp_id)
else: else:
raise Exception(error_message or "Unexpected backend error.") raise Exception(error_message or "Failed to pay_invoice on backend.")
return invoice.payment_hash return invoice.payment_hash

2
lnbits/core/views/api.py

@ -48,6 +48,7 @@ def api_payments_create_invoice():
wallet_id=g.wallet.id, amount=g.data["amount"], memo=memo, description_hash=description_hash wallet_id=g.wallet.id, amount=g.data["amount"], memo=memo, description_hash=description_hash
) )
except Exception as e: except Exception as e:
g.db.rollback()
return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR
invoice = bolt11.decode(payment_request) invoice = bolt11.decode(payment_request)
@ -75,6 +76,7 @@ def api_payments_pay_invoice():
return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN
except Exception as e: except Exception as e:
print(e) print(e)
g.db.rollback()
return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR
return ( return (

11
lnbits/db.py

@ -15,10 +15,21 @@ class Database:
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val:
self.connection.rollback()
self.cursor.close()
self.cursor.close()
else:
self.connection.commit() self.connection.commit()
self.cursor.close() self.cursor.close()
self.connection.close() self.connection.close()
def commit(self):
self.connection.commit()
def rollback(self):
self.connection.rollback()
def fetchall(self, query: str, values: tuple = ()) -> list: def fetchall(self, query: str, values: tuple = ()) -> list:
"""Given a query, return cursor.fetchall() rows.""" """Given a query, return cursor.fetchall() rows."""
self.execute(query, values) self.execute(query, values)

Loading…
Cancel
Save