File size: 2,816 Bytes
a3dfe34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94


from typing import Dict, List, Any
# from transformers import pipeline
# import holidays
from transformers import AutoTokenizer, AutoModelForCausalLM



class EndpointHandler():
    def __init__(self, path=None):
        # self.pipeline = pipeline("text-classification",model=path)
        # self.holidays = holidays.US()
        model_id = 'sijieaaa/CodeModel-V1-3B-2024-02-07'
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            # load_in_8bit=True,
            torch_dtype="auto",
            device_map="auto"
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
           model_id
        )
        self.model.eval()
        # self.tokenizer.eval()
        # llm = vllm.LLM(model=model_id, 
        #                dtype=torch.bfloat16, 
        #                trust_remote_code=True, 
        #                 quantization="bitsandbytes", 
        #                 load_format="bitsandbytes")
        a=1

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj: `str`)
            date (:obj: `str`)
        Return:
                A :obj:`list` | `dict`: will be serialized and returned
        """
        prompt = data["inputs"]

        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)

        generated_ids = self.model.generate(
            **model_inputs,
            max_new_tokens=512
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        response = [
            {"role": "assistant", "content": response}
        ]

        # yield response
        return response


# def test():
#     # init handler
#     my_handler = EndpointHandler(path=".")

#     # prepare sample payload
#     non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "date": "2022-08-08"}
#     holiday_payload = {"inputs": "Today is a though day", "date": "2022-07-04"}

#     # test the handler
#     a = my_handler.__call__(non_holiday_payload)
#     non_holiday_pred=my_handler(non_holiday_payload)
#     holiday_payload=my_handler(holiday_payload)

#     # show results
#     print("non_holiday_pred", non_holiday_pred)
#     print("holiday_payload", holiday_payload)

#     a=1


# if __name__ == "__main__":
#     test()