diff --git a/lnbits/app.py b/lnbits/app.py index 8184b04..5df599d 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -1,4 +1,5 @@ import importlib +import asyncio from quart import Quart, g from quart_cors import cors # type: ignore @@ -30,6 +31,7 @@ def create_app(config_object="lnbits.settings") -> Quart: register_filters(app) register_commands(app) register_request_hooks(app) + register_async_tasks(app) return app @@ -86,3 +88,20 @@ def register_request_hooks(app: Quart): @app.teardown_request async def after_request(exc): g.db.__exit__(type(exc), exc, None) + + +def register_async_tasks(app): + from lnbits.core.tasks import invoice_listener, webhook_handler + + @app.route("/wallet/webhook") + async def webhook_listener(): + return await webhook_handler() + + @app.before_serving + async def listeners(): + loop = asyncio.get_event_loop() + loop.create_task(invoice_listener(app)) + + @app.after_serving + async def stop_listeners(): + pass diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index 6d19c90..984492a 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -131,6 +131,19 @@ def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: # --------------- +def get_standalone_payment(checking_id: str) -> Optional[Payment]: + row = g.db.fetchone( + """ + SELECT * + FROM apipayments + WHERE checking_id = ? + """, + (checking_id,), + ) + + return Payment.from_row(row) if row else None + + def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]: row = g.db.fetchone( """ diff --git a/lnbits/core/models.py b/lnbits/core/models.py index 24d7649..243f934 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -2,6 +2,8 @@ import json from typing import List, NamedTuple, Optional, Dict from sqlite3 import Row +from lnbits.settings import WALLET + class User(NamedTuple): id: str @@ -113,6 +115,17 @@ class Payment(NamedTuple): update_payment_status(self.checking_id, pending) + def check_pending(self) -> None: + if self.is_uncheckable: + return + + if self.is_out: + pending = WALLET.get_payment_status(self.checking_id) + else: + pending = WALLET.get_invoice_status(self.checking_id) + + self.set_pending(pending.pending) + def delete(self) -> None: from .crud import delete_payment diff --git a/lnbits/core/tasks.py b/lnbits/core/tasks.py index be48d8a..6ab684a 100644 --- a/lnbits/core/tasks.py +++ b/lnbits/core/tasks.py @@ -1,9 +1,13 @@ import asyncio -from typing import Optional, Awaitable +from typing import Optional, List, Awaitable, Tuple, Callable from quart import Quart, Request, g from werkzeug.datastructures import Headers -from lnbits.db import open_db +from lnbits.db import open_db, open_ext_db +from lnbits.settings import WALLET + +from .models import Payment +from .crud import get_standalone_payment main_app: Optional[Quart] = None @@ -31,3 +35,37 @@ def run_on_pseudo_request(awaitable: Awaitable): loop = asyncio.get_event_loop() loop.create_task(run(awaitable)) + + +invoice_listeners: List[Tuple[str, Callable[[Payment], Awaitable[None]]]] = [] + + +def register_invoice_listener(ext_name: str, callback: Callable[[Payment], Awaitable[None]]): + """ + A method intended for extensions to call when they want to be notified about + new invoice payments incoming. + """ + print("registering callback", callback) + invoice_listeners.append((ext_name, callback)) + + +async def webhook_handler(): + handler = getattr(WALLET, "webhook_listener", None) + if handler: + await handler() + + +async def invoice_listener(app): + run_on_pseudo_request(_invoice_listener()) + + +async def _invoice_listener(): + async for checking_id in WALLET.paid_invoices_stream(): + # do this just so the g object is available + g.db = await open_db() + payment = await get_standalone_payment(checking_id) + if payment.is_in: + await payment.set_pending(False) + for ext_name, cb in invoice_listeners: + g.ext_db = await open_ext_db(ext_name) + cb(payment) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index a4fefc4..72c23c7 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -7,7 +7,6 @@ from lnbits.core import core_app from lnbits.core.services import create_invoice, pay_invoice from lnbits.core.crud import delete_expired_invoices from lnbits.decorators import api_check_wallet_key, api_validate_post_request -from lnbits.settings import WALLET @core_app.route("/api/v1/wallet", methods=["GET"]) @@ -32,10 +31,7 @@ async def api_payments(): delete_expired_invoices() for payment in g.wallet.get_payments(complete=False, pending=True, exclude_uncheckable=True): - if payment.is_out: - payment.set_pending(WALLET.get_payment_status(payment.checking_id).pending) - else: - payment.set_pending(WALLET.get_invoice_status(payment.checking_id).pending) + payment.check_pending() return jsonify(g.wallet.get_payments(pending=True)), HTTPStatus.OK @@ -123,17 +119,8 @@ async def api_payment(payment_hash): return jsonify({"paid": True}), HTTPStatus.OK try: - if payment.is_uncheckable: - pass - elif payment.is_out: - is_paid = not WALLET.get_payment_status(payment.checking_id).pending - elif payment.is_in: - is_paid = not WALLET.get_invoice_status(payment.checking_id).pending + payment.check_pending() except Exception: return jsonify({"paid": False}), HTTPStatus.OK - if is_paid: - payment.set_pending(False) - return jsonify({"paid": True}), HTTPStatus.OK - - return jsonify({"paid": False}), HTTPStatus.OK + return jsonify({"paid": not payment.pending}), HTTPStatus.OK diff --git a/lnbits/extensions/lnurlp/__init__.py b/lnbits/extensions/lnurlp/__init__.py index 4fb6466..319c256 100644 --- a/lnbits/extensions/lnurlp/__init__.py +++ b/lnbits/extensions/lnurlp/__init__.py @@ -6,3 +6,8 @@ lnurlp_ext: Blueprint = Blueprint("lnurlp", __name__, static_folder="static", te from .views_api import * # noqa from .views import * # noqa +from .tasks import on_invoice_paid + +from lnbits.core.tasks import register_invoice_listener + +register_invoice_listener("lnurlp", on_invoice_paid) diff --git a/lnbits/extensions/lnurlp/migrations.py b/lnbits/extensions/lnurlp/migrations.py index cdb8e9a..e5569df 100644 --- a/lnbits/extensions/lnurlp/migrations.py +++ b/lnbits/extensions/lnurlp/migrations.py @@ -14,3 +14,21 @@ def m001_initial(db): ); """ ) + + +# def m002_webhooks_and_success_actions(db): +# """ +# Webhooks and success actions. +# """ +# db.execute("ALTER TABLE pay_links ADD COLUMN webhook_url TEXT;") +# db.execute("ALTER TABLE pay_links ADD COLUMN success_text TEXT;") +# db.execute("ALTER TABLE pay_links ADD COLUMN success_url TEXT;") +# db.execute( +# """ +# CREATE TABLE invoices ( +# payment_hash PRIMARY KEY, +# link_id INTEGER NOT NULL REFERENCES pay_links (id), +# webhook_sent BOOLEAN NOT NULL DEFAULT false +# ); +# """ +# ) diff --git a/lnbits/extensions/lnurlp/models.py b/lnbits/extensions/lnurlp/models.py index 6cac5af..e376cf7 100644 --- a/lnbits/extensions/lnurlp/models.py +++ b/lnbits/extensions/lnurlp/models.py @@ -7,12 +7,15 @@ from typing import NamedTuple class PayLink(NamedTuple): - id: str + id: int wallet: str description: str amount: int served_meta: int served_pr: int + webhook_url: str + success_text: str + success_url: str @classmethod def from_row(cls, row: Row) -> "PayLink": @@ -27,3 +30,9 @@ class PayLink(NamedTuple): @property def lnurlpay_metadata(self) -> LnurlPayMetadata: return LnurlPayMetadata(json.dumps([["text/plain", self.description]])) + + +class Invoice(NamedTuple): + payment_hash: str + link_id: int + webhook_sent: bool diff --git a/lnbits/extensions/lnurlp/tasks.py b/lnbits/extensions/lnurlp/tasks.py new file mode 100644 index 0000000..3e986b2 --- /dev/null +++ b/lnbits/extensions/lnurlp/tasks.py @@ -0,0 +1,12 @@ +import aiohttp + +from lnbits.core.models import Payment + + +async def on_invoice_paid(payment: Payment) -> None: + islnurlp = "lnurlp" in payment.extra.get("tags", {}) + print("invoice paid on lnurlp?", islnurlp) + if islnurlp: + print("dispatching webhook") + async with aiohttp.ClientSession() as session: + await session.post("https://fiatjaf.free.beeceptor.com", json=payment) diff --git a/lnbits/extensions/lnurlp/views_api.py b/lnbits/extensions/lnurlp/views_api.py index 852a68e..3ad9ac0 100644 --- a/lnbits/extensions/lnurlp/views_api.py +++ b/lnbits/extensions/lnurlp/views_api.py @@ -4,6 +4,7 @@ from http import HTTPStatus from lnurl import LnurlPayResponse, LnurlPayActionResponse from lnurl.exceptions import InvalidUrl as LnurlInvalidUrl +from lnbits import bolt11 from lnbits.core.crud import get_user from lnbits.core.services import create_invoice from lnbits.decorators import api_check_wallet_key, api_validate_post_request @@ -126,6 +127,10 @@ async def api_lnurl_callback(link_id): description_hash=hashlib.sha256(link.lnurlpay_metadata.encode("utf-8")).digest(), extra={"tag": "lnurlp"}, ) + + inv = bolt11.decode(payment_request) + inv.payment_hash + resp = LnurlPayActionResponse(pr=payment_request, success_action=None, routes=[]) return jsonify(resp.dict()), HTTPStatus.OK diff --git a/lnbits/wallets/lnpay.py b/lnbits/wallets/lnpay.py index db0d417..53be5ec 100644 --- a/lnbits/wallets/lnpay.py +++ b/lnbits/wallets/lnpay.py @@ -1,6 +1,9 @@ +import asyncio +import aiohttp from os import getenv -from typing import Optional, Dict +from typing import Optional, Dict, AsyncGenerator from requests import get, post +from quart import request from .base import InvoiceResponse, PaymentResponse, PaymentStatus, Wallet @@ -15,9 +18,13 @@ class LNPayWallet(Wallet): self.auth_invoice = getenv("LNPAY_INVOICE_KEY") self.auth_read = getenv("LNPAY_READ_KEY") self.auth_api = {"X-Api-Key": getenv("LNPAY_API_KEY")} + self.queue = asyncio.Queue() def create_invoice( - self, amount: int, memo: Optional[str] = None, description_hash: Optional[bytes] = None + self, + amount: int, + memo: Optional[str] = None, + description_hash: Optional[bytes] = None, ) -> InvoiceResponse: data: Dict = {"num_satoshis": f"{amount}"} if description_hash: @@ -30,7 +37,12 @@ class LNPayWallet(Wallet): headers=self.auth_api, json=data, ) - ok, checking_id, payment_request, error_message = r.status_code == 201, None, None, r.text + ok, checking_id, payment_request, error_message = ( + r.status_code == 201, + None, + None, + r.text, + ) if ok: data = r.json() @@ -55,10 +67,30 @@ class LNPayWallet(Wallet): return self.get_payment_status(checking_id) def get_payment_status(self, checking_id: str) -> PaymentStatus: - r = get(url=f"{self.endpoint}/user/lntx/{checking_id}", headers=self.auth_api) + r = get( + url=f"{self.endpoint}/user/lntx/{checking_id}?fields=settled", + headers=self.auth_api, + ) if not r.ok: return PaymentStatus(None) statuses = {0: None, 1: True, -1: False} return PaymentStatus(statuses[r.json()["settled"]]) + + async def paid_invoices_stream(self) -> AsyncGenerator[str, None]: + while True: + yield await self.queue.get() + self.queue.task_done() + + async def webhook_listener(self): + data = await request.get_json() + if "event" not in data or data["event"].get("name") != "wallet_receive": + return "" + + lntx_id = data["data"]["wtx"]["lnTx"]["id"] + async with aiohttp.ClientSession() as session: + async with session.get(f"{self.endpoint}/user/lntx/{lntx_id}?fields=settled") as resp: + data = await resp.json() + if data["settled"]: + self.queue.put_nowait(lntx_id) diff --git a/lnbits/wallets/spark.py b/lnbits/wallets/spark.py index 2021634..c4e0708 100644 --- a/lnbits/wallets/spark.py +++ b/lnbits/wallets/spark.py @@ -1,7 +1,9 @@ import random import requests +import json +from aiohttp_sse_client import client as sse_client from os import getenv -from typing import Optional +from typing import Optional, AsyncGenerator from .base import InvoiceResponse, PaymentResponse, PaymentStatus, Wallet @@ -16,7 +18,7 @@ class UnknownError(Exception): class SparkWallet(Wallet): def __init__(self): - self.url = getenv("SPARK_URL") + self.url = getenv("SPARK_URL").replace("/rpc", "") self.token = getenv("SPARK_TOKEN") def __getattr__(self, key): @@ -28,7 +30,9 @@ class SparkWallet(Wallet): elif kwargs: params = kwargs - r = requests.post(self.url, headers={"X-Access": self.token}, json={"method": key, "params": params}) + r = requests.post( + self.url + "/rpc", headers={"X-Access": self.token}, json={"method": key, "params": params} + ) try: data = r.json() except: @@ -91,3 +95,15 @@ class SparkWallet(Wallet): return PaymentStatus(False) return PaymentStatus(None) raise KeyError("supplied an invalid checking_id") + + async def paid_invoices_stream(self) -> AsyncGenerator[str, None]: + url = self.url + "/stream?access-key=" + self.token + conn = sse_client.EventSource(url) + async with conn as es: + async for event in es: + try: + if event.type == "inv-paid": + data = json.loads(event.data) + yield data["label"] + except ConnectionError: + pass