import json, copy, types
import os
from enum import Enum
import time
from typing import Callable, Optional, Any, Union
import litellm
from litellm.utils import ModelResponse, get_secret, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx


class BedrockError(Exception):
    def __init__(self, status_code, message):
        self.status_code = status_code
        self.message = message
        self.request = httpx.Request(
            method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
        )
        self.response = httpx.Response(status_code=status_code, request=self.request)
        super().__init__(
            self.message
        )  # Call the base class constructor with the parameters it needs


class AmazonTitanConfig:
    """
    Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1

    Supported Params for the Amazon Titan models:

    - `maxTokenCount` (integer) max tokens,
    - `stopSequences` (string[]) list of stop sequence strings
    - `temperature` (float) temperature for model,
    - `topP` (int) top p for model
    """

    maxTokenCount: Optional[int] = None
    stopSequences: Optional[list] = None
    temperature: Optional[float] = None
    topP: Optional[int] = None

    def __init__(
        self,
        maxTokenCount: Optional[int] = None,
        stopSequences: Optional[list] = None,
        temperature: Optional[float] = None,
        topP: Optional[int] = None,
    ) -> None:
        locals_ = locals()
        for key, value in locals_.items():
            if key != "self" and value is not None:
                setattr(self.__class__, key, value)

    @classmethod
    def get_config(cls):
        return {
            k: v
            for k, v in cls.__dict__.items()
            if not k.startswith("__")
            and not isinstance(
                v,
                (
                    types.FunctionType,
                    types.BuiltinFunctionType,
                    classmethod,
                    staticmethod,
                ),
            )
            and v is not None
        }


class AmazonAnthropicConfig:
    """
    Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude

    Supported Params for the Amazon / Anthropic models:

    - `max_tokens_to_sample` (integer) max tokens,
    - `temperature` (float) model temperature,
    - `top_k` (integer) top k,
    - `top_p` (integer) top p,
    - `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
    - `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
    """

    max_tokens_to_sample: Optional[int] = litellm.max_tokens
    stop_sequences: Optional[list] = None
    temperature: Optional[float] = None
    top_k: Optional[int] = None
    top_p: Optional[int] = None
    anthropic_version: Optional[str] = None

    def __init__(
        self,
        max_tokens_to_sample: Optional[int] = None,
        stop_sequences: Optional[list] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[int] = None,
        anthropic_version: Optional[str] = None,
    ) -> None:
        locals_ = locals()
        for key, value in locals_.items():
            if key != "self" and value is not None:
                setattr(self.__class__, key, value)

    @classmethod
    def get_config(cls):
        return {
            k: v
            for k, v in cls.__dict__.items()
            if not k.startswith("__")
            and not isinstance(
                v,
                (
                    types.FunctionType,
                    types.BuiltinFunctionType,
                    classmethod,
                    staticmethod,
                ),
            )
            and v is not None
        }


class AmazonCohereConfig:
    """
    Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command

    Supported Params for the Amazon / Cohere models:

    - `max_tokens` (integer) max tokens,
    - `temperature` (float) model temperature,
    - `return_likelihood` (string) n/a
    """

    max_tokens: Optional[int] = None
    temperature: Optional[float] = None
    return_likelihood: Optional[str] = None

    def __init__(
        self,
        max_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
        return_likelihood: Optional[str] = None,
    ) -> None:
        locals_ = locals()
        for key, value in locals_.items():
            if key != "self" and value is not None:
                setattr(self.__class__, key, value)

    @classmethod
    def get_config(cls):
        return {
            k: v
            for k, v in cls.__dict__.items()
            if not k.startswith("__")
            and not isinstance(
                v,
                (
                    types.FunctionType,
                    types.BuiltinFunctionType,
                    classmethod,
                    staticmethod,
                ),
            )
            and v is not None
        }


