mirror of https://github.com/lukechilds/lnbits.git
Browse Source
a big refactor that: - fixes some issues that might have happened (or not) with asynchronous reactions to payments; - paves the way to https://github.com/lnbits/lnbits/issues/121; - uses more async/await notation which just looks nice; and - makes it simple(r?) for one extension to modify stuff from other extensions.livestream
fiatjaf
4 years ago
68 changed files with 976 additions and 1080 deletions
@ -1,66 +1,85 @@ |
|||||
import os |
import os |
||||
import sqlite3 |
from typing import Tuple, Optional, Any |
||||
|
from sqlalchemy_aio import TRIO_STRATEGY # type: ignore |
||||
|
from sqlalchemy import create_engine # type: ignore |
||||
|
from quart import g |
||||
|
|
||||
from .settings import LNBITS_DATA_FOLDER |
from .settings import LNBITS_DATA_FOLDER |
||||
|
|
||||
|
|
||||
class Database: |
class Database: |
||||
def __init__(self, db_path: str): |
def __init__(self, db_name: str): |
||||
self.path = db_path |
self.db_name = db_name |
||||
self.connection = sqlite3.connect(db_path) |
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") |
||||
self.connection.row_factory = sqlite3.Row |
self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY) |
||||
self.cursor = self.connection.cursor() |
|
||||
self.closed = False |
|
||||
|
|
||||
def close(self): |
def connect(self): |
||||
self.__exit__(None, None, None) |
return self.engine.connect() |
||||
|
|
||||
def __enter__(self): |
def session_connection(self) -> Tuple[Optional[Any], Optional[Any]]: |
||||
return self |
|
||||
|
|
||||
def __exit__(self, exc_type, exc_val, exc_tb): |
|
||||
if self.closed: |
|
||||
return |
|
||||
|
|
||||
if exc_val: |
|
||||
self.connection.rollback() |
|
||||
self.cursor.close() |
|
||||
self.connection.close() |
|
||||
else: |
|
||||
self.connection.commit() |
|
||||
self.cursor.close() |
|
||||
self.connection.close() |
|
||||
|
|
||||
self.closed = True |
|
||||
|
|
||||
def commit(self): |
|
||||
self.connection.commit() |
|
||||
|
|
||||
def rollback(self): |
|
||||
self.connection.rollback() |
|
||||
|
|
||||
def fetchall(self, query: str, values: tuple = ()) -> list: |
|
||||
"""Given a query, return cursor.fetchall() rows.""" |
|
||||
self.execute(query, values) |
|
||||
return self.cursor.fetchall() |
|
||||
|
|
||||
def fetchone(self, query: str, values: tuple = ()): |
|
||||
self.execute(query, values) |
|
||||
return self.cursor.fetchone() |
|
||||
|
|
||||
def execute(self, query: str, values: tuple = ()) -> None: |
|
||||
"""Given a query, cursor.execute() it.""" |
|
||||
try: |
try: |
||||
self.cursor.execute(query, values) |
return getattr(g, f"{self.db_name}_conn", None), getattr(g, f"{self.db_name}_txn", None) |
||||
except sqlite3.Error as exc: |
except RuntimeError: |
||||
self.connection.rollback() |
return None, None |
||||
raise exc |
|
||||
|
|
||||
|
|
||||
def open_db(db_name: str = "database") -> Database: |
|
||||
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") |
|
||||
return Database(db_path=db_path) |
|
||||
|
|
||||
|
async def begin(self): |
||||
|
conn, _ = self.session_connection() |
||||
|
if conn: |
||||
|
return |
||||
|
|
||||
def open_ext_db(extension_name: str) -> Database: |
conn = await self.engine.connect() |
||||
return open_db(f"ext_{extension_name}") |
setattr(g, f"{self.db_name}_conn", conn) |
||||
|
txn = await conn.begin() |
||||
|
setattr(g, f"{self.db_name}_txn", txn) |
||||
|
|
||||
|
async def fetchall(self, query: str, values: tuple = ()) -> list: |
||||
|
conn, _ = self.session_connection() |
||||
|
if conn: |
||||
|
result = await conn.execute(query, values) |
||||
|
return await result.fetchall() |
||||
|
|
||||
|
async with self.connect() as conn: |
||||
|
result = await conn.execute(query, values) |
||||
|
return await result.fetchall() |
||||
|
|
||||
|
async def fetchone(self, query: str, values: tuple = ()): |
||||
|
conn, _ = self.session_connection() |
||||
|
if conn: |
||||
|
result = await conn.execute(query, values) |
||||
|
row = await result.fetchone() |
||||
|
await result.close() |
||||
|
return row |
||||
|
|
||||
|
async with self.connect() as conn: |
||||
|
result = await conn.execute(query, values) |
||||
|
row = await result.fetchone() |
||||
|
await result.close() |
||||
|
return row |
||||
|
|
||||
|
async def execute(self, query: str, values: tuple = ()): |
||||
|
conn, _ = self.session_connection() |
||||
|
if conn: |
||||
|
return await conn.execute(query, values) |
||||
|
|
||||
|
async with self.connect() as conn: |
||||
|
return await conn.execute(query, values) |
||||
|
|
||||
|
async def commit(self): |
||||
|
conn, txn = self.session_connection() |
||||
|
if conn and txn: |
||||
|
await txn.commit() |
||||
|
await self.close_session() |
||||
|
|
||||
|
async def rollback(self): |
||||
|
conn, txn = self.session_connection() |
||||
|
if conn and txn: |
||||
|
await txn.rollback() |
||||
|
await self.close_session() |
||||
|
|
||||
|
async def close_session(self): |
||||
|
conn, txn = self.session_connection() |
||||
|
if conn and txn: |
||||
|
await txn.close() |
||||
|
await conn.close() |
||||
|
delattr(g, f"{self.db_name}_conn") |
||||
|
delattr(g, f"{self.db_name}_txn") |
||||
|
@ -1,2 +1,2 @@ |
|||||
def migrate(): |
async def migrate(): |
||||
pass |
pass |
||||
|
@ -1,45 +1,43 @@ |
|||||
from typing import List, Optional, Union |
from typing import List, Optional, Union |
||||
|
|
||||
from lnbits.db import open_ext_db |
|
||||
from lnbits.helpers import urlsafe_short_hash |
from lnbits.helpers import urlsafe_short_hash |
||||
|
|
||||
|
from . import db |
||||
from .models import Paywall |
from .models import Paywall |
||||
|
|
||||
|
|
||||
def create_paywall( |
async def create_paywall( |
||||
*, wallet_id: str, url: str, memo: str, description: Optional[str] = None, amount: int = 0, remembers: bool = True |
*, wallet_id: str, url: str, memo: str, description: Optional[str] = None, amount: int = 0, remembers: bool = True |
||||
) -> Paywall: |
) -> Paywall: |
||||
with open_ext_db("paywall") as db: |
paywall_id = urlsafe_short_hash() |
||||
paywall_id = urlsafe_short_hash() |
await db.execute( |
||||
db.execute( |
""" |
||||
""" |
INSERT INTO paywalls (id, wallet, url, memo, description, amount, remembers) |
||||
INSERT INTO paywalls (id, wallet, url, memo, description, amount, remembers) |
VALUES (?, ?, ?, ?, ?, ?, ?) |
||||
VALUES (?, ?, ?, ?, ?, ?, ?) |
""", |
||||
""", |
(paywall_id, wallet_id, url, memo, description, amount, int(remembers)), |
||||
(paywall_id, wallet_id, url, memo, description, amount, int(remembers)), |
) |
||||
) |
|
||||
|
|
||||
return get_paywall(paywall_id) |
paywall = await get_paywall(paywall_id) |
||||
|
assert paywall, "Newly created paywall couldn't be retrieved" |
||||
|
return paywall |
||||
|
|
||||
|
|
||||
def get_paywall(paywall_id: str) -> Optional[Paywall]: |
async def get_paywall(paywall_id: str) -> Optional[Paywall]: |
||||
with open_ext_db("paywall") as db: |
row = await db.fetchone("SELECT * FROM paywalls WHERE id = ?", (paywall_id,)) |
||||
row = db.fetchone("SELECT * FROM paywalls WHERE id = ?", (paywall_id,)) |
|
||||
|
|
||||
return Paywall.from_row(row) if row else None |
return Paywall.from_row(row) if row else None |
||||
|
|
||||
|
|
||||
def get_paywalls(wallet_ids: Union[str, List[str]]) -> List[Paywall]: |
async def get_paywalls(wallet_ids: Union[str, List[str]]) -> List[Paywall]: |
||||
if isinstance(wallet_ids, str): |
if isinstance(wallet_ids, str): |
||||
wallet_ids = [wallet_ids] |
wallet_ids = [wallet_ids] |
||||
|
|
||||
with open_ext_db("paywall") as db: |
q = ",".join(["?"] * len(wallet_ids)) |
||||
q = ",".join(["?"] * len(wallet_ids)) |
rows = await db.fetchall(f"SELECT * FROM paywalls WHERE wallet IN ({q})", (*wallet_ids,)) |
||||
rows = db.fetchall(f"SELECT * FROM paywalls WHERE wallet IN ({q})", (*wallet_ids,)) |
|
||||
|
|
||||
return [Paywall.from_row(row) for row in rows] |
return [Paywall.from_row(row) for row in rows] |
||||
|
|
||||
|
|
||||
def delete_paywall(paywall_id: str) -> None: |
async def delete_paywall(paywall_id: str) -> None: |
||||
with open_ext_db("paywall") as db: |
await db.execute("DELETE FROM paywalls WHERE id = ?", (paywall_id,)) |
||||
db.execute("DELETE FROM paywalls WHERE id = ?", (paywall_id,)) |
|
||||
|
@ -1,43 +1,40 @@ |
|||||
from typing import List, Optional, Union |
from typing import List, Optional, Union |
||||
|
|
||||
from lnbits.db import open_ext_db |
|
||||
from lnbits.helpers import urlsafe_short_hash |
from lnbits.helpers import urlsafe_short_hash |
||||
|
|
||||
|
from . import db |
||||
from .models import TPoS |
from .models import TPoS |
||||
|
|
||||
|
|
||||
def create_tpos(*, wallet_id: str, name: str, currency: str) -> TPoS: |
async def create_tpos(*, wallet_id: str, name: str, currency: str) -> TPoS: |
||||
with open_ext_db("tpos") as db: |
tpos_id = urlsafe_short_hash() |
||||
tpos_id = urlsafe_short_hash() |
await db.execute( |
||||
db.execute( |
""" |
||||
""" |
INSERT INTO tposs (id, wallet, name, currency) |
||||
INSERT INTO tposs (id, wallet, name, currency) |
VALUES (?, ?, ?, ?) |
||||
VALUES (?, ?, ?, ?) |
""", |
||||
""", |
(tpos_id, wallet_id, name, currency), |
||||
(tpos_id, wallet_id, name, currency), |
) |
||||
) |
|
||||
|
|
||||
return get_tpos(tpos_id) |
tpos = await get_tpos(tpos_id) |
||||
|
assert tpos, "Newly created tpos couldn't be retrieved" |
||||
|
return tpos |
||||
|
|
||||
|
|
||||
def get_tpos(tpos_id: str) -> Optional[TPoS]: |
async def get_tpos(tpos_id: str) -> Optional[TPoS]: |
||||
with open_ext_db("tpos") as db: |
row = await db.fetchone("SELECT * FROM tposs WHERE id = ?", (tpos_id,)) |
||||
row = db.fetchone("SELECT * FROM tposs WHERE id = ?", (tpos_id,)) |
|
||||
|
|
||||
return TPoS.from_row(row) if row else None |
return TPoS.from_row(row) if row else None |
||||
|
|
||||
|
|
||||
def get_tposs(wallet_ids: Union[str, List[str]]) -> List[TPoS]: |
async def get_tposs(wallet_ids: Union[str, List[str]]) -> List[TPoS]: |
||||
if isinstance(wallet_ids, str): |
if isinstance(wallet_ids, str): |
||||
wallet_ids = [wallet_ids] |
wallet_ids = [wallet_ids] |
||||
|
|
||||
with open_ext_db("tpos") as db: |
q = ",".join(["?"] * len(wallet_ids)) |
||||
q = ",".join(["?"] * len(wallet_ids)) |
rows = await db.fetchall(f"SELECT * FROM tposs WHERE wallet IN ({q})", (*wallet_ids,)) |
||||
rows = db.fetchall(f"SELECT * FROM tposs WHERE wallet IN ({q})", (*wallet_ids,)) |
|
||||
|
|
||||
return [TPoS.from_row(row) for row in rows] |
return [TPoS.from_row(row) for row in rows] |
||||
|
|
||||
|
|
||||
def delete_tpos(tpos_id: str) -> None: |
async def delete_tpos(tpos_id: str) -> None: |
||||
with open_ext_db("tpos") as db: |
await db.execute("DELETE FROM tposs WHERE id = ?", (tpos_id,)) |
||||
db.execute("DELETE FROM tposs WHERE id = ?", (tpos_id,)) |
|
||||
|
Loading…
Reference in new issue