Commit
·
70b9491
1
Parent(s):
43440d8
Create handler.py
Browse files- handler.py +71 -0
handler.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import bitsandbytes as bnb
|
3 |
+
import torch
|
4 |
+
import transformers
|
5 |
+
from datasets import load_dataset
|
6 |
+
from typing import Dict, List, Any
|
7 |
+
from peft import (
|
8 |
+
LoraConfig,
|
9 |
+
PeftConfig,
|
10 |
+
PeftModel,
|
11 |
+
get_peft_model,
|
12 |
+
prepare_model_for_kbit_training,
|
13 |
+
)
|
14 |
+
from transformers import (
|
15 |
+
AutoConfig,
|
16 |
+
LlamaTokenizer,
|
17 |
+
LlamaForCausalLM,
|
18 |
+
#AutoModelForCausalLM,
|
19 |
+
#AutoTokenizer,
|
20 |
+
BitsAndBytesConfig,
|
21 |
+
)
|
22 |
+
import json
|
23 |
+
|
24 |
+
bnb_config = BitsAndBytesConfig(
|
25 |
+
load_in_4bit=True,
|
26 |
+
bnb_4bit_use_double_quant=True,
|
27 |
+
bnb_4bit_quant_type="nf4",
|
28 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
from huggingface_hub import login
|
33 |
+
access_token_read = "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"
|
34 |
+
login(token = access_token_read)
|
35 |
+
|
36 |
+
|
37 |
+
class EndpointHandler:
|
38 |
+
def __init__(self, path=''):
|
39 |
+
PEFT_MODEL = path
|
40 |
+
config = PeftConfig.from_pretrained(PEFT_MODEL)
|
41 |
+
self.model = LlamaForCausalLM.from_pretrained(
|
42 |
+
config.base_model_name_or_path,
|
43 |
+
return_dict=True,
|
44 |
+
quantization_config=bnb_config,
|
45 |
+
device_map="auto",
|
46 |
+
trust_remote_code=True,
|
47 |
+
)
|
48 |
+
self.tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path)
|
49 |
+
self.tokenizer.pad_token_id = (0)
|
50 |
+
self.tokenizer.padding_side = "left"
|
51 |
+
self.model = PeftModel.from_pretrained(self.model, PEFT_MODEL)
|
52 |
+
self.generation_config = self.model.generation_config
|
53 |
+
self.generation_config.max_new_tokens = 100
|
54 |
+
self.generation_config.pad_token_id = self.tokenizer.eos_token_id
|
55 |
+
self.generation_config.eos_token_id = self.tokenizer.eos_token_id
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
def __call__(self, data: Dict[str, Any]):
|
61 |
+
prompt = data.pop("inputs", data)
|
62 |
+
DEVICE = "cuda:0"
|
63 |
+
input_message = f"""[INST]You are an assistant that detects the intent and entity of user's message. Possible entity stores are JioMart, JioFiber, JioCinema, Tira Beauty, netmeds and milkbasket. Detect the intent and entity of the following user's message[/INST]\nUser: {prompt}\nAssistant: """.strip()
|
64 |
+
encoding = self.tokenizer(input_message, return_tensors="pt").to(DEVICE)
|
65 |
+
with torch.inference_mode():
|
66 |
+
outputs = self.model.generate(
|
67 |
+
input_ids=encoding.input_ids,
|
68 |
+
attention_mask=encoding.attention_mask,
|
69 |
+
generation_config=self.generation_config
|
70 |
+
)
|
71 |
+
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)[len(input_message):]
|