class AmazonAI21Config:
    """
    Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra

    Supported Params for the Amazon / AI21 models:

    - `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.

    - `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.

    - `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.

    - `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.

    - `frequencyPenalty` (object): Placeholder for frequency penalty object.

    - `presencePenalty` (object): Placeholder for presence penalty object.

    - `countPenalty` (object): Placeholder for count penalty object.
    """

    maxTokens: Optional[int] = None
    temperature: Optional[float] = None
    topP: Optional[float] = None
    stopSequences: Optional[list] = None
    frequencePenalty: Optional[dict] = None
    presencePenalty: Optional[dict] = None
    countPenalty: Optional[dict] = None

    def __init__(
        self,
        maxTokens: Optional[int] = None,
        temperature: Optional[float] = None,
        topP: Optional[float] = None,
        stopSequences: Optional[list] = None,
        frequencePenalty: Optional[dict] = None,
        presencePenalty: Optional[dict] = None,
        countPenalty: Optional[dict] = None,
    ) -> None:
        locals_ = locals()
        for key, value in locals_.items():
            if key != "self" and value is not None:
                setattr(self.__class__, key, value)

    @classmethod
    def get_config(cls):
        return {
            k: v
            for k, v in cls.__dict__.items()
            if not k.startswith("__")
            and not isinstance(
                v,
                (
                    types.FunctionType,
                    types.BuiltinFunctionType,
                    classmethod,
                    staticmethod,
                ),
            )
            and v is not None
        }


class AnthropicConstants(Enum):
    HUMAN_PROMPT = "\n\nHuman: "
    AI_PROMPT = "\n\nAssistant: "


class AmazonLlamaConfig:
    """
    Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1

    Supported Params for the Amazon / Meta Llama models:

    - `max_gen_len` (integer) max tokens,
    - `temperature` (float) temperature for model,
    - `top_p` (float) top p for model
    """

    max_gen_len: Optional[int] = None
    temperature: Optional[float] = None
    topP: Optional[float] = None

    def __init__(
        self,
        maxTokenCount: Optional[int] = None,
        temperature: Optional[float] = None,
        topP: Optional[int] = None,
    ) -> None:
        locals_ = locals()
        for key, value in locals_.items():
            if key != "self" and value is not None:
                setattr(self.__class__, key, value)

    @classmethod
    def get_config(cls):
        return {
            k: v
            for k, v in cls.__dict__.items()
            if not k.startswith("__")
            and not isinstance(
                v,
                (
                    types.FunctionType,
                    types.BuiltinFunctionType,
                    classmethod,
                    staticmethod,
                ),
            )
            and v is not None
        }


def init_bedrock_client(
    region_name=None,
    aws_access_key_id: Optional[str] = None,
    aws_secret_access_key: Optional[str] = None,
    aws_region_name: Optional[str] = None,
    aws_bedrock_runtime_endpoint: Optional[str] = None,
):
    # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
    litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
    standard_aws_region_name = get_secret("AWS_REGION", None)

    ## CHECK IS  'os.environ/' passed in
    # Define the list of parameters to check
    params_to_check = [
        aws_access_key_id,
        aws_secret_access_key,
        aws_region_name,
        aws_bedrock_runtime_endpoint,
    ]

    # Iterate over parameters and update if needed
    for i, param in enumerate(params_to_check):
        if param and param.startswith("os.environ/"):
            params_to_check[i] = get_secret(param)
    # Assign updated values back to parameters
    (
        aws_access_key_id,
        aws_secret_access_key,
        aws_region_name,
        aws_bedrock_runtime_endpoint,
    ) = params_to_check
    if region_name:
        pass
    elif aws_region_name:
        region_name = aws_region_name
    elif litellm_aws_region_name:
        region_name = litellm_aws_region_name
    elif standard_aws_region_name:
        region_name = standard_aws_region_name
    else:
        raise BedrockError(
            message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
            status_code=401,
        )

    # check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
    env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
    if aws_bedrock_runtime_endpoint:
        endpoint_url = aws_bedrock_runtime_endpoint
    elif env_aws_bedrock_runtime_endpoint:
        endpoint_url = env_aws_bedrock_runtime_endpoint
    else:
        endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"

    import boto3

    if aws_access_key_id != None:
        # uses auth params passed to completion
        # aws_access_key_id is not None, assume user is trying to auth using litellm.completion

        client = boto3.client(
            service_name="bedrock-runtime",
            aws_access_key_id=aws_access_key_id,
            aws_secret_access_key=aws_secret_access_key,
            region_name=region_name,
            endpoint_url=endpoint_url,
        )
    else:
        # aws_access_key_id is None, assume user is trying to auth using env variables
        # boto3 automatically reads env variables

        client = boto3.client(
            service_name="bedrock-runtime",
            region_name=region_name,
            endpoint_url=endpoint_url,
        )

    return client


