subhankarfynd commited on
Commit
70b9491
·
1 Parent(s): 43440d8

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +71 -0
handler.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import bitsandbytes as bnb
3
+ import torch
4
+ import transformers
5
+ from datasets import load_dataset
6
+ from typing import Dict, List, Any
7
+ from peft import (
8
+ LoraConfig,
9
+ PeftConfig,
10
+ PeftModel,
11
+ get_peft_model,
12
+ prepare_model_for_kbit_training,
13
+ )
14
+ from transformers import (
15
+ AutoConfig,
16
+ LlamaTokenizer,
17
+ LlamaForCausalLM,
18
+ #AutoModelForCausalLM,
19
+ #AutoTokenizer,
20
+ BitsAndBytesConfig,
21
+ )
22
+ import json
23
+
24
+ bnb_config = BitsAndBytesConfig(
25
+ load_in_4bit=True,
26
+ bnb_4bit_use_double_quant=True,
27
+ bnb_4bit_quant_type="nf4",
28
+ bnb_4bit_compute_dtype=torch.bfloat16,
29
+ )
30
+
31
+
32
+ from huggingface_hub import login
33
+ access_token_read = "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"
34
+ login(token = access_token_read)
35
+
36
+
37
+ class EndpointHandler:
38
+ def __init__(self, path=''):
39
+ PEFT_MODEL = path
40
+ config = PeftConfig.from_pretrained(PEFT_MODEL)
41
+ self.model = LlamaForCausalLM.from_pretrained(
42
+ config.base_model_name_or_path,
43
+ return_dict=True,
44
+ quantization_config=bnb_config,
45
+ device_map="auto",
46
+ trust_remote_code=True,
47
+ )
48
+ self.tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path)
49
+ self.tokenizer.pad_token_id = (0)
50
+ self.tokenizer.padding_side = "left"
51
+ self.model = PeftModel.from_pretrained(self.model, PEFT_MODEL)
52
+ self.generation_config = self.model.generation_config
53
+ self.generation_config.max_new_tokens = 100
54
+ self.generation_config.pad_token_id = self.tokenizer.eos_token_id
55
+ self.generation_config.eos_token_id = self.tokenizer.eos_token_id
56
+
57
+
58
+
59
+
60
+ def __call__(self, data: Dict[str, Any]):
61
+ prompt = data.pop("inputs", data)
62
+ DEVICE = "cuda:0"
63
+ input_message = f"""[INST]You are an assistant that detects the intent and entity of user's message. Possible entity stores are JioMart, JioFiber, JioCinema, Tira Beauty, netmeds and milkbasket. Detect the intent and entity of the following user's message[/INST]\nUser: {prompt}\nAssistant: """.strip()
64
+ encoding = self.tokenizer(input_message, return_tensors="pt").to(DEVICE)
65
+ with torch.inference_mode():
66
+ outputs = self.model.generate(
67
+ input_ids=encoding.input_ids,
68
+ attention_mask=encoding.attention_mask,
69
+ generation_config=self.generation_config
70
+ )
71
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)[len(input_message):]