gemma-finetuned / handler.py
factshlab's picture
Update handler.py
5cfcb29 verified
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
class EndpointHandler():
def __init__(self, path=""):
# Define the base and adapter model names
self.base_model_name = "google/gemma-1.1-2b-it"
self.adapter_model_name = "factshlab/autotrain-pjkul-jliyi"
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map="auto",
)
# Load the adapter and merge with the base model
self.model = PeftModel.from_pretrained(base_model, self.adapter_model_name)
self.model = self.model.merge_and_unload() # Merging LoRA adapter
# Load the tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name, trust_remote_code=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "right"
# Ensure the model is on the appropriate device (CPU or GPU)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj:`str`): The input text for the model.
Return:
A :obj:`list` | `dict`: The prediction from the model, serialized and returned.
"""
# Extract input text
inputs = data.pop("inputs", data)
# Tokenize input
inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.device)
# Generate prediction
outputs = self.model.generate(**inputs, max_new_tokens=50)
# Decode the generated tokens to text
prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Return the result in a JSON-serializable format
return [{"generated_text": prediction}]