diff --git a/lnbits/app.py b/lnbits/app.py index a7f41ea..53f6f4b 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -1,5 +1,5 @@ -import importlib import asyncio +import importlib from quart import Quart, g from quart_cors import cors # type: ignore @@ -8,7 +8,7 @@ from secure import SecureHeaders # type: ignore from .commands import db_migrate, handle_assets from .core import core_app -from .db import open_db +from .db import open_db, open_ext_db from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored from .proxy_fix import ASGIProxyFix @@ -43,7 +43,17 @@ def register_blueprints(app: Quart) -> None: for ext in get_valid_extensions(): try: ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}") - app.register_blueprint(getattr(ext_module, f"{ext.code}_ext"), url_prefix=f"/{ext.code}") + bp = getattr(ext_module, f"{ext.code}_ext") + + @bp.before_request + async def before_request(): + g.ext_db = open_ext_db(ext.code) + + @bp.teardown_request + async def after_request(exc): + g.ext_db.__exit__(type(exc), exc, None) + + app.register_blueprint(bp, url_prefix=f"/{ext.code}") except Exception: raise ImportError(f"Please make sure that the extension `{ext.code}` follows conventions.") @@ -99,8 +109,8 @@ def register_async_tasks(app): @app.before_serving async def listeners(): - loop = asyncio.get_event_loop() - loop.create_task(invoice_listener(app)) + loop = asyncio.get_running_loop() + loop.create_task(invoice_listener()) @app.after_serving async def stop_listeners(): diff --git a/lnbits/core/tasks.py b/lnbits/core/tasks.py index b13f9f5..1dff052 100644 --- a/lnbits/core/tasks.py +++ b/lnbits/core/tasks.py @@ -31,8 +31,8 @@ def run_on_pseudo_request(awaitable: Awaitable): send_push_promise=lambda x, h: None, ) async with main_app.request_context(fk): - g.db = open_db() - await awaitable + with open_db() as g.db: + await awaitable loop = asyncio.get_event_loop() loop.create_task(run(awaitable)) @@ -57,16 +57,15 @@ async def webhook_handler(): return "", HTTPStatus.NO_CONTENT -async def invoice_listener(app): - run_on_pseudo_request(_invoice_listener()) +async def invoice_listener(): + async for checking_id in WALLET.paid_invoices_stream(): + run_on_pseudo_request(invoice_callback_dispatcher(checking_id)) -async def _invoice_listener(): - async for checking_id in WALLET.paid_invoices_stream(): - g.db = open_db() - payment = get_standalone_payment(checking_id) - if payment.is_in: - payment.set_pending(False) - for ext_name, cb in invoice_listeners: - g.ext_db = open_ext_db(ext_name) +async def invoice_callback_dispatcher(checking_id: str): + payment = get_standalone_payment(checking_id) + if payment and payment.is_in: + payment.set_pending(False) + for ext_name, cb in invoice_listeners: + with open_ext_db(ext_name) as g.ext_db: # type: ignore await cb(payment)