from channels.db import database_sync_to_async
from channels.middleware import BaseMiddleware
from typing import Any, cast
import logging


logger = logging.getLogger(__name__)

class JWTAuthMiddleware(BaseMiddleware):
    async def __call__(self, scope, receive, send):
        from rest_framework_simplejwt.tokens import AccessToken
        from django.contrib.auth import get_user_model
        user = get_user_model()

        query_string = scope.get("query_string", b"").decode()
        token = None
        for param in query_string.split("&"):
            if param.startswith("token="):
                token = param.split("=")[1]

        if token:
            try:
                validated = AccessToken(cast(Any, token))
                user_id = validated["user_id"]
                scope["user"] = cast(Any, await database_sync_to_async(user.objects.get)(id=user_id))
            except Exception as exc:
                logger.warning(f"JWTAuthMiddleware: Invalid token - {str(exc)}")
                scope["user"] = None
        else:
            scope["user"] = None

        return await super().__call__(scope, receive, send)
