from collections import OrderedDict

from django.conf import settings
from django.core.exceptions import FieldDoesNotExist
from django.http import Http404
from django.shortcuts import redirect
from django.urls import path, reverse
from modelcluster.fields import ParentalKey
from rest_framework import status
from rest_framework.renderers import BrowsableAPIRenderer, JSONRenderer
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet

from wagtail.api import APIField
from wagtail.models import Page, PageViewRestriction, Site

from .filters import (
    AncestorOfFilter,
    ChildOfFilter,
    DescendantOfFilter,
    FieldsFilter,
    LocaleFilter,
    OrderingFilter,
    SearchFilter,
    TranslationOfFilter,
)
from .pagination import WagtailPagination
from .serializers import BaseSerializer, PageSerializer, get_serializer_class
from .utils import (
    BadRequestError,
    get_object_detail_url,
    page_models_from_string,
    parse_fields_parameter,
)


class BaseAPIViewSet(GenericViewSet):
    renderer_classes = [JSONRenderer, BrowsableAPIRenderer]

    pagination_class = WagtailPagination
    base_serializer_class = BaseSerializer
    filter_backends = []
    model = None  # Set on subclass

    known_query_parameters = frozenset(
        [
            "limit",
            "offset",
            "fields",
            "order",
            "search",
            "search_operator",
            # Used by jQuery for cache-busting. See #1671
            "_",
            # Required by BrowsableAPIRenderer
            "format",
        ]
    )
    body_fields = ["id"]
    meta_fields = ["type", "detail_url"]
    listing_default_fields = ["id", "type", "detail_url"]
    nested_default_fields = ["id", "type", "detail_url"]
    detail_only_fields = []
    name = None  # Set on subclass.

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # seen_types is a mapping of type name strings (format: "app_label.ModelName")
        # to model classes. When an object is serialised in the API, its model
        # is added to this mapping. This is used by the Admin API which appends a
        # summary of the used types to the response.
        self.seen_types = OrderedDict()

    def get_queryset(self):
        return self.model.objects.all().order_by("id")

    def listing_view(self, request):
        queryset = self.get_queryset()
        self.check_query_parameters(queryset)
        queryset = self.filter_queryset(queryset)
        queryset = self.paginate_queryset(queryset)
        serializer = self.get_serializer(queryset, many=True)
        return self.get_paginated_response(serializer.data)

    def detail_view(self, request, pk):
        instance = self.get_object()
        serializer = self.get_serializer(instance)
        return Response(serializer.data)

    def find_view(self, request):
        queryset = self.get_queryset()

        try:
            obj = self.find_object(queryset, request)

            if obj is None:
                raise self.model.DoesNotExist

        except self.model.DoesNotExist:
            raise Http404("not found")

        # Generate redirect
        url = get_object_detail_url(
            self.request.wagtailapi_router, request, self.model, obj.pk
        )

        if url is None:
            # Shouldn't happen unless this endpoint isn't actually installed in the router
            raise Exception(
                "Cannot generate URL to detail view. Is '{}' installed in the API router?".format(
                    self.__class__.__name__
                )
            )

        return redirect(url)

    def find_object(self, queryset, request):
        """
        Override this to implement more find methods.
        """
        if "id" in request.GET:
            return queryset.get(id=request.GET["id"])

    def handle_exception(self, exc):
        if isinstance(exc, Http404):
            data = {"message": str(exc)}
            return Response(data, status=status.HTTP_404_NOT_FOUND)
        elif isinstance(exc, BadRequestError):
            data = {"message": str(exc)}
            return Response(data, status=status.HTTP_400_BAD_REQUEST)
        return super().handle_exception(exc)

    @classmethod
    def _convert_api_fields(cls, fields):
        return [
            field if isinstance(field, APIField) else APIField(field)
            for field in fields
        ]

    @classmethod
    def get_body_fields(cls, model):
        return cls._convert_api_fields(
            cls.body_fields + list(getattr(model, "api_fields", ()))
        )

    @classmethod
    def get_body_fields_names(cls, model):
        return [field.name for field in cls.get_body_fields(model)]

    @classmethod
    def get_meta_fields(cls, model):
        return cls._convert_api_fields(
            cls.meta_fields + list(getattr(model, "api_meta_fields", ()))
        )

    @classmethod
    def get_meta_fields_names(cls, model):
        return [field.name for field in cls.get_meta_fields(model)]

    @classmethod
    def get_field_serializer_overrides(cls, model):
        return {
            field.name: field.serializer
            for field in cls.get_body_fields(model) + cls.get_meta_fields(model)
            if field.serializer is not None
        }

    @classmethod
    def get_available_fields(cls, model, db_fields_only=False):
        """
        Returns a list of all the fields that can be used in the API for the
        specified model class.

        Setting db_fields_only to True will remove all fields that do not have
        an underlying column in the database (eg, type/detail_url and any custom
        fields that are callables)
        """
        fields = cls.get_body_fields_names(model) + cls.get_meta_fields_names(model)

        if db_fields_only:
            # Get list of available database fields then remove any fields in our
            # list that isn't a database field
            database_fields = set()
            for field in model._meta.get_fields():
                database_fields.add(field.name)

                if hasattr(field, "attname"):
                    database_fields.add(field.attname)

            fields = [field for field in fields if field in database_fields]

        return fields

    @classmethod
    def get_detail_default_fields(cls, model):
        return cls.get_available_fields(model)

    @classmethod
    def get_listing_default_fields(cls, model):
        return cls.listing_default_fields[:]

    @classmethod
    def get_nested_default_fields(cls, model):
        return cls.nested_default_fields[:]

    def check_query_parameters(self, queryset):
        """
        Ensure that only valid query parameters are included in the URL.
        """
        query_parameters = set(self.request.GET.keys())

        # All query parameters must be either a database field or an operation
        allowed_query_parameters = set(
            self.get_available_fields(queryset.model, db_fields_only=True)
        ).union(self.known_query_parameters)
        unknown_parameters = query_parameters - allowed_query_parameters
        if unknown_parameters:
            raise BadRequestError(
                "query parameter is not an operation or a recognised field: %s"
                % ", ".join(sorted(unknown_parameters))
            )

    @classmethod
    def _get_serializer_class(
        cls, router, model, fields_config, show_details=False, nested=False
    ):
        # Get all available fields
        body_fields = cls.get_body_fields_names(model)
        meta_fields = cls.get_meta_fields_names(model)
        all_fields = body_fields + meta_fields

        # Remove any duplicates
        all_fields = list(OrderedDict.fromkeys(all_fields))

        if not show_details:
            # Remove detail only fields
            for field in cls.detail_only_fields:
                try:
                    all_fields.remove(field)
                except ValueError:
                    pass

        # Get list of configured fields
        if show_details:
            fields = set(cls.get_detail_default_fields(model))
        elif nested:
            fields = set(cls.get_nested_default_fields(model))
        else:
            fields = set(cls.get_listing_default_fields(model))

        # If first field is '*' start with all fields
        # If first field is '_' start with no fields
        if fields_config and fields_config[0][0] == "*":
            fields = set(all_fields)
            fields_config = fields_config[1:]
        elif fields_config and fields_config[0][0] == "_":
            fields = set()
            fields_config = fields_config[1:]

        mentioned_fields = set()
        sub_fields = {}

        for field_name, negated, field_sub_fields in fields_config:
            if negated:
                try:
                    fields.remove(field_name)
                except KeyError:
                    pass
            else:
                fields.add(field_name)
                if field_sub_fields:
                    sub_fields[field_name] = field_sub_fields

            mentioned_fields.add(field_name)

        unknown_fields = mentioned_fields - set(all_fields)

        if unknown_fields:
            raise BadRequestError(
                "unknown fields: %s" % ", ".join(sorted(unknown_fields))
            )

        # Build nested serialisers
        child_serializer_classes = {}

        for field_name in fields:
            try:
                django_field = model._meta.get_field(field_name)
            except FieldDoesNotExist:
                django_field = None

            if django_field and django_field.is_relation:
                child_sub_fields = sub_fields.get(field_name, [])

                # Inline (aka "child") models should display all fields by default
                if isinstance(getattr(django_field, "field", None), ParentalKey):
                    if not child_sub_fields or child_sub_fields[0][0] not in ["*", "_"]:
                        child_sub_fields = list(child_sub_fields)
                        child_sub_fields.insert(0, ("*", False, None))

                # Get a serializer class for the related object
                child_model = django_field.related_model
                child_endpoint_class = router.get_model_endpoint(child_model)
                child_endpoint_class = (
                    child_endpoint_class[1] if child_endpoint_class else BaseAPIViewSet
                )
                child_serializer_classes[
                    field_name
                ] = child_endpoint_class._get_serializer_class(
                    router, child_model, child_sub_fields, nested=True
                )

            else:
                if field_name in sub_fields:
                    # Sub fields were given for a non-related field
                    raise BadRequestError(
                        "'%s' does not support nested fields" % field_name
                    )

        # Reorder fields so it matches the order of all_fields
        fields = [field for field in all_fields if field in fields]

        field_serializer_overrides = {
            field[0]: field[1]
            for field in cls.get_field_serializer_overrides(model).items()
            if field[0] in fields
        }
        return get_serializer_class(
            model,
            fields,
            meta_fields=meta_fields,
            field_serializer_overrides=field_serializer_overrides,
            child_serializer_classes=child_serializer_classes,
            base=cls.base_serializer_class,
        )

    def get_serializer_class(self):
        request = self.request

        # Get model
        if self.action == "listing_view":
            model = self.get_queryset().model
        else:
            model = type(self.get_object())

        # Fields
        if "fields" in request.GET:
            try:
                fields_config = parse_fields_parameter(request.GET["fields"])
            except ValueError as e:
                raise BadRequestError("fields error: %s" % str(e))
        else:
            # Use default fields
            fields_config = []

        # Allow "detail_only" (eg parent) fields on detail view
        if self.action == "listing_view":
            show_details = False
        else:
            show_details = True

        return self._get_serializer_class(
            self.request.wagtailapi_router,
            model,
            fields_config,
            show_details=show_details,
        )

    def get_serializer_context(self):
        """
        The serialization context differs between listing and detail views.
        """
        return {
            "request": self.request,
            "view": self,
            "router": self.request.wagtailapi_router,
        }

    def get_renderer_context(self):
        context = super().get_renderer_context()
        context["indent"] = 4
        return context

    @classmethod
    def get_urlpatterns(cls):
        """
        This returns a list of URL patterns for the endpoint
        """
        return [
            path("", cls.as_view({"get": "listing_view"}), name="listing"),
            path("<int:pk>/", cls.as_view({"get": "detail_view"}), name="detail"),
            path("find/", cls.as_view({"get": "find_view"}), name="find"),
        ]

    @classmethod
    def get_model_listing_urlpath(cls, model, namespace=""):
        if namespace:
            url_name = namespace + ":listing"
        else:
            url_name = "listing"

        return reverse(url_name)

    @classmethod
    def get_object_detail_urlpath(cls, model, pk, namespace=""):
        if namespace:
            url_name = namespace + ":detail"
        else:
            url_name = "detail"

        return reverse(url_name, args=(pk,))


