from enum import Enum
from typing import Any, Dict, List, Optional, Union

from redis.utils import experimental

try:
    from typing import Self  # Py 3.11+
except ImportError:
    from typing_extensions import Self

from redis.commands.search.aggregation import Limit, Reducer
from redis.commands.search.query import Filter, SortbyField


@experimental
class HybridSearchQuery:
    def __init__(
        self,
        query_string: str,
        scorer: Optional[str] = None,
        yield_score_as: Optional[str] = None,
    ) -> None:
        """
        Create a new hybrid search query object.

        Args:
            query_string: The query string.
            scorer: The scorer to use. Allowed values are "TFIDF" or "BM25".
            yield_score_as: The name of the field to yield the score as.
        """
        self._query_string = query_string
        self._scorer = scorer
        self._yield_score_as = yield_score_as

    def query_string(self) -> str:
        """Return the query string of this query object."""
        return self._query_string

    def scorer(self, scorer: str) -> "HybridSearchQuery":
        """
        Scoring algorithm for text search query.
        Allowed values are "TFIDF", "DISMAX", "DOCSCORE", "BM25", etc.

        For more information about supported scroring algorithms,
        see https://redis.io/docs/latest/develop/ai/search-and-query/advanced-concepts/scoring/
        """
        self._scorer = scorer
        return self

    def yield_score_as(self, alias: str) -> "HybridSearchQuery":
        """
        Yield the score as a field.
        """
        self._yield_score_as = alias
        return self

    def get_args(self) -> List[str]:
        args = ["SEARCH", self._query_string]
        if self._scorer:
            args.extend(("SCORER", self._scorer))
        if self._yield_score_as:
            args.extend(("YIELD_SCORE_AS", self._yield_score_as))
        return args


class VectorSearchMethods(Enum):
    KNN = "KNN"
    RANGE = "RANGE"


@experimental
class HybridVsimQuery:
    def __init__(
        self,
        vector_field_name: str,
        vector_data: Union[bytes, str],
        vsim_search_method: Optional[VectorSearchMethods] = None,
        vsim_search_method_params: Optional[Dict[str, Any]] = None,
        filter: Optional["Filter"] = None,
        yield_score_as: Optional[str] = None,
    ) -> None:
        """
        Create a new hybrid vsim query object.

        Args:
            vector_field_name: Vector field name.

            vector_data: Vector data for the search.

            vsim_search_method: Search method that will be used for the vsim search.

            vsim_search_method_params: Search method parameters. Use the param names
                for keys and the values for the values.
                Example for KNN: {"K": 10, "EF_RUNTIME": 100}
                                    where K is mandatory and defines the number of results
                                    and EF_RUNTIME is optional and definesthe exploration factor.
                Example for RANGE: {"RADIUS": 10, "EPSILON": 0.1}
                                    where RADIUS is mandatory and defines the radius of the search
                                    and EPSILON is optional and defines the accuracy of the search.
            yield_score_as: The name of the field to yield the score as.

            filter: If defined, a filter will be applied on the vsim query results.
        """
        self._vector_field = vector_field_name
        self._vector_data = vector_data
        if vsim_search_method and vsim_search_method_params:
            self.vsim_method_params(vsim_search_method, **vsim_search_method_params)
        else:
            self._vsim_method_params = None
        self._filter = filter
        self._yield_score_as = yield_score_as

    def vector_field(self) -> str:
        """Return the vector field name of this query object."""
        return self._vector_field

    def vector_data(self) -> Union[bytes, str]:
        """Return the vector data of this query object."""
        return self._vector_data

    def vsim_method_params(
        self,
        method: VectorSearchMethods,
        **kwargs,
    ) -> "HybridVsimQuery":
        """
        Add search method parameters to the query.

        Args:
            method: Vector search method name. Supported values are "KNN" or "RANGE".
            kwargs: Search method parameters. Use the param names for keys and the
                values for the values. Example: {"K": 10, "EF_RUNTIME": 100}.
        """
        vsim_method_params: List[Union[str, int]] = [method.value]
        if kwargs:
            vsim_method_params.append(len(kwargs.items()) * 2)
            for key, value in kwargs.items():
                vsim_method_params.extend((key, value))
        self._vsim_method_params = vsim_method_params

        return self

    def filter(self, flt: "HybridFilter") -> "HybridVsimQuery":
        """
        Add a filter to the query.

        Args:
            flt: A HybridFilter object, used on a corresponding field.
        """
        self._filter = flt
        return self

    def yield_score_as(self, alias: str) -> "HybridVsimQuery":
        """
        Return the score as a field with name `alias`.
        """
        self._yield_score_as = alias
        return self

    def get_args(self) -> List[str]:
        args = ["VSIM", self._vector_field, self._vector_data]
        if self._vsim_method_params:
            args.extend(self._vsim_method_params)
        if self._filter:
            args.extend(self._filter.args)
        if self._yield_score_as:
            args.extend(("YIELD_SCORE_AS", self._yield_score_as))

        return args


class HybridQuery:
    def __init__(
        self,
        search_query: HybridSearchQuery,
        vector_similarity_query: HybridVsimQuery,
    ) -> None:
        """
        Create a new hybrid query object.

        Args:
            search_query: HybridSearchQuery object containing the text query.
            vector_similarity_query: HybridVsimQuery object containing the vector similarity query.
        """
        self._search_query = search_query
        self._vector_similarity_query = vector_similarity_query

    def get_args(self) -> List[str]:
        args = []
        args.extend(self._search_query.get_args())
        args.extend(self._vector_similarity_query.get_args())
        return args


