mirror of https://github.com/lukechilds/lnbits.git
Browse Source
a big refactor that: - fixes some issues that might have happened (or not) with asynchronous reactions to payments; - paves the way to https://github.com/lnbits/lnbits/issues/121; - uses more async/await notation which just looks nice; and - makes it simple(r?) for one extension to modify stuff from other extensions.livestream
fiatjaf
4 years ago
68 changed files with 976 additions and 1080 deletions
@ -1,66 +1,85 @@ |
|||||
import os |
import os |
||||
import sqlite3 |
from typing import Tuple, Optional, Any |
||||
|
from sqlalchemy_aio import TRIO_STRATEGY # type: ignore |
||||
|
from sqlalchemy import create_engine # type: ignore |
||||
|
from quart import g |
||||
|
|
||||
from .settings import LNBITS_DATA_FOLDER |
from .settings import LNBITS_DATA_FOLDER |
||||
|
|
||||
|
|
||||
class Database: |
class Database: |
||||
def __init__(self, db_path: str): |
def __init__(self, db_name: str): |
||||
self.path = db_path |
self.db_name = db_name |
||||
self.connection = sqlite3.connect(db_path) |
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") |
||||
self.connection.row_factory = sqlite3.Row |
self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY) |
||||
self.cursor = self.connection.cursor() |
|
||||
self.closed = False |
|
||||
|
|
||||
def close(self): |
|
||||
self.__exit__(None, None, None) |
|
||||
|
|
||||
def __enter__(self): |
|
||||
return self |
|
||||
|
|
||||
def __exit__(self, exc_type, exc_val, exc_tb): |
|
||||
if self.closed: |
|
||||
return |
|
||||
|
|
||||
if exc_val: |
|
||||
self.connection.rollback() |
|
||||
self.cursor.close() |
|
||||
self.connection.close() |
|
||||
else: |
|
||||
self.connection.commit() |
|
||||
self.cursor.close() |
|
||||
self.connection.close() |
|
||||
|
|
||||
self.closed = True |
|
||||
|
|
||||
def commit(self): |
|
||||
self.connection.commit() |
|
||||
|
|
||||
def rollback(self): |
|
||||
self.connection.rollback() |
|
||||
|
|
||||
def fetchall(self, query: str, values: tuple = ()) -> list: |
|
||||
"""Given a query, return cursor.fetchall() rows.""" |
|
||||
self.execute(query, values) |
|
||||
return self.cursor.fetchall() |
|
||||
|
|
||||
def fetchone(self, query: str, values: tuple = ()): |
def connect(self): |
||||
self.execute(query, values) |
return self.engine.connect() |
||||
return self.cursor.fetchone() |
|
||||
|
|
||||
def execute(self, query: str, values: tuple = ()) -> None: |
def session_connection(self) -> Tuple[Optional[Any], Optional[Any]]: |
||||
"""Given a query, cursor.execute() it.""" |
|
||||
try: |
try: |
||||
self.cursor.execute(query, values) |
return getattr(g, f"{self.db_name}_conn", None), getattr(g, f"{self.db_name}_txn", None) |
||||
except sqlite3.Error as exc: |
except RuntimeError: |
||||
self.connection.rollback() |
return None, None |
||||
raise exc |
|
||||
|
|
||||
|
|
||||
def open_db(db_name: str = "database") -> Database: |
|
||||
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") |
|
||||
return Database(db_path=db_path) |
|
||||
|
|
||||
|
async def begin(self): |
||||
|
conn, _ = self.session_connection() |
||||
|
if conn: |
||||
|
return |
||||
|
|
||||
def open_ext_db(extension_name: str) -> Database: |
conn = await self.engine.connect() |
||||
return open_db(f"ext_{extension_name}") |
setattr(g, f"{self.db_name}_conn", conn) |
||||
|
txn = await conn.begin() |
||||
|
setattr(g, f"{self.db_name}_txn", txn) |
||||
|
|
||||
|
async def fetchall(self, query: str, values: tuple = ()) -> list: |
||||
|
conn, _ = self.session_connection() |
||||
|
if conn: |
||||
|
result = await conn.execute(query, values) |
||||
|
return await result.fetchall() |
||||
|
|
||||
|
async with self.connect() as conn: |
||||
|
result = await conn.execute(query, values) |
||||
|
return await result.fetchall() |
||||
|
|
||||
|
async def fetchone(self, query: str, values: tuple = ()): |
||||
|
conn, _ = self.session_connection() |
||||
|
if conn: |
||||
|
result = await conn.execute(query, values) |
||||
|
row = await result.fetchone() |
||||
|
await result.close() |
||||
|
return row |
||||
|
|
||||
|
async with self.connect() as conn: |
||||
|
result = await conn.execute(query, values) |
||||
|
row = await result.fetchone() |
||||
|
await result.close() |
||||
|
return row |
||||
|
|
||||
|
async def execute(self, query: str, values: tuple = ()): |
||||
|
conn, _ = self.session_connection() |
||||
|
if conn: |
||||
|
return await conn.execute(query, values) |
||||
|
|
||||
|
async with self.connect() as conn: |
||||
|
return await conn.execute(query, values) |
||||
|
|
||||
|
async def commit(self): |
||||
|
conn, txn = self.session_connection() |
||||
|
if conn and txn: |
||||
|
await txn.commit() |
||||
|
await self.close_session() |
||||
|
|
||||
|
async def rollback(self): |
||||
|
conn, txn = self.session_connection() |
||||
|
if conn and txn: |
||||
|
await txn.rollback() |
||||
|
await self.close_session() |
||||
|
|
||||
|
async def close_session(self): |
||||
|
conn, txn = self.session_connection() |
||||
|
if conn and txn: |
||||
|
await txn.close() |
||||
|
await conn.close() |
||||
|
delattr(g, f"{self.db_name}_conn") |
||||
|
delattr(g, f"{self.db_name}_txn") |
||||
|
@ -1,2 +1,2 @@ |
|||||
def migrate(): |
async def migrate(): |
||||
pass |
pass |
||||
|
Loading…
Reference in new issue