Browse Source

make every request a database transaction.

aiosqlite
fiatjaf 4 years ago
committed by fiatjaf
parent
commit
4855e2cd3d
  1. 17
      lnbits/app.py
  2. 295
      lnbits/core/crud.py
  3. 4
      lnbits/core/services.py
  4. 2
      lnbits/core/views/api.py
  5. 15
      lnbits/db.py

17
lnbits/app.py

@ -1,6 +1,6 @@
import importlib import importlib
from flask import Flask from flask import Flask, g
from flask_assets import Bundle # type: ignore from flask_assets import Bundle # type: ignore
from flask_cors import CORS # type: ignore from flask_cors import CORS # type: ignore
from flask_talisman import Talisman # type: ignore from flask_talisman import Talisman # type: ignore
@ -8,6 +8,7 @@ from werkzeug.middleware.proxy_fix import ProxyFix
from .commands import flask_migrate from .commands import flask_migrate
from .core import core_app from .core import core_app
from .db import open_db
from .ext import assets, compress from .ext import assets, compress
from .helpers import get_valid_extensions from .helpers import get_valid_extensions
@ -24,6 +25,7 @@ def create_app(config_object="lnbits.settings") -> Flask:
register_blueprints(app) register_blueprints(app)
register_filters(app) register_filters(app)
register_commands(app) register_commands(app)
register_request_hooks(app)
return app return app
@ -73,3 +75,16 @@ def register_filters(app):
app.jinja_env.globals["DEBUG"] = app.config["DEBUG"] app.jinja_env.globals["DEBUG"] = app.config["DEBUG"]
app.jinja_env.globals["EXTENSIONS"] = get_valid_extensions() app.jinja_env.globals["EXTENSIONS"] = get_valid_extensions()
app.jinja_env.globals["SITE_TITLE"] = app.config["LNBITS_SITE_TITLE"] app.jinja_env.globals["SITE_TITLE"] = app.config["LNBITS_SITE_TITLE"]
def register_request_hooks(app):
"""Open the core db for each request so everything happens in a big transaction"""
@app.before_request
def before_request():
g.db = open_db()
@app.teardown_request
def after_request(exc):
print("after", exc)
g.db.__exit__(type(exc), exc, None)

295
lnbits/core/crud.py

