diff --git a/src/c3nav/asgi.py b/src/c3nav/asgi.py index 41a6340a..6d5f9561 100644 --- a/src/c3nav/asgi.py +++ b/src/c3nav/asgi.py @@ -3,7 +3,7 @@ from contextlib import suppress from channels.auth import AuthMiddlewareStack from channels.routing import ProtocolTypeRouter, URLRouter -from channels.security.websocket import AllowedHostsOriginValidator +from channels.security.websocket import OriginValidator from django.core.asgi import get_asgi_application os.environ.setdefault("DJANGO_SETTINGS_MODULE", "c3nav.settings") @@ -13,9 +13,37 @@ django_asgi = get_asgi_application() from c3nav.control.middleware import UserPermissionsChannelMiddleware # noqa from c3nav.urls import websocket_urlpatterns # noqa + +class OriginValidatorWithAllowNone(OriginValidator): + def valid_origin(self, parsed_origin): + """ + Checks parsed origin is None. + We want to allow None because browsers always send the Origin header and non-browser clients do not need CORS + + Pass control to the validate_origin function. + + Returns ``True`` if validation function was successful, ``False`` otherwise. + """ + # None is not allowed unless all hosts are allowed + if parsed_origin is None: + return True + return self.validate_origin(parsed_origin) + + +def AllowedHostsOriginValidatorWithAllowNone(app): + """ + Factory function which returns an OriginValidatorWithAllowNone configured to use + settings.ALLOWED_HOSTS. + """ + allowed_hosts = settings.ALLOWED_HOSTS + if settings.DEBUG and not allowed_hosts: + allowed_hosts = ["localhost", "127.0.0.1", "[::1]"] + return OriginValidatorWithAllowNone(app, allowed_hosts) + + application = ProtocolTypeRouter({ "http": django_asgi, - "websocket": AllowedHostsOriginValidator( + "websocket": AllowedHostsOriginValidatorWithAllowNone( AuthMiddlewareStack( UserPermissionsChannelMiddleware( URLRouter(websocket_urlpatterns),