from typing import Dict, List, Any
from modelscope import AutoModelForCausalLM, AutoTokenizer
import torch

class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer =AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForCausalLM.from_pretrained(path, device_map='cuda')
            
    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        sys_prompt=data["prompt"]
        list=data["inputs"]
        prompt=f"<|im_start|>system\n{sys_prompt}.<|im_end|>\n"
        for item in list:
            if item["role"]=="assistant":
                content=item["content"]
                prompt+=f"<|im_start|>assistant\n{content}<|im_end|>\n"
            else:
                content=item["content"]
                prompt+=f"<|im_start|>user\n{content}<|im_end|>\n"
        prompt+="<|im_start|>assistant\n"

        #for chat in prompts:
            #print(chat)
        encodeds = self.tokenizer.encode(prompt, return_tensors="pt")
        model_inputs = encodeds.to("cuda")
        self.model.to("cuda")
        generated_ids = self.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
        decoded = self.tokenizer.decode(generated_ids[0])
        return decoded