class PagesAPIViewSet(BaseAPIViewSet):
    base_serializer_class = PageSerializer
    filter_backends = [
        FieldsFilter,
        ChildOfFilter,
        AncestorOfFilter,
        DescendantOfFilter,
        OrderingFilter,
        TranslationOfFilter,
        LocaleFilter,
        SearchFilter,  # needs to be last, as SearchResults querysets cannot be filtered further
    ]
    known_query_parameters = BaseAPIViewSet.known_query_parameters.union(
        [
            "type",
            "child_of",
            "ancestor_of",
            "descendant_of",
            "translation_of",
            "locale",
            "site",
        ]
    )
    body_fields = BaseAPIViewSet.body_fields + [
        "title",
    ]
    meta_fields = BaseAPIViewSet.meta_fields + [
        "html_url",
        "slug",
        "show_in_menus",
        "seo_title",
        "search_description",
        "first_published_at",
        "alias_of",
        "parent",
        "locale",
    ]
    listing_default_fields = BaseAPIViewSet.listing_default_fields + [
        "title",
        "html_url",
        "slug",
        "first_published_at",
    ]
    nested_default_fields = BaseAPIViewSet.nested_default_fields + [
        "title",
    ]
    detail_only_fields = ["parent"]
    name = "pages"
    model = Page

    @classmethod
    def get_detail_default_fields(cls, model):
        detail_default_fields = super().get_detail_default_fields(model)

        # When i18n is disabled, remove "locale" from default fields
        if not getattr(settings, "WAGTAIL_I18N_ENABLED", False):
            detail_default_fields.remove("locale")

        return detail_default_fields

    @classmethod
    def get_listing_default_fields(cls, model):
        listing_default_fields = super().get_listing_default_fields(model)

        # When i18n is enabled, add "locale" to default fields
        if getattr(settings, "WAGTAIL_I18N_ENABLED", False):
            listing_default_fields.append("locale")

        return listing_default_fields

    def get_root_page(self):
        """
        Returns the page that is used when the `&child_of=root` filter is used.
        """
        return Site.find_for_request(self.request).root_page

    def get_base_queryset(self):
        """
        Returns a queryset containing all pages that can be seen by this user.

        This is used as the base for get_queryset and is also used to find the
        parent pages when using the child_of and descendant_of filters as well.
        """

        request = self.request

        # Get all live pages
        queryset = Page.objects.all().live()

        # Exclude pages that the user doesn't have access to
        restricted_pages = [
            restriction.page
            for restriction in PageViewRestriction.objects.all().select_related("page")
            if not restriction.accept_request(self.request)
        ]

        # Exclude the restricted pages and their descendants from the queryset
        for restricted_page in restricted_pages:
            queryset = queryset.not_descendant_of(restricted_page, inclusive=True)

        # Check if we have a specific site to look for
        if "site" in request.GET:
            # Optionally allow querying by port
            if ":" in request.GET["site"]:
                (hostname, port) = request.GET["site"].split(":", 1)
                query = {
                    "hostname": hostname,
                    "port": port,
                }
            else:
                query = {
                    "hostname": request.GET["site"],
                }
            try:
                site = Site.objects.get(**query)
            except Site.MultipleObjectsReturned:
                raise BadRequestError(
                    "Your query returned multiple sites. Try adding a port number to your site filter."
                )
        else:
            # Otherwise, find the site from the request
            site = Site.find_for_request(self.request)

        if site:
            base_queryset = queryset
            queryset = base_queryset.descendant_of(site.root_page, inclusive=True)

            # If internationalisation is enabled, include pages from other language trees
            if getattr(settings, "WAGTAIL_I18N_ENABLED", False):
                for translation in site.root_page.get_translations():
                    queryset |= base_queryset.descendant_of(translation, inclusive=True)

        else:
            # No sites configured
            queryset = queryset.none()

        return queryset

    def get_queryset(self):
        request = self.request

        # Allow pages to be filtered to a specific type
        try:
            models_type = request.GET.get("type", None)
            models = models_type and page_models_from_string(models_type) or []
        except (LookupError, ValueError):
            raise BadRequestError("type doesn't exist")

        if not models:
            if self.model == Page:
                return self.get_base_queryset()
            else:
                return self.model.objects.filter(
                    pk__in=self.get_base_queryset().values_list("pk", flat=True)
                )

        elif len(models) == 1:
            # If a single page type has been specified, swap out the Page-based queryset for one based on
            # the specific page model so that we can filter on any custom APIFields defined on that model
            return models[0].objects.filter(
                pk__in=self.get_base_queryset().values_list("pk", flat=True)
            )

        else:  # len(models) > 1
            return self.get_base_queryset().type(*models)

    def get_object(self):
        base = super().get_object()
        return base.specific

    def find_object(self, queryset, request):
        site = Site.find_for_request(request)
        if "html_path" in request.GET and site is not None:
            path = request.GET["html_path"]
            path_components = [component for component in path.split("/") if component]

            try:
                page, _, _ = site.root_page.specific.route(request, path_components)
            except Http404:
                return

            if queryset.filter(id=page.id).exists():
                return page

        return super().find_object(queryset, request)

    def get_serializer_context(self):
        """
        The serialization context differs between listing and detail views.
        """
        context = super().get_serializer_context()
        context["base_queryset"] = self.get_base_queryset()
        return context