class CombinationMethods(Enum):
    RRF = "RRF"
    LINEAR = "LINEAR"


@experimental
class CombineResultsMethod:
    def __init__(self, method: CombinationMethods, **kwargs) -> None:
        """
        Create a new combine results method object.

        Args:
            method: The combine method to use - RRF or LINEAR.
            kwargs: Additional combine parameters.
                    For RRF, the following parameters are supported(at least one should be provided):
                                WINDOW: Limits fusion scopeLimits fusion scope.
                                CONSTANT: Controls decay of rank influence.
                                YIELD_SCORE_AS: The name of the field to yield the calculated score as.
                    For LINEAR, supported parameters (at least one should be provided):
                                ALPHA: The weight of the first query.
                                BETA: The weight of the second query.
                                YIELD_SCORE_AS: The name of the field to yield the calculated score as.

                    The additional parameters are not validated and are passed as is to the server.
                    The supported format is to provide the parameter names and values like the following:
                        CombineResultsMethod(CombinationMethods.RRF, WINDOW=3, CONSTANT=0.5)
                        CombineResultsMethod(CombinationMethods.LINEAR, ALPHA=0.5, BETA=0.5)
        """
        self._method = method
        self._kwargs = kwargs

    def get_args(self) -> List[Union[str, int]]:
        args: List[Union[str, int]] = ["COMBINE", self._method.value]
        if self._kwargs:
            args.append(len(self._kwargs.items()) * 2)
            for key, value in self._kwargs.items():
                args.extend((key, value))
        return args


@experimental
class HybridPostProcessingConfig:
    def __init__(self) -> None:
        """
        Create a new hybrid post processing configuration object.
        """
        self._load_statements = []
        self._apply_statements = []
        self._groupby_statements = []
        self._sortby_fields = []
        self._filter = None
        self._limit = None

    def load(self, *fields: str) -> Self:
        """
        Add load statement parameters to the query.
        """
        if fields:
            fields_str = " ".join(fields)
            fields_list = fields_str.split(" ")
            self._load_statements.extend(("LOAD", len(fields_list), *fields_list))
        return self

    def group_by(self, fields: List[str], *reducers: Reducer) -> Self:
        """
        Specify by which fields to group the aggregation.

        Args:
            fields: Fields to group by. This can either be a single string or a list
                of strings. In both cases, the field should be specified as `@field`.
            reducers: One or more reducers. Reducers may be found in the
                `aggregation` module.
        """

        fields = [fields] if isinstance(fields, str) else fields

        ret = ["GROUPBY", str(len(fields)), *fields]
        for reducer in reducers:
            ret.extend(("REDUCE", reducer.NAME, str(len(reducer.args))))
            ret.extend(reducer.args)
            if reducer._alias is not None:
                ret.extend(("AS", reducer._alias))

        self._groupby_statements.extend(ret)
        return self

    def apply(self, **kwexpr) -> Self:
        """
        Specify one or more projection expressions to add to each result.

        Args:
            kwexpr: One or more key-value pairs for a projection. The key is
                the alias for the projection, and the value is the projection
                expression itself, for example `apply(square_root="sqrt(@foo)")`.
        """
        apply_args = []
        for alias, expr in kwexpr.items():
            ret = ["APPLY", expr]
            if alias is not None:
                ret.extend(("AS", alias))
            apply_args.extend(ret)

        self._apply_statements.extend(apply_args)

        return self

    def sort_by(self, *sortby: "SortbyField") -> Self:
        """
        Add sortby parameters to the query.
        """
        self._sortby_fields = [*sortby]
        return self

    def filter(self, filter: "HybridFilter") -> Self:
        """
        Add a numeric or string filter to the query.

        Currently, only one of each filter is supported by the engine.

        Args:
            filter: A NumericFilter or GeoFilter object, used on a corresponding field.
        """
        self._filter = filter
        return self

    def limit(self, offset: int, num: int) -> Self:
        """
        Add limit parameters to the query.
        """
        self._limit = Limit(offset, num)
        return self

    def build_args(self) -> List[str]:
        args = []
        if self._load_statements:
            args.extend(self._load_statements)
        if self._groupby_statements:
            args.extend(self._groupby_statements)
        if self._apply_statements:
            args.extend(self._apply_statements)
        if self._sortby_fields:
            sortby_args = []
            for f in self._sortby_fields:
                sortby_args.extend(f.args)
            args.extend(("SORTBY", len(sortby_args), *sortby_args))
        if self._filter:
            args.extend(self._filter.args)
        if self._limit:
            args.extend(self._limit.build_args())

        return args


@experimental
class HybridFilter(Filter):
    def __init__(
        self,
        conditions: str,
    ) -> None:
        """
        Create a new hybrid filter object.

        Args:
            conditions: Filter conditions.
        """
        args = [conditions]
        Filter.__init__(self, "FILTER", *args)


@experimental
class HybridCursorQuery:
    def __init__(self, count: int = 0, max_idle: int = 0) -> None:
        """
        Create a new hybrid cursor query object.

        Args:
            count: Number of results to return per cursor iteration.
            max_idle: Maximum idle time for the cursor.
        """
        self.count = count
        self.max_idle = max_idle

    def build_args(self):
        args = ["WITHCURSOR"]
        if self.count:
            args += ["COUNT", str(self.count)]
        if self.max_idle:
            args += ["MAXIDLE", str(self.max_idle)]
        return args
