import time import bitsandbytes as bnb import torch import transformers from datasets import load_dataset from typing import Dict, List, Any from peft import ( LoraConfig, PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training, ) from transformers import ( AutoConfig, LlamaTokenizer, LlamaForCausalLM, #AutoModelForCausalLM, #AutoTokenizer, BitsAndBytesConfig, ) import json bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) from huggingface_hub import login access_token_read = "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" login(token = access_token_read) class EndpointHandler: def __init__(self, path=''): PEFT_MODEL = path config = PeftConfig.from_pretrained(PEFT_MODEL) self.model = LlamaForCausalLM.from_pretrained( config.base_model_name_or_path, return_dict=True, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, ) self.tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path) self.tokenizer.pad_token_id = (0) self.tokenizer.padding_side = "left" self.model = PeftModel.from_pretrained(self.model, PEFT_MODEL) self.generation_config = self.model.generation_config self.generation_config.max_new_tokens = 100 self.generation_config.pad_token_id = self.tokenizer.eos_token_id self.generation_config.eos_token_id = self.tokenizer.eos_token_id def __call__(self, data: Dict[str, Any]): prompt = data.pop("inputs", data) DEVICE = "cuda:0" input_message = f"""[INST]You are an assistant that detects the intent and entity of user's message. Possible entity stores are JioMart, JioFiber, JioCinema, Tira Beauty, netmeds and milkbasket. Detect the intent and entity of the following user's message[/INST]\nUser: {prompt}\nAssistant: """.strip() encoding = self.tokenizer(input_message, return_tensors="pt").to(DEVICE) with torch.inference_mode(): outputs = self.model.generate( input_ids=encoding.input_ids, attention_mask=encoding.attention_mask, generation_config=self.generation_config ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)[len(input_message):]