Browse Source

make every request a database transaction.

aiosqlite
fiatjaf 4 years ago
committed by fiatjaf
parent
commit
4855e2cd3d
  1. 17
      lnbits/app.py
  2. 295
      lnbits/core/crud.py
  3. 4
      lnbits/core/services.py
  4. 2
      lnbits/core/views/api.py
  5. 15
      lnbits/db.py

17
lnbits/app.py

@ -1,6 +1,6 @@
import importlib
from flask import Flask
from flask import Flask, g
from flask_assets import Bundle # type: ignore
from flask_cors import CORS # type: ignore
from flask_talisman import Talisman # type: ignore
@ -8,6 +8,7 @@ from werkzeug.middleware.proxy_fix import ProxyFix
from .commands import flask_migrate
from .core import core_app
from .db import open_db
from .ext import assets, compress
from .helpers import get_valid_extensions
@ -24,6 +25,7 @@ def create_app(config_object="lnbits.settings") -> Flask:
register_blueprints(app)
register_filters(app)
register_commands(app)
register_request_hooks(app)
return app
@ -73,3 +75,16 @@ def register_filters(app):
app.jinja_env.globals["DEBUG"] = app.config["DEBUG"]
app.jinja_env.globals["EXTENSIONS"] = get_valid_extensions()
app.jinja_env.globals["SITE_TITLE"] = app.config["LNBITS_SITE_TITLE"]
def register_request_hooks(app):
"""Open the core db for each request so everything happens in a big transaction"""
@app.before_request
def before_request():
g.db = open_db()
@app.teardown_request
def after_request(exc):
print("after", exc)
g.db.__exit__(type(exc), exc, None)

295
lnbits/core/crud.py

