Fynd
/

subhankarfynd's picture
Create handler.py
70b9491
import time
import bitsandbytes as bnb
import torch
import transformers
from datasets import load_dataset
from typing import Dict, List, Any
from peft import (
LoraConfig,
PeftConfig,
PeftModel,
get_peft_model,
prepare_model_for_kbit_training,
)
from transformers import (
AutoConfig,
LlamaTokenizer,
LlamaForCausalLM,
#AutoModelForCausalLM,
#AutoTokenizer,
BitsAndBytesConfig,
)
import json
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
from huggingface_hub import login
access_token_read = "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"
login(token = access_token_read)
class EndpointHandler:
def __init__(self, path=''):
PEFT_MODEL = path
config = PeftConfig.from_pretrained(PEFT_MODEL)
self.model = LlamaForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
self.tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path)
self.tokenizer.pad_token_id = (0)
self.tokenizer.padding_side = "left"
self.model = PeftModel.from_pretrained(self.model, PEFT_MODEL)
self.generation_config = self.model.generation_config
self.generation_config.max_new_tokens = 100
self.generation_config.pad_token_id = self.tokenizer.eos_token_id
self.generation_config.eos_token_id = self.tokenizer.eos_token_id
def __call__(self, data: Dict[str, Any]):
prompt = data.pop("inputs", data)
DEVICE = "cuda:0"
input_message = f"""[INST]You are an assistant that detects the intent and entity of user's message. Possible entity stores are JioMart, JioFiber, JioCinema, Tira Beauty, netmeds and milkbasket. Detect the intent and entity of the following user's message[/INST]\nUser: {prompt}\nAssistant: """.strip()
encoding = self.tokenizer(input_message, return_tensors="pt").to(DEVICE)
with torch.inference_mode():
outputs = self.model.generate(
input_ids=encoding.input_ids,
attention_mask=encoding.attention_mask,
generation_config=self.generation_config
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)[len(input_message):]