Browse Source

fix proxyfix.

aiosqlite
fiatjaf 4 years ago
parent
commit
49baa07141
  1. 4
      lnbits/app.py
  2. 73
      lnbits/proxy_fix.py

4
lnbits/app.py

@ -9,7 +9,7 @@ from .commands import db_migrate
from .core import core_app from .core import core_app
from .db import open_db from .db import open_db
from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored
from .proxy_fix import ProxyFix from .proxy_fix import ASGIProxyFix
secure_headers = SecureHeaders(hsts=False) secure_headers = SecureHeaders(hsts=False)
@ -20,10 +20,10 @@ def create_app(config_object="lnbits.settings") -> Quart:
""" """
app = Quart(__name__, static_folder="static") app = Quart(__name__, static_folder="static")
app.config.from_object(config_object) app.config.from_object(config_object)
app.asgi_http_class = ASGIProxyFix
cors(app) cors(app)
Compress(app) Compress(app)
ProxyFix(app, x_proto=1, x_host=1)
register_assets(app) register_assets(app)
register_blueprints(app) register_blueprints(app)

73
lnbits/proxy_fix.py

@ -1,48 +1,46 @@
from typing import Optional, List from typing import Optional, List, Callable
from functools import partial
from urllib.request import parse_http_list as _parse_list_header from urllib.request import parse_http_list as _parse_list_header
from urllib.parse import urlparse
from werkzeug.datastructures import Headers
from quart import request from quart import Request
from quart.asgi import ASGIHTTPConnection
class ProxyFix: class ASGIProxyFix(ASGIHTTPConnection):
def __init__(self, app=None, x_for: int = 1, x_proto: int = 1, x_host: int = 0, x_port: int = 0, x_prefix: int = 0): def _create_request_from_scope(self, send: Callable) -> Request:
self.app = app headers = Headers()
self.x_for = x_for headers["Remote-Addr"] = (self.scope.get("client") or ["<local>"])[0]
self.x_proto = x_proto for name, value in self.scope["headers"]:
self.x_host = x_host headers.add(name.decode("latin1").title(), value.decode("latin1"))
self.x_port = x_port if self.scope["http_version"] < "1.1":
self.x_prefix = x_prefix headers.setdefault("Host", self.app.config["SERVER_NAME"] or "")
if app: path = self.scope["path"]
self.init_app(app) path = path if path[0] == "/" else urlparse(path).path
def init_app(self, app): x_proto = self._get_real_value(1, headers.get("X-Forwarded-Proto"))
@app.before_request
async def before_request():
x_for = self._get_real_value(self.x_for, request.headers.get("X-Forwarded-For"))
if x_for:
request.headers["Remote-Addr"] = x_for
x_proto = self._get_real_value(self.x_proto, request.headers.get("X-Forwarded-Proto"))
if x_proto: if x_proto:
request.scheme = x_proto self.scope["scheme"] = x_proto
x_host = self._get_real_value(self.x_host, request.headers.get("X-Forwarded-Host")) x_host = self._get_real_value(1, headers.get("X-Forwarded-Host"))
if x_host: if x_host:
request.headers["host"] = x_host.lower() headers["host"] = x_host.lower()
parts = x_host.split(":", 1)
# environ["SERVER_NAME"] = parts[0] return self.app.request_class(
# if len(parts) == 2: self.scope["method"],
# environ["SERVER_PORT"] = parts[1] self.scope["scheme"],
path,
x_port = self._get_real_value(self.x_port, request.headers.get("X-Forwarded-Port")) self.scope["query_string"],
if x_port: headers,
host = request.host self.scope.get("root_path", ""),
if host: self.scope["http_version"],
parts = host.split(":", 1) max_content_length=self.app.config["MAX_CONTENT_LENGTH"],
host = parts[0] if len(parts) == 2 else host body_timeout=self.app.config["BODY_TIMEOUT"],
request.headers["host"] = f"{host}:{x_port}" send_push_promise=partial(self._send_push_promise, send),
# environ["SERVER_PORT"] = x_port scope=self.scope,
)
def _get_real_value(self, trusted: int, value: Optional[str]) -> Optional[str]: def _get_real_value(self, trusted: int, value: Optional[str]) -> Optional[str]:
"""Get the real value from a list header based on the configured """Get the real value from a list header based on the configured
@ -95,6 +93,3 @@ class ProxyFix:
if not is_filename or value[:2] != "\\\\": if not is_filename or value[:2] != "\\\\":
return value.replace("\\\\", "\\").replace('\\"', '"') return value.replace("\\\\", "\\").replace('\\"', '"')
return value return value
# host, request.root_path, subdomain, request.scheme, request.method, request.path, request.query_string.decode(),

Loading…
Cancel
Save