from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextStreamer
import transformers
import torch

from huggingface_hub import login
import os 

import logging

login(token = os.getenv('HF_TOKEN'))

class Model(torch.nn.Module):
    number_of_models = 0
    __model_list__ = [
        "Qwen/Qwen2-1.5B-Instruct",
        "lmsys/vicuna-7b-v1.5",
        "google-t5/t5-large",
        "mistralai/Mistral-7B-Instruct-v0.1",
        "meta-llama/Meta-Llama-3.1-8B-Instruct"
    ]

    def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None:
        super(Model, self).__init__()
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.name = model_name

        logging.info(f'start loading model {self.name}')

        if model_name == "google-t5/t5-large":
            # For T5 or any other Seq2Seq model
            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name, torch_dtype=torch.bfloat16, device_map="auto"
            )
        else:
            # For GPT-like models or other causal language models
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name, torch_dtype=torch.bfloat16, device_map="auto"
            )

        logging.info(f'Loaded model {self.name}')

        self.update()

    @classmethod
    def update(cls):
        cls.number_of_models += 1

    def return_mode_name(self):
        return self.name
    
    def return_tokenizer(self):
        return self.tokenizer
    
    def return_model(self):
        return self.pipeline

    def gen(self, content_list, temp=0.1, max_length=500, streaming=False):
        # Convert list of texts to input IDs
        input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)

        if streaming:
            # Prepare streamers for each input
            streamers = [TextStreamer(self.tokenizer, skip_prompt=True) for _ in content_list]
            
            # Stream the output token by token for each input text
            for i, streamer in enumerate(streamers):
                for output in self.model.generate(
                        input_ids[i].unsqueeze(0),  # Process each input separately
                        max_new_tokens=max_length,
                        do_sample=True,
                        temperature=temp,
                        eos_token_id=self.tokenizer.eos_token_id,
                        return_dict_in_generate=True,
                        output_scores=True,
                        streamer=streamer):
                    pass  # TextStreamer automatically handles the streaming, no need to manually handle the output
        else:
            outputs = self.model.generate(
                input_ids,
                max_new_tokens=max_length,
                do_sample=True,
                temperature=temp,
                eos_token_id=self.tokenizer.eos_token_id
            )
            return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]