from utilities.setup import * import json import os from typing import Dict, List, Any from peft import AutoPeftModelForCausalLM from transformers import AutoTokenizer class EndpointHandler(): def __init__(self, path=""): """Initialize class. Load model of interest upon init.""" print("Reading config") self.path = path self.HF_TOKEN = os.getenv("HF_TOKEN") self.wd = os.getcwd() self.model_name = os.path.basename(self.wd) print("loading model") self.model, self.tokenizer = self.load_model() def load_model(self): """Load unsloth model and tokenizer""" model = AutoPeftModelForCausalLM.from_pretrained( self.path, load_in_4bit = True, ) tokenizer = AutoTokenizer.from_pretrained(self.path) return model, tokenizer def prompt_formatter(self, prompt): """Prompts must be formatted in alpaca style prior to API.""" inputs = self.tokenizer([prompt], return_tensors = "pt").to("cuda") return inputs, prompt def infer(self, prompt, max_new_tokens=1000): # add streaming capability """Bringing it all together""" # load model inputs, prompt_text = self.prompt_formatter(prompt) outputs = self.model.generate(**inputs, max_new_tokens = max_new_tokens, use_cache=True) completion = self.tokenizer.batch_decode(outputs) return completion def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ if data["inputs"] is not None: request = data['inputs'] prediction = self.infer(request) return {"prediction": prediction} else: return [{"Error" : "no input received."}]