|
|
@ -1,48 +1,46 @@ |
|
|
|
from typing import Optional, List |
|
|
|
from typing import Optional, List, Callable |
|
|
|
from functools import partial |
|
|
|
from urllib.request import parse_http_list as _parse_list_header |
|
|
|
|
|
|
|
from quart import request |
|
|
|
|
|
|
|
|
|
|
|
class ProxyFix: |
|
|
|
def __init__(self, app=None, x_for: int = 1, x_proto: int = 1, x_host: int = 0, x_port: int = 0, x_prefix: int = 0): |
|
|
|
self.app = app |
|
|
|
self.x_for = x_for |
|
|
|
self.x_proto = x_proto |
|
|
|
self.x_host = x_host |
|
|
|
self.x_port = x_port |
|
|
|
self.x_prefix = x_prefix |
|
|
|
|
|
|
|
if app: |
|
|
|
self.init_app(app) |
|
|
|
|
|
|
|
def init_app(self, app): |
|
|
|
@app.before_request |
|
|
|
async def before_request(): |
|
|
|
x_for = self._get_real_value(self.x_for, request.headers.get("X-Forwarded-For")) |
|
|
|
if x_for: |
|
|
|
request.headers["Remote-Addr"] = x_for |
|
|
|
|
|
|
|
x_proto = self._get_real_value(self.x_proto, request.headers.get("X-Forwarded-Proto")) |
|
|
|
if x_proto: |
|
|
|
request.scheme = x_proto |
|
|
|
|
|
|
|
x_host = self._get_real_value(self.x_host, request.headers.get("X-Forwarded-Host")) |
|
|
|
if x_host: |
|
|
|
request.headers["host"] = x_host.lower() |
|
|
|
parts = x_host.split(":", 1) |
|
|
|
# environ["SERVER_NAME"] = parts[0] |
|
|
|
# if len(parts) == 2: |
|
|
|
# environ["SERVER_PORT"] = parts[1] |
|
|
|
|
|
|
|
x_port = self._get_real_value(self.x_port, request.headers.get("X-Forwarded-Port")) |
|
|
|
if x_port: |
|
|
|
host = request.host |
|
|
|
if host: |
|
|
|
parts = host.split(":", 1) |
|
|
|
host = parts[0] if len(parts) == 2 else host |
|
|
|
request.headers["host"] = f"{host}:{x_port}" |
|
|
|
# environ["SERVER_PORT"] = x_port |
|
|
|
from urllib.parse import urlparse |
|
|
|
from werkzeug.datastructures import Headers |
|
|
|
|
|
|
|
from quart import Request |
|
|
|
from quart.asgi import ASGIHTTPConnection |
|
|
|
|
|
|
|
|
|
|
|
class ASGIProxyFix(ASGIHTTPConnection): |
|
|
|
def _create_request_from_scope(self, send: Callable) -> Request: |
|
|
|
headers = Headers() |
|
|
|
headers["Remote-Addr"] = (self.scope.get("client") or ["<local>"])[0] |
|
|
|
for name, value in self.scope["headers"]: |
|
|
|
headers.add(name.decode("latin1").title(), value.decode("latin1")) |
|
|
|
if self.scope["http_version"] < "1.1": |
|
|
|
headers.setdefault("Host", self.app.config["SERVER_NAME"] or "") |
|
|
|
|
|
|
|
path = self.scope["path"] |
|
|
|
path = path if path[0] == "/" else urlparse(path).path |
|
|
|
|
|
|
|
x_proto = self._get_real_value(1, headers.get("X-Forwarded-Proto")) |
|
|
|
if x_proto: |
|
|
|
self.scope["scheme"] = x_proto |
|
|
|
|
|
|
|
x_host = self._get_real_value(1, headers.get("X-Forwarded-Host")) |
|
|
|
if x_host: |
|
|
|
headers["host"] = x_host.lower() |
|
|
|
|
|
|
|
return self.app.request_class( |
|
|
|
self.scope["method"], |
|
|
|
self.scope["scheme"], |
|
|
|
path, |
|
|
|
self.scope["query_string"], |
|
|
|
headers, |
|
|
|
self.scope.get("root_path", ""), |
|
|
|
self.scope["http_version"], |
|
|
|
max_content_length=self.app.config["MAX_CONTENT_LENGTH"], |
|
|
|
body_timeout=self.app.config["BODY_TIMEOUT"], |
|
|
|
send_push_promise=partial(self._send_push_promise, send), |
|
|
|
scope=self.scope, |
|
|
|
) |
|
|
|
|
|
|
|
def _get_real_value(self, trusted: int, value: Optional[str]) -> Optional[str]: |
|
|
|
"""Get the real value from a list header based on the configured |
|
|
@ -95,6 +93,3 @@ class ProxyFix: |
|
|
|
if not is_filename or value[:2] != "\\\\": |
|
|
|
return value.replace("\\\\", "\\").replace('\\"', '"') |
|
|
|
return value |
|
|
|
|
|
|
|
|
|
|
|
# host, request.root_path, subdomain, request.scheme, request.method, request.path, request.query_string.decode(), |
|
|
|