From fadddc995a7cdb0aff816c76fd10ca28514dd69a Mon Sep 17 00:00:00 2001
From: fiatjaf <fiatjaf@alhur.es>
Date: Mon, 31 Aug 2020 00:24:57 -0300
Subject: [PATCH] get_wallet_payments with more fine-grained, explicit filters.

---
 lnbits/core/crud.py                   | 34 +++++++++++++++++++++------
 lnbits/core/models.py                 |  6 +++--
 lnbits/core/views/api.py              |  2 +-
 lnbits/extensions/usermanager/crud.py |  6 ++---
 4 files changed, 34 insertions(+), 14 deletions(-)

diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py
index 94ead32..1b73786 100644
--- a/lnbits/core/crud.py
+++ b/lnbits/core/crud.py
@@ -150,18 +150,38 @@ def get_wallet_payment(wallet_id: str, checking_id: str) -> Optional[Payment]:
     return Payment(**row) if row else None
 
 
-def get_wallet_payments(wallet_id: str, *, include_all_pending: bool = False) -> List[Payment]:
-    with open_db() as db:
-        if include_all_pending:
-            clause = "pending = 1"
-        else:
-            clause = "((amount > 0 AND pending = 0) OR amount < 0)"
+def get_wallet_payments(
+    wallet_id: str, *, complete: bool = False, pending: bool = False, outgoing: bool = False, incoming: bool = False
+) -> List[Payment]:
+    """
+    Filters payments to be returned by complete | pending | outgoing | incoming.
+    """
+
+    clause = ""
+    if complete and pending:
+        clause += ""
+    elif complete:
+        clause += "AND ((amount > 0 AND pending = 0) OR amount < 0)"
+    elif pending:
+        clause += "AND pending = 1"
+    else:
+        raise TypeError("at least one of [complete, pending] must be True.")
+
+    if outgoing and incoming:
+        clause += ""
+    elif outgoing:
+        clause += "AND amount < 0"
+    elif incoming:
+        clause += "AND amount > 0"
+    else:
+        raise TypeError("at least one of [outgoing, incoming] must be True.")
 
+    with open_db() as db:
         rows = db.fetchall(
             f"""
             SELECT payhash as checking_id, amount, fee, pending, memo, time
             FROM apipayments
-            WHERE wallet = ? AND {clause}
+            WHERE wallet = ? {clause}
             ORDER BY time DESC
             """,
             (wallet_id,),
diff --git a/lnbits/core/models.py b/lnbits/core/models.py
index f3d5fbd..10a87ad 100644
--- a/lnbits/core/models.py
+++ b/lnbits/core/models.py
@@ -34,10 +34,12 @@ class Wallet(NamedTuple):
 
         return get_wallet_payment(self.id, checking_id)
 
-    def get_payments(self, *, include_all_pending: bool = False) -> List["Payment"]:
+    def get_payments(
+        self, *, complete: bool = True, pending: bool = False, outgoing: bool = True, incoming: bool = True
+    ) -> List["Payment"]:
         from .crud import get_wallet_payments
 
-        return get_wallet_payments(self.id, include_all_pending=include_all_pending)
+        return get_wallet_payments(self.id, complete=complete, pending=pending, outgoing=outgoing, incoming=incoming)
 
     def delete_expired_payments(self, seconds: int = 86400) -> None:
         from .crud import delete_wallet_payments_expired
diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py
index c66d873..47bf875 100644
--- a/lnbits/core/views/api.py
+++ b/lnbits/core/views/api.py
@@ -15,7 +15,7 @@ def api_payments():
     if "check_pending" in request.args:
         g.wallet.delete_expired_payments()
 
-        for payment in g.wallet.get_payments(include_all_pending=True):
+        for payment in g.wallet.get_payments(pending=True):
             if payment.is_out:
                 payment.set_pending(WALLET.get_payment_status(payment.checking_id).pending)
             else:
diff --git a/lnbits/extensions/usermanager/crud.py b/lnbits/extensions/usermanager/crud.py
index 8119f3d..db9556b 100644
--- a/lnbits/extensions/usermanager/crud.py
+++ b/lnbits/extensions/usermanager/crud.py
@@ -1,12 +1,10 @@
 from lnbits.db import open_ext_db
-from lnbits.settings import WALLET
 from .models import Users, Wallets
-from typing import List, Optional, Union
+from typing import Optional
 
 from ...core.crud import (
     create_account,
     get_user,
-    update_user_extension,
     get_wallet_payments,
     create_wallet,
     delete_wallet,
@@ -103,7 +101,7 @@ def get_usermanager_wallets(user_id: str) -> Wallets:
 
 
 def get_usermanager_wallet_transactions(wallet_id: str) -> Users:
-    return get_wallet_payments(wallet_id=wallet_id, include_all_pending=False)
+    return get_wallet_payments(wallet_id=wallet_id, complete=True, pending=False, outgoing=True, incoming=True)
 
 
 def get_usermanager_wallet_balances(user_id: str) -> Users: