import os
import toml
from pathlib import Path
import google.generativeai as palm_api

from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt

from .utils import set_palm_api_key


# Set PaLM API Key
set_palm_api_key()

# Load PaLM Prompt Templates
palm_prompts = toml.load(Path('.') / 'assets' / 'palm_prompts.toml')

class PaLMChatPromptFmt(PromptFmt):
    @classmethod
    def ctx(cls, context):
        warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)

    @classmethod
    def prompt(cls, pingpong, truncate_size):
        warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
        ping = pingpong.ping[:truncate_size]
        pong = pingpong.pong
        
        if pong is None or pong.strip() == "":
            return [
                {
                    "author": "USER",
                    "content": ping
                },
            ]
        else:
            pong = pong[:truncate_size]

            return [
                {
                    "author": "USER",
                    "content": ping
                },
                {
                    "author": "AI",
                    "content": pong
                },
            ]

class PaLMChatPPManager(PPManager):
    def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=PaLMChatPromptFmt, truncate_size: int=None):
        warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
        results = []
        
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)

        for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
            results += fmt.prompt(pingpong, truncate_size=truncate_size)

        return results    

class GradioPaLMChatPPManager(PaLMChatPPManager):
    def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
        warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)

        results = []

        for pingpong in self.pingpongs[from_idx:to_idx]:
            results.append(fmt.ui(pingpong))

        return results    

async def gen_text(
    prompt,
    mode="chat", #chat or text
    parameters=None,
    use_filter=True
):
    warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
    if parameters is None:
        temperature = 1.0
        top_k = 40
        top_p = 0.95
        max_output_tokens = 1024
        
        # default safety settings
        safety_settings = [{"category":"HARM_CATEGORY_DEROGATORY","threshold":1},
                           {"category":"HARM_CATEGORY_TOXICITY","threshold":1},
                           {"category":"HARM_CATEGORY_VIOLENCE","threshold":2},
                           {"category":"HARM_CATEGORY_SEXUAL","threshold":2},
                           {"category":"HARM_CATEGORY_MEDICAL","threshold":2},
                           {"category":"HARM_CATEGORY_DANGEROUS","threshold":2}]
        if not use_filter:
            for idx, _ in enumerate(safety_settings):
                safety_settings[idx]['threshold'] = 4

        if mode == "chat":
            parameters = {
                'model': 'models/chat-bison-001',
                'candidate_count': 1,
                'context': "",
                'temperature': temperature,
                'top_k': top_k,
                'top_p': top_p,
                'safety_settings': safety_settings,
            }
        else:
            parameters = {
                'model': 'models/text-bison-001',
                'candidate_count': 1,
                'temperature': temperature,
                'top_k': top_k,
                'top_p': top_p,
                'max_output_tokens': max_output_tokens,
                'safety_settings': safety_settings,
            }

    try:
        if mode == "chat":
            response = await palm_api.chat_async(**parameters, messages=prompt)
        else:
            response = palm_api.generate_text(**parameters, prompt=prompt)
    except:
        raise EnvironmentError("PaLM API is not available.")

    if use_filter and len(response.filters) > 0:
        raise Exception("PaLM API has withheld a response due to content safety concerns.")
    else:
        if mode == "chat":
            response_txt = response.last
        else:
            response_txt = response.result
    
    return response, response_txt