def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
    # handle anthropic prompts using anthropic constants
    if provider == "anthropic":
        if model in custom_prompt_dict:
            # check if the model has a registered custom prompt
            model_prompt_details = custom_prompt_dict[model]
            prompt = custom_prompt(
                role_dict=model_prompt_details["roles"],
                initial_prompt_value=model_prompt_details["initial_prompt_value"],
                final_prompt_value=model_prompt_details["final_prompt_value"],
                messages=messages,
            )
        else:
            prompt = prompt_factory(
                model=model, messages=messages, custom_llm_provider="anthropic"
            )
    else:
        prompt = ""
        for message in messages:
            if "role" in message:
                if message["role"] == "user":
                    prompt += f"{message['content']}"
                else:
                    prompt += f"{message['content']}"
            else:
                prompt += f"{message['content']}"
    return prompt


"""
BEDROCK AUTH Keys/Vars
os.environ['AWS_ACCESS_KEY_ID'] = ""
os.environ['AWS_SECRET_ACCESS_KEY'] = ""
"""


# set os.environ['AWS_REGION_NAME'] = <your-region_name>


def completion(
    model: str,
    messages: list,
    custom_prompt_dict: dict,
    model_response: ModelResponse,
    print_verbose: Callable,
    encoding,
    logging_obj,
    optional_params=None,
    litellm_params=None,
    logger_fn=None,
):
    exception_mapping_worked = False
    try:
        # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
        aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
        aws_access_key_id = optional_params.pop("aws_access_key_id", None)
        aws_region_name = optional_params.pop("aws_region_name", None)
        aws_bedrock_runtime_endpoint = optional_params.pop(
            "aws_bedrock_runtime_endpoint", None
        )

        # use passed in BedrockRuntime.Client if provided, otherwise create a new one
        client = optional_params.pop("aws_bedrock_client", None)

        # only init client, if user did not pass one
        if client is None:
            client = init_bedrock_client(
                aws_access_key_id=aws_access_key_id,
                aws_secret_access_key=aws_secret_access_key,
                aws_region_name=aws_region_name,
                aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
            )

        model = model
        modelId = (
            optional_params.pop("model_id", None) or model
        )  # default to model if not passed
        provider = model.split(".")[0]
        prompt = convert_messages_to_prompt(
            model, messages, provider, custom_prompt_dict
        )
        inference_params = copy.deepcopy(optional_params)
        stream = inference_params.pop("stream", False)
        if provider == "anthropic":
            ## LOAD CONFIG
            config = litellm.AmazonAnthropicConfig.get_config()
            for k, v in config.items():
                if (
                    k not in inference_params
                ):  # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
                    inference_params[k] = v
            data = json.dumps({"prompt": prompt, **inference_params})
        elif provider == "ai21":
            ## LOAD CONFIG
            config = litellm.AmazonAI21Config.get_config()
            for k, v in config.items():
                if (
                    k not in inference_params
                ):  # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
                    inference_params[k] = v

            data = json.dumps({"prompt": prompt, **inference_params})
        elif provider == "cohere":
            ## LOAD CONFIG
            config = litellm.AmazonCohereConfig.get_config()
            for k, v in config.items():
                if (
                    k not in inference_params
                ):  # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
                    inference_params[k] = v
            if optional_params.get("stream", False) == True:
                inference_params[
                    "stream"
                ] = True  # cohere requires stream = True in inference params
            data = json.dumps({"prompt": prompt, **inference_params})
        elif provider == "meta":
            ## LOAD CONFIG
            config = litellm.AmazonLlamaConfig.get_config()
            for k, v in config.items():
                if (
                    k not in inference_params
                ):  # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
                    inference_params[k] = v
            data = json.dumps({"prompt": prompt, **inference_params})
        elif provider == "amazon":  # amazon titan
            ## LOAD CONFIG
            config = litellm.AmazonTitanConfig.get_config()
            for k, v in config.items():
                if (
                    k not in inference_params
                ):  # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
                    inference_params[k] = v

            data = json.dumps(
                {
                    "inputText": prompt,
                    "textGenerationConfig": inference_params,
                }
            )
        else:
            data = json.dumps({})

        ## COMPLETION CALL
        accept = "application/json"
        contentType = "application/json"
        if stream == True:
            if provider == "ai21":
                ## LOGGING
                request_str = f"""
                response = client.invoke_model(
                    body={data},
                    modelId={modelId},
                    accept=accept,
                    contentType=contentType
                )
                """
                logging_obj.pre_call(
                    input=prompt,
                    api_key="",
                    additional_args={
                        "complete_input_dict": data,
                        "request_str": request_str,
                    },
                )

                response = client.invoke_model(
                    body=data, modelId=modelId, accept=accept, contentType=contentType
                )

                response = response.get("body").read()
                return response
            else:
                ## LOGGING
                request_str = f"""
                response = client.invoke_model_with_response_stream(
                    body={data},
                    modelId={modelId},
                    accept=accept,
                    contentType=contentType
                )
                """
                logging_obj.pre_call(
                    input=prompt,
                    api_key="",
                    additional_args={
                        "complete_input_dict": data,
                        "request_str": request_str,
                    },
                )

                response = client.invoke_model_with_response_stream(
                    body=data, modelId=modelId, accept=accept, contentType=contentType
                )
                response = response.get("body")
                return response
        try:
            ## LOGGING
            request_str = f"""
            response = client.invoke_model(
                body={data},
                modelId={modelId},
                accept=accept,
                contentType=contentType
            )
            """
            logging_obj.pre_call(
                input=prompt,
                api_key="",
                additional_args={
                    "complete_input_dict": data,
                    "request_str": request_str,
                },
            )
            response = client.invoke_model(
                body=data, modelId=modelId, accept=accept, contentType=contentType
            )
        except client.exceptions.ValidationException as e:
            if "The provided model identifier is invalid" in str(e):
                raise BedrockError(status_code=404, message=str(e))
            raise BedrockError(status_code=400, message=str(e))
        except Exception as e:
            raise BedrockError(status_code=500, message=str(e))

        response_body = json.loads(response.get("body").read())

        ## LOGGING
        logging_obj.post_call(
            input=prompt,
            api_key="",
            original_response=json.dumps(response_body),
            additional_args={"complete_input_dict": data},
        )
        print_verbose(f"raw model_response: {response}")
        ## RESPONSE OBJECT
        outputText = "default"
        if provider == "ai21":
            outputText = response_body.get("completions")[0].get("data").get("text")
        elif provider == "anthropic":
            outputText = response_body["completion"]
            model_response["finish_reason"] = response_body["stop_reason"]
        elif provider == "cohere":
            outputText = response_body["generations"][0]["text"]
        elif provider == "meta":
            outputText = response_body["generation"]
        else:  # amazon titan
            outputText = response_body.get("results")[0].get("outputText")

        response_metadata = response.get("ResponseMetadata", {})
        if response_metadata.get("HTTPStatusCode", 500) >= 400:
            raise BedrockError(
                message=outputText,
                status_code=response_metadata.get("HTTPStatusCode", 500),
            )
        else:
            try:
                if len(outputText) > 0:
                    model_response["choices"][0]["message"]["content"] = outputText
            except:
                raise BedrockError(
                    message=json.dumps(outputText),
                    status_code=response_metadata.get("HTTPStatusCode", 500),
                )

        ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
        prompt_tokens = len(encoding.encode(prompt))
        completion_tokens = len(
            encoding.encode(model_response["choices"][0]["message"].get("content", ""))
        )

        model_response["created"] = int(time.time())
        model_response["model"] = model
        usage = Usage(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=prompt_tokens + completion_tokens,
        )
        model_response.usage = usage
        return model_response
    except BedrockError as e:
        exception_mapping_worked = True
        raise e
    except Exception as e:
        if exception_mapping_worked:
            raise e
        else:
            import traceback

            raise BedrockError(status_code=500, message=traceback.format_exc())


