diff --git a/lnbits/app.py b/lnbits/app.py index 4a8c1cc..2fe2827 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -9,7 +9,7 @@ from .commands import db_migrate from .core import core_app from .db import open_db from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored -from .proxy_fix import ProxyFix +from .proxy_fix import ASGIProxyFix secure_headers = SecureHeaders(hsts=False) @@ -20,10 +20,10 @@ def create_app(config_object="lnbits.settings") -> Quart: """ app = Quart(__name__, static_folder="static") app.config.from_object(config_object) + app.asgi_http_class = ASGIProxyFix cors(app) Compress(app) - ProxyFix(app, x_proto=1, x_host=1) register_assets(app) register_blueprints(app) diff --git a/lnbits/proxy_fix.py b/lnbits/proxy_fix.py index b3751d8..9b77dc1 100644 --- a/lnbits/proxy_fix.py +++ b/lnbits/proxy_fix.py @@ -1,48 +1,46 @@ -from typing import Optional, List +from typing import Optional, List, Callable +from functools import partial from urllib.request import parse_http_list as _parse_list_header - -from quart import request - - -class ProxyFix: - def __init__(self, app=None, x_for: int = 1, x_proto: int = 1, x_host: int = 0, x_port: int = 0, x_prefix: int = 0): - self.app = app - self.x_for = x_for - self.x_proto = x_proto - self.x_host = x_host - self.x_port = x_port - self.x_prefix = x_prefix - - if app: - self.init_app(app) - - def init_app(self, app): - @app.before_request - async def before_request(): - x_for = self._get_real_value(self.x_for, request.headers.get("X-Forwarded-For")) - if x_for: - request.headers["Remote-Addr"] = x_for - - x_proto = self._get_real_value(self.x_proto, request.headers.get("X-Forwarded-Proto")) - if x_proto: - request.scheme = x_proto - - x_host = self._get_real_value(self.x_host, request.headers.get("X-Forwarded-Host")) - if x_host: - request.headers["host"] = x_host.lower() - parts = x_host.split(":", 1) - # environ["SERVER_NAME"] = parts[0] - # if len(parts) == 2: - # environ["SERVER_PORT"] = parts[1] - - x_port = self._get_real_value(self.x_port, request.headers.get("X-Forwarded-Port")) - if x_port: - host = request.host - if host: - parts = host.split(":", 1) - host = parts[0] if len(parts) == 2 else host - request.headers["host"] = f"{host}:{x_port}" - # environ["SERVER_PORT"] = x_port +from urllib.parse import urlparse +from werkzeug.datastructures import Headers + +from quart import Request +from quart.asgi import ASGIHTTPConnection + + +class ASGIProxyFix(ASGIHTTPConnection): + def _create_request_from_scope(self, send: Callable) -> Request: + headers = Headers() + headers["Remote-Addr"] = (self.scope.get("client") or [""])[0] + for name, value in self.scope["headers"]: + headers.add(name.decode("latin1").title(), value.decode("latin1")) + if self.scope["http_version"] < "1.1": + headers.setdefault("Host", self.app.config["SERVER_NAME"] or "") + + path = self.scope["path"] + path = path if path[0] == "/" else urlparse(path).path + + x_proto = self._get_real_value(1, headers.get("X-Forwarded-Proto")) + if x_proto: + self.scope["scheme"] = x_proto + + x_host = self._get_real_value(1, headers.get("X-Forwarded-Host")) + if x_host: + headers["host"] = x_host.lower() + + return self.app.request_class( + self.scope["method"], + self.scope["scheme"], + path, + self.scope["query_string"], + headers, + self.scope.get("root_path", ""), + self.scope["http_version"], + max_content_length=self.app.config["MAX_CONTENT_LENGTH"], + body_timeout=self.app.config["BODY_TIMEOUT"], + send_push_promise=partial(self._send_push_promise, send), + scope=self.scope, + ) def _get_real_value(self, trusted: int, value: Optional[str]) -> Optional[str]: """Get the real value from a list header based on the configured @@ -95,6 +93,3 @@ class ProxyFix: if not is_filename or value[:2] != "\\\\": return value.replace("\\\\", "\\").replace('\\"', '"') return value - - -# host, request.root_path, subdomain, request.scheme, request.method, request.path, request.query_string.decode(),