from typing import Dict, List, Any # from transformers import pipeline # import holidays from transformers import AutoTokenizer, AutoModelForCausalLM class EndpointHandler(): def __init__(self, path=None): # self.pipeline = pipeline("text-classification",model=path) # self.holidays = holidays.US() model_id = 'sijieaaa/CodeModel-V1-3B-2024-02-07' self.model = AutoModelForCausalLM.from_pretrained( model_id, # load_in_8bit=True, torch_dtype="auto", device_map="auto" ) self.tokenizer = AutoTokenizer.from_pretrained( model_id ) self.model.eval() # self.tokenizer.eval() # llm = vllm.LLM(model=model_id, # dtype=torch.bfloat16, # trust_remote_code=True, # quantization="bitsandbytes", # load_format="bitsandbytes") a=1 def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ prompt = data["inputs"] messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} ] text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) generated_ids = self.model.generate( **model_inputs, max_new_tokens=512 ) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] response = [ {"role": "assistant", "content": response} ] # yield response return response # def test(): # # init handler # my_handler = EndpointHandler(path=".") # # prepare sample payload # non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "date": "2022-08-08"} # holiday_payload = {"inputs": "Today is a though day", "date": "2022-07-04"} # # test the handler # a = my_handler.__call__(non_holiday_payload) # non_holiday_pred=my_handler(non_holiday_payload) # holiday_payload=my_handler(holiday_payload) # # show results # print("non_holiday_pred", non_holiday_pred) # print("holiday_payload", holiday_payload) # a=1 # if __name__ == "__main__": # test()