import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoProcessor, Llama4ForConditionalGeneration

import torch

# from transformers import BitsAndBytesConfig
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     llm_int8_enable_fp32_cpu_offload=True,
# )

#Qwen/Qwen2.5-14B-Instruct-1M
#Qwen/Qwen2-0.5B
# model_name = "bartowski/simplescaling_s1-32B-GGUF"
# subfolder = "Qwen-0.5B-GRPO/checkpoint-1868"
# filename = "simplescaling_s1-32B-Q4_K_S.gguf"
# model_name = "simplescaling/s1.1-32B"
# model_name = "unsloth/Llama-4-Scout-17B-16E-Instruct-GGUF"
model_name = "unsloth/Llama-4-Scout-17B-16E-Instruct-unsloth-bnb-4bit"
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
filename = "Llama-4-Scout-17B-16E-Instruct-UD-IQ2_XXS.gguf"
torch_dtype = torch.bfloat16 # could be torch.float16 or torch.bfloat16 torch.float32 too
cache_dir = "/data"

# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     # subfolder=subfolder,
#     gguf_file=filename,
#     torch_dtype=torch_dtype,
#     device_map="auto",
#     cache_dir = cache_dir,
# )
model = Llama4ForConditionalGeneration.from_pretrained(
    model_name,
    # default is eager attention
    # attn_implementation="flex_attention",
    # gguf_file=filename,
    cache_dir = cache_dir,
    torch_dtype=torch_dtype,
    # quantization_config=bnb_config,
    device_map="auto",
)
# processor = AutoProcessor.from_pretrained(model_name, cache_dir = cache_dir)
processor = AutoTokenizer.from_pretrained(model_name, cache_dir = cache_dir)
    # , gguf_file=filename
    # , subfolder=subfolder
SYSTEM_PROMPT = "You are a friendly Chatbot."
# """
# Respond in the following format:
# <reasoning>
# ...
# </reasoning>
# <answer>
# ...
# </answer>
# """

@spaces.GPU
def generate(prompt, history):
    messages = [
        # {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": prompt}
    ]
    # text = tokenizer.apply_chat_template(
    #     messages,
    #     # tokenize=False,
    #     tokenize=True,
    #     add_generation_prompt=True
    # )
    # model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    # generated_ids = model.generate(
    #     **model_inputs,
    #     max_new_tokens=512
    # )
    # generated_ids = [
    #     output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    # ]
    
    # response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    # return response
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        # tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )
    outputs = model.generate(
        **inputs.to(model.device),
        max_new_tokens=100,
    )
    response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]


chat_interface = gr.ChatInterface(
    fn=generate,
)
chat_interface.launch(share=True)