@ -2,8 +2,8 @@ import json
import datetime
from uuid import uuid4
from typing import List, Optional, Dict
from flask import g
from lnbits.db import open_db
from lnbits import bolt11
from lnbits.settings import DEFAULT_WALLET_NAME
@ -15,9 +15,8 @@ from .models import User, Wallet, Payment
def create_account() -> User:
with open_db() as db:
user_id = uuid4().hex
db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
user_id = uuid4().hex
g.db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
new_account = get_account(user_id=user_id)
assert new_account, "Newly created account couldn't be retrieved"
@ -26,26 +25,24 @@ def create_account() -> User:
def get_account(user_id: str) -> Optional[User]:
with open_db() as db:
row = db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,))
row = g.db.fetchone("SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,))
return User(**row) if row else None
def get_user(user_id: str) -> Optional[User]:
with open_db() as db:
user = db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,))
if user:
extensions = db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,))
wallets = db.fetchall(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets
WHERE user = ?
""",
(user_id,),
)
user = g.db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,))
if user:
extensions = g.db.fetchall("SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,))
wallets = g.db.fetchall(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets
WHERE user = ?
""",
(user_id,),
)
return (
User(**{**user, **{"extensions": [e[0] for e in extensions], "wallets": [Wallet(**w) for w in wallets]}})
@ -55,14 +52,13 @@ def get_user(user_id: str) -> Optional[User]:
def update_user_extension(*, user_id: str, extension: str, active: int) -> None:
with open_db() as db:
db.execute(
"""
INSERT OR REPLACE INTO extensions (user, extension, active)
VALUES (?, ?, ?)
""",
(user_id, extension, active),
)
g.db.execute(
"""
INSERT OR REPLACE INTO extensions (user, extension, active)
VALUES (?, ?, ?)
""",
(user_id, extension, active),
)
# wallets
@ -70,15 +66,14 @@ def update_user_extension(*, user_id: str, extension: str, active: int) -> None:
def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
with open_db() as db:
wallet_id = uuid4().hex
db.execute(
"""
INSERT INTO wallets (id, name, user, adminkey, inkey)
VALUES (?, ?, ?, ?, ?)
""",
(wallet_id, wallet_name or DEFAULT_WALLET_NAME, user_id, uuid4().hex, uuid4().hex),
)
wallet_id = uuid4().hex
g.db.execute(
"""
INSERT INTO wallets (id, name, user, adminkey, inkey)
VALUES (?, ?, ?, ?, ?)
""",
(wallet_id, wallet_name or DEFAULT_WALLET_NAME, user_id, uuid4().hex, uuid4().hex),
)
new_wallet = get_wallet(wallet_id=wallet_id)
assert new_wallet, "Newly created wallet couldn't be retrieved"
@ -87,52 +82,49 @@ def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
def delete_wallet(*, user_id: str, wallet_id: str) -> None:
with open_db() as db:
db.execute(
"""
UPDATE wallets AS w
SET
user = 'del:' || w.user,
adminkey = 'del:' || w.adminkey,
inkey = 'del:' || w.inkey
WHERE id = ? AND user = ?
""",
(wallet_id, user_id),
)
g.db.execute(
"""
UPDATE wallets AS w
SET
user = 'del:' || w.user,
adminkey = 'del:' || w.adminkey,
inkey = 'del:' || w.inkey
WHERE id = ? AND user = ?
""",
(wallet_id, user_id),
)
def get_wallet(wallet_id: str) -> Optional[Wallet]:
with open_db() as db:
row = db.fetchone(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets
WHERE id = ?
""",
(wallet_id,),
)
row = g.db.fetchone(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets
WHERE id = ?
""",
(wallet_id,),
)
return Wallet(**row) if row else None
def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]:
with open_db() as db:
row = db.fetchone(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets
WHERE adminkey = ? OR inkey = ?
""",
(key, key),
)
row = g.db.fetchone(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets
WHERE adminkey = ? OR inkey = ?
""",
(key, key),
)
if not row:
return None
if not row:
return None
if key_type == "admin" and row["adminkey"] != key:
return None
if key_type == "admin" and row["adminkey"] != key:
return None
return Wallet(**row)
return Wallet(**row)
# wallet payments
@ -140,15 +132,14 @@ def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]:
def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]:
with open_db() as db:
row = db.fetchone(
"""
SELECT *
FROM apipayments
WHERE wallet = ? AND hash = ?
""",
(wallet_id, payment_hash),
)
row = g.db.fetchone(
"""
SELECT *
FROM apipayments
WHERE wallet = ? AND hash = ?
""",
(wallet_id, payment_hash),
)
return Payment.from_row(row) if row else None
@ -179,46 +170,44 @@ def get_wallet_payments(
else:
raise TypeError("at least one of [outgoing, incoming] must be True.")
with open_db() as db:
rows = db.fetchall(
f"""
SELECT *
FROM apipayments
WHERE wallet = ? {clause}
ORDER BY time DESC
""",
(wallet_id,),
)
rows = g.db.fetchall(
f"""
SELECT *
FROM apipayments
WHERE wallet = ? {clause}
ORDER BY time DESC
""",
(wallet_id,),
)
return [Payment.from_row(row) for row in rows]
def delete_expired_invoices() -> None:
with open_db() as db:
rows = db.fetchall(
"""
SELECT bolt11
FROM apipayments
WHERE pending = 1 AND amount > 0 AND time < strftime('%s', 'now') - 86400
rows = g.db.fetchall(
"""
)
for (payment_request,) in rows:
try:
invoice = bolt11.decode(payment_request)
except:
continue
SELECT bolt11
FROM apipayments
WHERE pending = 1 AND amount > 0 AND time < strftime('%s', 'now') - 86400
"""
)
for (payment_request,) in rows:
try:
invoice = bolt11.decode(payment_request)
except:
continue
expiration_date = datetime.datetime.fromtimestamp(invoice.date + invoice.expiry)
if expiration_date > datetime.datetime.utcnow():
continue
expiration_date = datetime.datetime.fromtimestamp(invoice.date + invoice.expiry)
if expiration_date > datetime.datetime.utcnow():
continue
db.execute(
"""
DELETE FROM apipayments
WHERE pending = 1 AND hash = ?
""",
(invoice.payment_hash,),
)
g.db.execute(
"""
DELETE FROM apipayments
WHERE pending = 1 AND hash = ?
""",
(invoice.payment_hash,),
)
# payments
@ -238,27 +227,26 @@ def create_payment(
pending: bool = True,
extra: Optional[Dict] = None,
) -> Payment:
with open_db() as db:
db.execute(
"""
INSERT INTO apipayments
(wallet, checking_id, bolt11, hash, preimage,
amount, pending, memo, fee, extra)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
wallet_id,
checking_id,
payment_request,
payment_hash,
preimage,
amount,
int(pending),
memo,
fee,
json.dumps(extra) if extra and extra != {} and type(extra) is dict else None,
),
)
g.db.execute(
"""
INSERT INTO apipayments
(wallet, checking_id, bolt11, hash, preimage,
amount, pending, memo, fee, extra)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
wallet_id,
checking_id,
payment_request,
payment_hash,
preimage,
amount,
int(pending),
memo,
fee,
json.dumps(extra) if extra and extra != {} and type(extra) is dict else None,
),
)
new_payment = get_wallet_payment(wallet_id, payment_hash)
assert new_payment, "Newly created payment couldn't be retrieved"
@ -267,31 +255,28 @@ def create_payment(
def update_payment_status(checking_id: str, pending: bool) -> None:
with open_db() as db:
db.execute(
"UPDATE apipayments SET pending = ? WHERE checking_id = ?",
(
int(pending),
checking_id,
),
)
g.db.execute(
"UPDATE apipayments SET pending = ? WHERE checking_id = ?",
(
int(pending),
checking_id,
),
)
def delete_payment(checking_id: str) -> None:
with open_db() as db:
db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,))
g.db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,))
def check_internal(payment_hash: str) -> Optional[str]:
with open_db() as db:
row = db.fetchone(
"""
SELECT checking_id FROM apipayments
WHERE hash = ? AND pending AND amount > 0
""",
(payment_hash,),
)
if not row:
return None
else:
return row["checking_id"]
row = g.db.fetchone(
"""
SELECT checking_id FROM apipayments
WHERE hash = ? AND pending AND amount > 0
""",
(payment_hash,),
)
if not row:
return None
else:
return row["checking_id"]

4
lnbits/core/services.py

@ -1,4 +1,5 @@
from typing import Optional, Tuple, Dict
from flask import g
try:
from typing import TypedDict # type: ignore
@ -94,6 +95,7 @@ def pay_invoice(
wallet = get_wallet(wallet_id)
assert wallet, "invalid wallet id"
if wallet.balance_msat < 0:
g.db.rollback()
raise PermissionError("Insufficient balance.")
if internal:
@ -108,7 +110,7 @@ def pay_invoice(
create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs)
delete_payment(temp_id)
else:
raise Exception(error_message or "Unexpected backend error.")
raise Exception(error_message or "Failed to pay_invoice on backend.")
return invoice.payment_hash

2
lnbits/core/views/api.py

@ -48,6 +48,7 @@ def api_payments_create_invoice():
wallet_id=g.wallet.id, amount=g.data["amount"], memo=memo, description_hash=description_hash
)
except Exception as e:
g.db.rollback()
return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR
invoice = bolt11.decode(payment_request)
@ -75,6 +76,7 @@ def api_payments_pay_invoice():
return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN
except Exception as e:
print(e)
g.db.rollback()
return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR
return (

15
lnbits/db.py

@ -15,9 +15,20 @@ class Database:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val:
self.connection.rollback()
self.cursor.close()
self.cursor.close()
else:
self.connection.commit()
self.cursor.close()
self.connection.close()
def commit(self):
self.connection.commit()
self.cursor.close()
self.connection.close()
def rollback(self):
self.connection.rollback()
def fetchall(self, query: str, values: tuple = ()) -> list:
"""Given a query, return cursor.fetchall() rows."""

Loading…
Cancel
Save