mirror of https://github.com/lukechilds/lnbits.git
Browse Source
a big refactor that: - fixes some issues that might have happened (or not) with asynchronous reactions to payments; - paves the way to https://github.com/lnbits/lnbits/issues/121; - uses more async/await notation which just looks nice; and - makes it simple(r?) for one extension to modify stuff from other extensions.livestream
fiatjaf
4 years ago
68 changed files with 976 additions and 1080 deletions
@ -1,66 +1,85 @@ |
|||
import os |
|||
import sqlite3 |
|||
from typing import Tuple, Optional, Any |
|||
from sqlalchemy_aio import TRIO_STRATEGY # type: ignore |
|||
from sqlalchemy import create_engine # type: ignore |
|||
from quart import g |
|||
|
|||
from .settings import LNBITS_DATA_FOLDER |
|||
|
|||
|
|||
class Database: |
|||
def __init__(self, db_path: str): |
|||
self.path = db_path |
|||
self.connection = sqlite3.connect(db_path) |
|||
self.connection.row_factory = sqlite3.Row |
|||
self.cursor = self.connection.cursor() |
|||
self.closed = False |
|||
|
|||
def close(self): |
|||
self.__exit__(None, None, None) |
|||
|
|||
def __enter__(self): |
|||
return self |
|||
|
|||
def __exit__(self, exc_type, exc_val, exc_tb): |
|||
if self.closed: |
|||
return |
|||
|
|||
if exc_val: |
|||
self.connection.rollback() |
|||
self.cursor.close() |
|||
self.connection.close() |
|||
else: |
|||
self.connection.commit() |
|||
self.cursor.close() |
|||
self.connection.close() |
|||
|
|||
self.closed = True |
|||
|
|||
def commit(self): |
|||
self.connection.commit() |
|||
|
|||
def rollback(self): |
|||
self.connection.rollback() |
|||
|
|||
def fetchall(self, query: str, values: tuple = ()) -> list: |
|||
"""Given a query, return cursor.fetchall() rows.""" |
|||
self.execute(query, values) |
|||
return self.cursor.fetchall() |
|||
def __init__(self, db_name: str): |
|||
self.db_name = db_name |
|||
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") |
|||
self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY) |
|||
|
|||
def fetchone(self, query: str, values: tuple = ()): |
|||
self.execute(query, values) |
|||
return self.cursor.fetchone() |
|||
def connect(self): |
|||
return self.engine.connect() |
|||
|
|||
def execute(self, query: str, values: tuple = ()) -> None: |
|||
"""Given a query, cursor.execute() it.""" |
|||
def session_connection(self) -> Tuple[Optional[Any], Optional[Any]]: |
|||
try: |
|||
self.cursor.execute(query, values) |
|||
except sqlite3.Error as exc: |
|||
self.connection.rollback() |
|||
raise exc |
|||
|
|||
|
|||
def open_db(db_name: str = "database") -> Database: |
|||
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") |
|||
return Database(db_path=db_path) |
|||
return getattr(g, f"{self.db_name}_conn", None), getattr(g, f"{self.db_name}_txn", None) |
|||
except RuntimeError: |
|||
return None, None |
|||
|
|||
async def begin(self): |
|||
conn, _ = self.session_connection() |
|||
if conn: |
|||
return |
|||
|
|||
def open_ext_db(extension_name: str) -> Database: |
|||
return open_db(f"ext_{extension_name}") |
|||
conn = await self.engine.connect() |
|||
setattr(g, f"{self.db_name}_conn", conn) |
|||
txn = await conn.begin() |
|||
setattr(g, f"{self.db_name}_txn", txn) |
|||
|
|||
async def fetchall(self, query: str, values: tuple = ()) -> list: |
|||
conn, _ = self.session_connection() |
|||
if conn: |
|||
result = await conn.execute(query, values) |
|||
return await result.fetchall() |
|||
|
|||
async with self.connect() as conn: |
|||
result = await conn.execute(query, values) |
|||
return await result.fetchall() |
|||
|
|||
async def fetchone(self, query: str, values: tuple = ()): |
|||
conn, _ = self.session_connection() |
|||
if conn: |
|||
result = await conn.execute(query, values) |
|||
row = await result.fetchone() |
|||
await result.close() |
|||
return row |
|||
|
|||
async with self.connect() as conn: |
|||
result = await conn.execute(query, values) |
|||
row = await result.fetchone() |
|||
await result.close() |
|||
return row |
|||
|
|||
async def execute(self, query: str, values: tuple = ()): |
|||
conn, _ = self.session_connection() |
|||
if conn: |
|||
return await conn.execute(query, values) |
|||
|
|||
async with self.connect() as conn: |
|||
return await conn.execute(query, values) |
|||
|
|||
async def commit(self): |
|||
conn, txn = self.session_connection() |
|||
if conn and txn: |
|||
await txn.commit() |
|||
await self.close_session() |
|||
|
|||
async def rollback(self): |
|||
conn, txn = self.session_connection() |
|||
if conn and txn: |
|||
await txn.rollback() |
|||
await self.close_session() |
|||
|
|||
async def close_session(self): |
|||
conn, txn = self.session_connection() |
|||
if conn and txn: |
|||
await txn.close() |
|||
await conn.close() |
|||
delattr(g, f"{self.db_name}_conn") |
|||
delattr(g, f"{self.db_name}_txn") |
|||
|
@ -1,2 +1,2 @@ |
|||
def migrate(): |
|||
async def migrate(): |
|||
pass |
|||
|
Loading…
Reference in new issue