@ -2,8 +2,8 @@ import json
import datetime import datetime
from uuid import uuid4 from uuid import uuid4
from typing import List, Optional, Dict from typing import List, Optional, Dict
from flask import g
from lnbits.db import open_db
from lnbits import bolt11 from lnbits import bolt11
from lnbits.settings import DEFAULT_WALLET_NAME from lnbits.settings import DEFAULT_WALLET_NAME
@ -15,9 +15,8 @@ from .models import User, Wallet, Payment
def create_account() -> User: def create_account() -> User:
with open_db() as db: user_id = uuid4().hex
user_id = uuid4().hex g.db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
new_account = get_account(user_id=user_id) new_account = get_account(user_id=user_id)
assert new_account, "Newly created account couldn't be retrieved" assert new_account, "Newly created account couldn't be retrieved"
@ -26,26 +25,24 @@ def create_account() -> User:
def get_account(user_id: str) -> Optional[User]: def get_account(user_id: str) -> Optional[User]:
with open_db() as db: row = g.db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,))
row = db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,))
return User(**row) if row else None return User(**row) if row else None
def get_user(user_id: str) -> Optional[User]: def get_user(user_id: str) -> Optional[User]:
with open_db() as db: user = g.db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,))
user = db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,))
if user:
if user: extensions = g.db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,))
extensions = db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,)) wallets = g.db.fetchall(
wallets = db.fetchall( """
""" SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat FROM wallets
FROM wallets WHERE user = ?
WHERE user = ? """,
""", (user_id,),
(user_id,), )
)
return ( return (
User(**{**user, **{"extensions": [e[0] for e in extensions], "wallets": [Wallet(**w) for w in wallets]}}) User(**{**user, **{"extensions": [e[0] for e in extensions], "wallets": [Wallet(**w) for w in wallets]}})
@ -55,14 +52,13 @@ def get_user(user_id: str) -> Optional[User]:
def update_user_extension(*, user_id: str, extension: str, active: int) -> None: def update_user_extension(*, user_id: str, extension: str, active: int) -> None:
with open_db() as db: g.db.execute(
db.execute( """
""" INSERT OR REPLACE INTO extensions (user, extension, active)
INSERT OR REPLACE INTO extensions (user, extension, active) VALUES (?, ?, ?)
VALUES (?, ?, ?) """,
""", (user_id, extension, active),
(user_id, extension, active), )
)
# wallets # wallets
@ -70,15 +66,14 @@ def update_user_extension(*, user_id: str, extension: str, active: int) -> None:
def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
with open_db() as db: wallet_id = uuid4().hex
wallet_id = uuid4().hex g.db.execute(
db.execute( """
""" INSERT INTO wallets (id, name, user, adminkey, inkey)
INSERT INTO wallets (id, name, user, adminkey, inkey) VALUES (?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?) """,
""", (wallet_id, wallet_name or DEFAULT_WALLET_NAME, user_id, uuid4().hex, uuid4().hex),
(wallet_id, wallet_name or DEFAULT_WALLET_NAME, user_id, uuid4().hex, uuid4().hex), )
)
new_wallet = get_wallet(wallet_id=wallet_id) new_wallet = get_wallet(wallet_id=wallet_id)
assert new_wallet, "Newly created wallet couldn't be retrieved" assert new_wallet, "Newly created wallet couldn't be retrieved"
@ -87,52 +82,49 @@ def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
def delete_wallet(*, user_id: str, wallet_id: str) -> None: def delete_wallet(*, user_id: str, wallet_id: str) -> None:
with open_db() as db: g.db.execute(
db.execute( """
""" UPDATE wallets AS w
UPDATE wallets AS w SET
SET user = 'del:' || w.user,
user = 'del:' || w.user, adminkey = 'del:' || w.adminkey,
adminkey = 'del:' || w.adminkey, inkey = 'del:' || w.inkey
inkey = 'del:' || w.inkey WHERE id = ? AND user = ?
WHERE id = ? AND user = ? """,
""", (wallet_id, user_id),
(wallet_id, user_id), )
)
def get_wallet(wallet_id: str) -> Optional[Wallet]: def get_wallet(wallet_id: str) -> Optional[Wallet]:
with open_db() as db: row = g.db.fetchone(
row = db.fetchone( """
""" SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat FROM wallets
FROM wallets WHERE id = ?
WHERE id = ? """,
""", (wallet_id,),
(wallet_id,), )
)
return Wallet(**row) if row else None return Wallet(**row) if row else None
def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]:
with open_db() as db: row = g.db.fetchone(
row = db.fetchone( """
""" SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat FROM wallets
FROM wallets WHERE adminkey = ? OR inkey = ?
WHERE adminkey = ? OR inkey = ? """,
""", (key, key),
(key, key), )
)
if not row: if not row:
return None return None
if key_type == "admin" and row["adminkey"] != key: if key_type == "admin" and row["adminkey"] != key:
return None return None
return Wallet(**row) return Wallet(**row)
# wallet payments # wallet payments
@ -140,15 +132,14 @@ def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]:
def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]: def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]:
with open_db() as db: row = g.db.fetchone(
row = db.fetchone( """
""" SELECT *
SELECT * FROM apipayments
FROM apipayments WHERE wallet = ? AND hash = ?
WHERE wallet = ? AND hash = ? """,
""", (wallet_id, payment_hash),
(wallet_id, payment_hash), )
)
return Payment.from_row(row) if row else None return Payment.from_row(row) if row else None
@ -179,46 +170,44 @@ def get_wallet_payments(
else: else:
raise TypeError("at least one of [outgoing, incoming] must be True.") raise TypeError("at least one of [outgoing, incoming] must be True.")
with open_db() as db: rows = g.db.fetchall(
rows = db.fetchall( f"""
f""" SELECT *
SELECT * FROM apipayments
FROM apipayments WHERE wallet = ? {clause}
WHERE wallet = ? {clause} ORDER BY time DESC
ORDER BY time DESC """,
""", (wallet_id,),
(wallet_id,), )
)
return [Payment.from_row(row) for row in rows] return [Payment.from_row(row) for row in rows]
def delete_expired_invoices() -> None: def delete_expired_invoices() -> None:
with open_db() as db: rows = g.db.fetchall(
rows = db.fetchall(
"""
SELECT bolt11
FROM apipayments
WHERE pending = 1 AND amount > 0 AND time < strftime('%s', 'now') - 86400
""" """
) SELECT bolt11
for (payment_request,) in rows: FROM apipayments
try: WHERE pending = 1 AND amount > 0 AND time < strftime('%s', 'now') - 86400
invoice = bolt11.decode(payment_request) """
except: )
continue for (payment_request,) in rows:
try:
invoice = bolt11.decode(payment_request)
except:
continue
expiration_date = datetime.datetime.fromtimestamp(invoice.date + invoice.expiry) expiration_date = datetime.datetime.fromtimestamp(invoice.date + invoice.expiry)
if expiration_date > datetime.datetime.utcnow(): if expiration_date > datetime.datetime.utcnow():
continue continue
db.execute( g.db.execute(
""" """
DELETE FROM apipayments DELETE FROM apipayments
WHERE pending = 1 AND hash = ? WHERE pending = 1 AND hash = ?
""", """,
(invoice.payment_hash,), (invoice.payment_hash,),
) )
# payments # payments
@ -238,27 +227,26 @@ def create_payment(
pending: bool = True, pending: bool = True,
extra: Optional[Dict] = None, extra: Optional[Dict] = None,
) -> Payment: ) -> Payment:
with open_db() as db: g.db.execute(
db.execute( """
""" INSERT INTO apipayments
INSERT INTO apipayments (wallet, checking_id, bolt11, hash, preimage,
(wallet, checking_id, bolt11, hash, preimage, amount, pending, memo, fee, extra)
amount, pending, memo, fee, extra) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """,
""", (
( wallet_id,
wallet_id, checking_id,
checking_id, payment_request,
payment_request, payment_hash,
payment_hash, preimage,
preimage, amount,
amount, int(pending),
int(pending), memo,
memo, fee,
fee, json.dumps(extra) if extra and extra != {} and type(extra) is dict else None,
json.dumps(extra) if extra and extra != {} and type(extra) is dict else None, ),
), )
)
new_payment = get_wallet_payment(wallet_id, payment_hash) new_payment = get_wallet_payment(wallet_id, payment_hash)
assert new_payment, "Newly created payment couldn't be retrieved" assert new_payment, "Newly created payment couldn't be retrieved"
@ -267,31 +255,28 @@ def create_payment(
def update_payment_status(checking_id: str, pending: bool) -> None: def update_payment_status(checking_id: str, pending: bool) -> None:
with open_db() as db: g.db.execute(
db.execute( "UPDATE apipayments SET pending = ? WHERE checking_id = ?",
"UPDATE apipayments SET pending = ? WHERE checking_id = ?", (
( int(pending),
int(pending), checking_id,
checking_id, ),
), )
)
def delete_payment(checking_id: str) -> None: def delete_payment(checking_id: str) -> None:
with open_db() as db: g.db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,))
db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,))
def check_internal(payment_hash: str) -> Optional[str]: def check_internal(payment_hash: str) -> Optional[str]:
with open_db() as db: row = g.db.fetchone(
row = db.fetchone( """
""" SELECT checking_id FROM apipayments
SELECT checking_id FROM apipayments WHERE hash = ? AND pending AND amount > 0
WHERE hash = ? AND pending AND amount > 0 """,
""", (payment_hash,),
(payment_hash,), )
) if not row:
if not row: return None
return None else:
else: return row["checking_id"]
return row["checking_id"]

