import logging
from datetime import datetime, timedelta
from django.db.models import Count, Avg, Q
from django.db.models.functions import TruncDate, Extract
from django.utils import timezone
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework import viewsets, status
from rest_framework.decorators import action
from rest_framework.filters import SearchFilter, OrderingFilter
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from apps.companies.models import Company
from apps.calls.constants import TRANSFER_STATUS, BotType
from apps.calls.filters import CallFilter
from apps.calls.models import Call
from apps.calls.pagination import CallLimitOffsetPagination
from apps.calls.serializers import (
    CallDetailSerializer,
    DateRangeSerializer
)
from django.utils import timezone


logger = logging.getLogger(__name__)


class CurrentUserCallViewSet(viewsets.ModelViewSet):
    """View for manage call APIs."""
    serializer_class = CallDetailSerializer
    queryset = Call.objects.all()
    permission_classes = [IsAuthenticated]
    pagination_class = CallLimitOffsetPagination
    filter_backends = [DjangoFilterBackend, SearchFilter, OrderingFilter]
    filterset_class = CallFilter
    search_fields = ['from_number', 'to_number', 'transcript', 'summary']
    ordering_fields = ['created_at', 'duration', 'cost']
    ordering = ['-created_at']

    def get_queryset(self):
        user = self.request.user

        if user.active_company:
            try:
                company_info = Company.objects.get(phone=user.active_company.phone)
                qs = self.queryset.filter(company=company_info)
            except Company.DoesNotExist:
                return self.queryset.none()
        else:
            return self.queryset.none()

        qs = qs.filter(transfer_user=user)
        if self.action == "list":
            qs = qs.filter(bot_type=BotType.SERVICE_BOT.value)

        return qs.order_by('-id')

    def get_serializer_class(self):
        """Return the serializer class for request."""
        if self.action == 'list':
            return CallDetailSerializer

        return self.serializer_class

    def perform_create(self, serializer):
        """Create a new call."""
        serializer.save()

    @action(detail=False, methods=['get'], url_path='summary')
    def summary(self, request):
        """Get call summary statistics using DRF filters."""
        filtered_queryset = self.filter_queryset(self.get_queryset())

        summary = filtered_queryset.aggregate(
            total_calls=Count('id'),
            total_missed_calls=Count('id', filter=Q(transfer_status=TRANSFER_STATUS.FAILED.value)),
            total_attended_calls=Count('id', filter=Q(transfer_status=TRANSFER_STATUS.SUCCESSFUL.value)),
            average_call_time=Avg('duration'),
            unique_callers=Count('from_number', distinct=True),
            repeated_callers=Count('from_number') - Count('from_number', distinct=True),
        )

        report_data = {
            'total_calls': summary['total_calls'] or 0,
            'total_missed_calls': summary['total_missed_calls'] or 0,
            'total_attended_calls': summary['total_attended_calls'] or 0,
            'average_call_time_seconds': summary['average_call_time'] or 0,
            'unique_callers': summary['unique_callers'] or 0,
            'repeated_callers': summary['repeated_callers'] or 0,
        }

        return Response(report_data, status=status.HTTP_200_OK)


    @action(detail=False, methods=['get'], url_path='daily-calls')
    def daily_calls(self, request):
        """
        Return per-day unique and repetitive call counts for a given date range.
        Everything uses Django's timezone utilities.
        """
        serializer = DateRangeSerializer(data=request.query_params)
        serializer.is_valid(raise_exception=True)
        validated = serializer.validated_data

        # drop time part
        start_date = validated["start_date"].date()
        end_date = validated["end_date"].date()

        filtered_queryset = self.filter_queryset(self.get_queryset())
        user = request.user

        company_tz = timezone.get_default_timezone()
        if getattr(user, "active_company", None) and getattr(user.active_company, "timezone", None):
            try:
                company_tz = timezone.pytz.timezone(user.active_company.timezone)
            except Exception:
                company_tz = timezone.get_default_timezone()

        # start_date = datetime.fromisoformat(start_date).date()
        # end_date = datetime.fromisoformat(end_date).date()

        # start_date = validated["start_date"].date()
        # end_date = validated["end_date"].date()

        start_dt = timezone.make_aware(datetime.combine(start_date, datetime.min.time()), company_tz)
        end_dt_exclusive = timezone.make_aware(
            datetime.combine(end_date + timedelta(days=1), datetime.min.time()),
            company_tz,
        )

        queryset_in_range = filtered_queryset.filter(
            created_at__gte=start_dt,
            created_at__lt=end_dt_exclusive,
        )

        daily_agg = (
            queryset_in_range
            .annotate(day=TruncDate("created_at", tzinfo=company_tz))
            .values("day")
            .annotate(
                total_calls=Count("id"),
                unique_callers=Count("from_number", distinct=True),
            )
            .order_by("day")
        )

        days_range = {
            start_date + timedelta(days=i): {"unique": 0, "repetitive": 0}
            for i in range((end_date - start_date).days + 1)
        }

        for row in daily_agg:
            local_day = row["day"]  # already in company tz
            unique = row["unique_callers"] or 0
            total = row["total_calls"] or 0
            days_range[local_day] = {"unique": unique, "repetitive": max(total - unique, 0)}

        result = [
            {"date": day.strftime("%b-%d"), **days_range[day]}
            for day in sorted(days_range.keys())
        ]

        return Response({"daily_calls": result}, status=status.HTTP_200_OK)


    @action(detail=False, methods=['get'], url_path='monthly-stats')
    def monthly_stats(self, request):
        """Return monthly call statistics with total calls and missed percentage for current year only."""
        filtered_queryset = self.filter_queryset(self.get_queryset())
        user = request.user

        company_tz = timezone.get_default_timezone()
        if getattr(user, "active_company", None) and getattr(user.active_company, "timezone", None):
            try:
                company_tz = timezone.pytz.timezone(user.active_company.timezone)
            except Exception:
                company_tz = timezone.get_default_timezone()

        # Filter for current year only
        current_year = timezone.now().year
        start_of_year = timezone.make_aware(
            datetime(current_year, 1, 1), company_tz
        )
        end_of_year = timezone.make_aware(
            datetime(current_year + 1, 1, 1), company_tz
        )

        current_year_queryset = filtered_queryset.filter(
            created_at__gte=start_of_year,
            created_at__lt=end_of_year
        )

        monthly_agg = (
            current_year_queryset
            .annotate(
                year=Extract('created_at', 'year', tzinfo=company_tz),
                month=Extract('created_at', 'month', tzinfo=company_tz)
            )
            .values('year', 'month')
            .annotate(
                total_calls=Count('id'),
                missed_calls=Count('id', filter=Q(transfer_status=0))
            )
            .filter(total_calls__gt=0)  # Only include months with calls
            .order_by('month')
        )

        month_names = {
            1: "January", 2: "February", 3: "March", 4: "April",
            5: "May", 6: "June", 7: "July", 8: "August",
            9: "September", 10: "October", 11: "November", 12: "December"
        }

        result = []
        for row in monthly_agg:
            month_num = int(row['month'])
            month_name = month_names.get(month_num, f"Month {month_num}")
            total_calls = row['total_calls'] or 0
            missed_calls = row['missed_calls'] or 0

            missed_percentage = "0%"
            if total_calls > 0:
                missed_percentage = f"{round((missed_calls / total_calls) * 100)}%"

            result.append({
                "month": month_name,
                "calls": total_calls,
                "missed": missed_percentage
            })

        return Response(result, status=status.HTTP_200_OK)

    @action(detail=False, methods=['get'], url_path='transfer-percentage')
    def transfer_percentage(self, request):
        """Return percentages for transferred, not transferred, failed, and unknown statuses."""
        serializer = DateRangeSerializer(data=request.query_params)
        serializer.is_valid(raise_exception=True)
        validated = serializer.validated_data

        start_date = validated["start_date"]
        end_date = validated["end_date"]

        filtered_queryset = self.filter_queryset(self.get_queryset())
        calls_in_range = filtered_queryset.filter(
            created_at__gte=start_date,
            created_at__lte=end_date
        )

        total_calls = calls_in_range.count()

        if total_calls == 0:
            return Response({
                "transferred_percentage": 0.0,
                "not_transferred_percentage": 0.0,
                "failed_percentage": 0.0,
                "unknown_percentage": 0.0,
                "total_calls": 0
            }, status=status.HTTP_200_OK)

        transferred_count = calls_in_range.filter(transfer_status=1).count()  # SUCCESSFUL
        not_transferred_count = calls_in_range.filter(transfer_status=0).count()  # NOT_TRANSFERRED
        failed_count = calls_in_range.filter(transfer_status=2).count()  # FAILED
        unknown_count = calls_in_range.exclude(transfer_status__in=[0, 1, 2]).count()  # Any other status

        transferred_percentage = round((transferred_count / total_calls) * 100, 2)
        not_transferred_percentage = round((not_transferred_count / total_calls) * 100, 2)
        failed_percentage = round((failed_count / total_calls) * 100, 2)
        unknown_percentage = round((unknown_count / total_calls) * 100, 2)

        return Response({
            "transferred_percentage": transferred_percentage,
            "not_transferred_percentage": not_transferred_percentage,
            "failed_percentage": failed_percentage,
            "unknown_percentage": unknown_percentage,
        }, status=status.HTTP_200_OK)