def _embedding_func_single(
    model: str,
    input: str,
    client: Any,
    optional_params=None,
    encoding=None,
    logging_obj=None,
):
    # logic for parsing in - calling - parsing out model embedding calls
    ## FORMAT EMBEDDING INPUT ##
    provider = model.split(".")[0]
    inference_params = copy.deepcopy(optional_params)
    inference_params.pop(
        "user", None
    )  # make sure user is not passed in for bedrock call
    modelId = (
        optional_params.pop("model_id", None) or model
    )  # default to model if not passed
    if provider == "amazon":
        input = input.replace(os.linesep, " ")
        data = {"inputText": input, **inference_params}
        # data = json.dumps(data)
    elif provider == "cohere":
        inference_params["input_type"] = inference_params.get(
            "input_type", "search_document"
        )  # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
        data = {"texts": [input], **inference_params}  # type: ignore
    body = json.dumps(data).encode("utf-8")
    ## LOGGING
    request_str = f"""
    response = client.invoke_model(
        body={body},
        modelId={modelId},
        accept="*/*",
        contentType="application/json",
    )"""  # type: ignore
    logging_obj.pre_call(
        input=input,
        api_key="",  # boto3 is used for init.
        additional_args={
            "complete_input_dict": {"model": modelId, "texts": input},
            "request_str": request_str,
        },
    )
    try:
        response = client.invoke_model(
            body=body,
            modelId=modelId,
            accept="*/*",
            contentType="application/json",
        )
        response_body = json.loads(response.get("body").read())
        ## LOGGING
        logging_obj.post_call(
            input=input,
            api_key="",
            additional_args={"complete_input_dict": data},
            original_response=json.dumps(response_body),
        )
        if provider == "cohere":
            response = response_body.get("embeddings")
            # flatten list
            response = [item for sublist in response for item in sublist]
            return response
        elif provider == "amazon":
            return response_body.get("embedding")
    except Exception as e:
        raise BedrockError(
            message=f"Embedding Error with model {model}: {e}", status_code=500
        )