4
lnbits/core/services.py

@ -1,4 +1,5 @@
from typing import Optional, Tuple, Dict from typing import Optional, Tuple, Dict
from flask import g
try: try:
from typing import TypedDict # type: ignore from typing import TypedDict # type: ignore
@ -94,6 +95,7 @@ def pay_invoice(
wallet = get_wallet(wallet_id) wallet = get_wallet(wallet_id)
assert wallet, "invalid wallet id" assert wallet, "invalid wallet id"
if wallet.balance_msat < 0: if wallet.balance_msat < 0:
g.db.rollback()
raise PermissionError("Insufficient balance.") raise PermissionError("Insufficient balance.")
if internal: if internal:
@ -108,7 +110,7 @@ def pay_invoice(
create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs) create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs)
delete_payment(temp_id) delete_payment(temp_id)
else: else:
raise Exception(error_message or "Unexpected backend error.") raise Exception(error_message or "Failed to pay_invoice on backend.")
return invoice.payment_hash return invoice.payment_hash

2
lnbits/core/views/api.py

@ -48,6 +48,7 @@ def api_payments_create_invoice():
wallet_id=g.wallet.id, amount=g.data["amount"], memo=memo, description_hash=description_hash wallet_id=g.wallet.id, amount=g.data["amount"], memo=memo, description_hash=description_hash
) )
except Exception as e: except Exception as e:
g.db.rollback()
return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR
invoice = bolt11.decode(payment_request) invoice = bolt11.decode(payment_request)
@ -75,6 +76,7 @@ def api_payments_pay_invoice():
return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN
except Exception as e: except Exception as e:
print(e) print(e)
g.db.rollback()
return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR
return ( return (

15
lnbits/db.py

@ -15,9 +15,20 @@ class Database:
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val:
self.connection.rollback()
self.cursor.close()
self.cursor.close()
else:
self.connection.commit()
self.cursor.close()
self.connection.close()
def commit(self):
self.connection.commit() self.connection.commit()
self.cursor.close()
self.connection.close() def rollback(self):
self.connection.rollback()
def fetchall(self, query: str, values: tuple = ()) -> list: def fetchall(self, query: str, values: tuple = ()) -> list:
"""Given a query, return cursor.fetchall() rows.""" """Given a query, return cursor.fetchall() rows."""

Loading…
Cancel
Save