import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import spaces
from threading import Thread
from typing import Iterator

model_id = "mistralai/Mistral-Nemo-Instruct-2407"

MAX_INPUT_TOKEN_LENGTH = 4096

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    load_in_8bit=True
)

@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9
) -> Iterator[str]:
    conversation = []
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        temperature=temperature,
        num_beams=1
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# Set up Gradio interface
iface = gr.ChatInterface(
    generate,
    chatbot=gr.Chatbot(height=600),
    textbox=gr.Textbox(placeholder="Enter your message here...", container=False, scale=7),
    title="Chat with Mistral Nemo",
    description="This is a chat interface for the Mistral Nemo model. Ask questions and get answers!",
    retry_btn="Retry",
    undo_btn="Undo Last",
    clear_btn="Clear",
    additional_inputs=[
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Maximum number of new tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

# Launch the interface
iface.launch()