mirror of https://github.com/lukechilds/lnbits.git
Browse Source
a big refactor that: - fixes some issues that might have happened (or not) with asynchronous reactions to payments; - paves the way to https://github.com/lnbits/lnbits/issues/121; - uses more async/await notation which just looks nice; and - makes it simple(r?) for one extension to modify stuff from other extensions.livestream
fiatjaf
4 years ago
68 changed files with 976 additions and 1080 deletions
@ -1,66 +1,85 @@ |
|||
import os |
|||
import sqlite3 |
|||
from typing import Tuple, Optional, Any |
|||
from sqlalchemy_aio import TRIO_STRATEGY # type: ignore |
|||
from sqlalchemy import create_engine # type: ignore |
|||
from quart import g |
|||
|
|||
from .settings import LNBITS_DATA_FOLDER |
|||
|
|||
|
|||
class Database: |
|||
def __init__(self, db_path: str): |
|||
self.path = db_path |
|||
self.connection = sqlite3.connect(db_path) |
|||
self.connection.row_factory = sqlite3.Row |
|||
self.cursor = self.connection.cursor() |
|||
self.closed = False |
|||
def __init__(self, db_name: str): |
|||
self.db_name = db_name |
|||
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") |
|||
self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY) |
|||
|
|||
def close(self): |
|||
self.__exit__(None, None, None) |
|||
def connect(self): |
|||
return self.engine.connect() |
|||
|
|||
def __enter__(self): |
|||
return self |
|||
|
|||
def __exit__(self, exc_type, exc_val, exc_tb): |
|||
if self.closed: |
|||
return |
|||
|
|||
if exc_val: |
|||
self.connection.rollback() |
|||
self.cursor.close() |
|||
self.connection.close() |
|||
else: |
|||
self.connection.commit() |
|||
self.cursor.close() |
|||
self.connection.close() |
|||
|
|||
self.closed = True |
|||
|
|||
def commit(self): |
|||
self.connection.commit() |
|||
|
|||
def rollback(self): |
|||
self.connection.rollback() |
|||
|
|||
def fetchall(self, query: str, values: tuple = ()) -> list: |
|||
"""Given a query, return cursor.fetchall() rows.""" |
|||
self.execute(query, values) |
|||
return self.cursor.fetchall() |
|||
|
|||
def fetchone(self, query: str, values: tuple = ()): |
|||
self.execute(query, values) |
|||
return self.cursor.fetchone() |
|||
|
|||
def execute(self, query: str, values: tuple = ()) -> None: |
|||
"""Given a query, cursor.execute() it.""" |
|||
def session_connection(self) -> Tuple[Optional[Any], Optional[Any]]: |
|||
try: |
|||
self.cursor.execute(query, values) |
|||
except sqlite3.Error as exc: |
|||
self.connection.rollback() |
|||
raise exc |
|||
|
|||
|
|||
def open_db(db_name: str = "database") -> Database: |
|||
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") |
|||
return Database(db_path=db_path) |
|||
return getattr(g, f"{self.db_name}_conn", None), getattr(g, f"{self.db_name}_txn", None) |
|||
except RuntimeError: |
|||
return None, None |
|||
|
|||
async def begin(self): |
|||
conn, _ = self.session_connection() |
|||
if conn: |
|||
return |
|||
|
|||
def open_ext_db(extension_name: str) -> Database: |
|||
return open_db(f"ext_{extension_name}") |
|||
conn = await self.engine.connect() |
|||
setattr(g, f"{self.db_name}_conn", conn) |
|||
txn = await conn.begin() |
|||
setattr(g, f"{self.db_name}_txn", txn) |
|||
|
|||
async def fetchall(self, query: str, values: tuple = ()) -> list: |
|||
conn, _ = self.session_connection() |
|||
if conn: |
|||
result = await conn.execute(query, values) |
|||
return await result.fetchall() |
|||
|
|||
async with self.connect() as conn: |
|||
result = await conn.execute(query, values) |
|||
return await result.fetchall() |
|||
|
|||
async def fetchone(self, query: str, values: tuple = ()): |
|||
conn, _ = self.session_connection() |
|||
if conn: |
|||
result = await conn.execute(query, values) |
|||
row = await result.fetchone() |
|||
await result.close() |
|||
return row |
|||
|
|||
async with self.connect() as conn: |
|||
result = await conn.execute(query, values) |
|||
row = await result.fetchone() |
|||
await result.close() |
|||
return row |
|||
|
|||
async def execute(self, query: str, values: tuple = ()): |
|||
conn, _ = self.session_connection() |
|||
if conn: |
|||
return await conn.execute(query, values) |
|||
|
|||
async with self.connect() as conn: |
|||
return await conn.execute(query, values) |
|||
|
|||
async def commit(self): |
|||
conn, txn = self.session_connection() |
|||
if conn and txn: |
|||
await txn.commit() |
|||
await self.close_session() |
|||
|
|||
async def rollback(self): |
|||
conn, txn = self.session_connection() |
|||
if conn and txn: |
|||
await txn.rollback() |
|||
await self.close_session() |
|||
|
|||
async def close_session(self): |
|||
conn, txn = self.session_connection() |
|||
if conn and txn: |
|||
await txn.close() |
|||
await conn.close() |
|||
delattr(g, f"{self.db_name}_conn") |
|||
delattr(g, f"{self.db_name}_txn") |
|||
|
@ -1,2 +1,2 @@ |
|||
def migrate(): |
|||
async def migrate(): |
|||
pass |
|||
|
@ -1,45 +1,43 @@ |
|||
from typing import List, Optional, Union |
|||
|
|||
from lnbits.db import open_ext_db |
|||
from lnbits.helpers import urlsafe_short_hash |
|||
|
|||
from . import db |
|||
from .models import Paywall |
|||
|
|||
|
|||
def create_paywall( |
|||
async def create_paywall( |
|||
*, wallet_id: str, url: str, memo: str, description: Optional[str] = None, amount: int = 0, remembers: bool = True |
|||
) -> Paywall: |
|||
with open_ext_db("paywall") as db: |
|||
paywall_id = urlsafe_short_hash() |
|||
db.execute( |
|||
""" |
|||
INSERT INTO paywalls (id, wallet, url, memo, description, amount, remembers) |
|||
VALUES (?, ?, ?, ?, ?, ?, ?) |
|||
""", |
|||
(paywall_id, wallet_id, url, memo, description, amount, int(remembers)), |
|||
) |
|||
paywall_id = urlsafe_short_hash() |
|||
await db.execute( |
|||
""" |
|||
INSERT INTO paywalls (id, wallet, url, memo, description, amount, remembers) |
|||
VALUES (?, ?, ?, ?, ?, ?, ?) |
|||
""", |
|||
(paywall_id, wallet_id, url, memo, description, amount, int(remembers)), |
|||
) |
|||
|
|||
return get_paywall(paywall_id) |
|||
paywall = await get_paywall(paywall_id) |
|||
assert paywall, "Newly created paywall couldn't be retrieved" |
|||
return paywall |
|||
|
|||
|
|||
def get_paywall(paywall_id: str) -> Optional[Paywall]: |
|||
with open_ext_db("paywall") as db: |
|||
row = db.fetchone("SELECT * FROM paywalls WHERE id = ?", (paywall_id,)) |
|||
async def get_paywall(paywall_id: str) -> Optional[Paywall]: |
|||
row = await db.fetchone("SELECT * FROM paywalls WHERE id = ?", (paywall_id,)) |
|||
|
|||
return Paywall.from_row(row) if row else None |
|||
|
|||
|
|||
def get_paywalls(wallet_ids: Union[str, List[str]]) -> List[Paywall]: |
|||
async def get_paywalls(wallet_ids: Union[str, List[str]]) -> List[Paywall]: |
|||
if isinstance(wallet_ids, str): |
|||
wallet_ids = [wallet_ids] |
|||
|
|||
with open_ext_db("paywall") as db: |
|||
q = ",".join(["?"] * len(wallet_ids)) |
|||
rows = db.fetchall(f"SELECT * FROM paywalls WHERE wallet IN ({q})", (*wallet_ids,)) |
|||
q = ",".join(["?"] * len(wallet_ids)) |
|||
rows = await db.fetchall(f"SELECT * FROM paywalls WHERE wallet IN ({q})", (*wallet_ids,)) |
|||
|
|||
return [Paywall.from_row(row) for row in rows] |
|||
|
|||
|
|||
def delete_paywall(paywall_id: str) -> None: |
|||
with open_ext_db("paywall") as db: |
|||
db.execute("DELETE FROM paywalls WHERE id = ?", (paywall_id,)) |
|||
async def delete_paywall(paywall_id: str) -> None: |
|||
await db.execute("DELETE FROM paywalls WHERE id = ?", (paywall_id,)) |
|||
|
@ -1,43 +1,40 @@ |
|||
from typing import List, Optional, Union |
|||
|
|||
from lnbits.db import open_ext_db |
|||
from lnbits.helpers import urlsafe_short_hash |
|||
|
|||
from . import db |
|||
from .models import TPoS |
|||
|
|||
|
|||
def create_tpos(*, wallet_id: str, name: str, currency: str) -> TPoS: |
|||
with open_ext_db("tpos") as db: |
|||
tpos_id = urlsafe_short_hash() |
|||
db.execute( |
|||
""" |
|||
INSERT INTO tposs (id, wallet, name, currency) |
|||
VALUES (?, ?, ?, ?) |
|||
""", |
|||
(tpos_id, wallet_id, name, currency), |
|||
) |
|||
async def create_tpos(*, wallet_id: str, name: str, currency: str) -> TPoS: |
|||
tpos_id = urlsafe_short_hash() |
|||
await db.execute( |
|||
""" |
|||
INSERT INTO tposs (id, wallet, name, currency) |
|||
VALUES (?, ?, ?, ?) |
|||
""", |
|||
(tpos_id, wallet_id, name, currency), |
|||
) |
|||
|
|||
return get_tpos(tpos_id) |
|||
tpos = await get_tpos(tpos_id) |
|||
assert tpos, "Newly created tpos couldn't be retrieved" |
|||
return tpos |
|||
|
|||
|
|||
def get_tpos(tpos_id: str) -> Optional[TPoS]: |
|||
with open_ext_db("tpos") as db: |
|||
row = db.fetchone("SELECT * FROM tposs WHERE id = ?", (tpos_id,)) |
|||
|
|||
async def get_tpos(tpos_id: str) -> Optional[TPoS]: |
|||
row = await db.fetchone("SELECT * FROM tposs WHERE id = ?", (tpos_id,)) |
|||
return TPoS.from_row(row) if row else None |
|||
|
|||
|
|||
def get_tposs(wallet_ids: Union[str, List[str]]) -> List[TPoS]: |
|||
async def get_tposs(wallet_ids: Union[str, List[str]]) -> List[TPoS]: |
|||
if isinstance(wallet_ids, str): |
|||
wallet_ids = [wallet_ids] |
|||
|
|||
with open_ext_db("tpos") as db: |
|||
q = ",".join(["?"] * len(wallet_ids)) |
|||
rows = db.fetchall(f"SELECT * FROM tposs WHERE wallet IN ({q})", (*wallet_ids,)) |
|||
q = ",".join(["?"] * len(wallet_ids)) |
|||
rows = await db.fetchall(f"SELECT * FROM tposs WHERE wallet IN ({q})", (*wallet_ids,)) |
|||
|
|||
return [TPoS.from_row(row) for row in rows] |
|||
|
|||
|
|||
def delete_tpos(tpos_id: str) -> None: |
|||
with open_ext_db("tpos") as db: |
|||
db.execute("DELETE FROM tposs WHERE id = ?", (tpos_id,)) |
|||
async def delete_tpos(tpos_id: str) -> None: |
|||
await db.execute("DELETE FROM tposs WHERE id = ?", (tpos_id,)) |
|||
|
Loading…
Reference in new issue