def embedding(
    model: str,
    input: Union[list, str],
    api_key: Optional[str] = None,
    logging_obj=None,
    model_response=None,
    optional_params=None,
    encoding=None,
):
    ### BOTO3 INIT ###
    # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
    aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
    aws_access_key_id = optional_params.pop("aws_access_key_id", None)
    aws_region_name = optional_params.pop("aws_region_name", None)
    aws_bedrock_runtime_endpoint = optional_params.pop(
        "aws_bedrock_runtime_endpoint", None
    )

    # use passed in BedrockRuntime.Client if provided, otherwise create a new one
    client = init_bedrock_client(
        aws_access_key_id=aws_access_key_id,
        aws_secret_access_key=aws_secret_access_key,
        aws_region_name=aws_region_name,
        aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
    )
    if type(input) == str:
        embeddings = [
            _embedding_func_single(
                model,
                input,
                optional_params=optional_params,
                client=client,
                logging_obj=logging_obj,
            )
        ]
    else:
        ## Embedding Call
        embeddings = [
            _embedding_func_single(
                model,
                i,
                optional_params=optional_params,
                client=client,
                logging_obj=logging_obj,
            )
            for i in input
        ]  # [TODO]: make these parallel calls

    ## Populate OpenAI compliant dictionary
    embedding_response = []
    for idx, embedding in enumerate(embeddings):
        embedding_response.append(
            {
                "object": "embedding",
                "index": idx,
                "embedding": embedding,
            }
        )
    model_response["object"] = "list"
    model_response["data"] = embedding_response
    model_response["model"] = model
    input_tokens = 0

    input_str = "".join(input)

    input_tokens += len(encoding.encode(input_str))

    usage = Usage(
        prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + 0
    )
    model_response.usage = usage

    return model_response