|
from typing import Dict, List, Any |
|
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification |
|
from sentence_transformers import SentenceTransformer |
|
import torch |
|
import os |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.path = path |
|
try: |
|
self.task = self._determine_task() |
|
except Exception as e: |
|
logger.error(f"Failed to determine task: {str(e)}") |
|
raise |
|
|
|
logger.info(f"Initializing model for task: {self.task} at path: {path}") |
|
if self.task == "text-generation": |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
path, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.pipeline = pipeline( |
|
"text-generation", |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
elif self.task == "text-classification": |
|
self.model = AutoModelForSequenceClassification.from_pretrained( |
|
path, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.pipeline = pipeline( |
|
"text-classification", |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
elif self.task == "sentence-embedding": |
|
self.model = SentenceTransformer(path) |
|
else: |
|
raise ValueError(f"Unsupported task: {self.task} for model at {path}") |
|
|
|
def _determine_task(self): |
|
config_path = os.path.join(self.path, "config.json") |
|
if not os.path.exists(config_path): |
|
logger.error(f"config.json not found in {self.path}") |
|
raise ValueError(f"config.json not found in {self.path}") |
|
|
|
try: |
|
config = AutoConfig.from_pretrained(self.path) |
|
model_type = config.model_type if hasattr(config, "model_type") else None |
|
except Exception as e: |
|
logger.error(f"Failed to load config: {str(e)}") |
|
raise ValueError(f"Invalid config.json in {self.path}: {str(e)}") |
|
|
|
text_generation_types = ["gpt2"] |
|
text_classification_types = ["bert", "distilbert", "roberta"] |
|
embedding_types = ["bert"] |
|
|
|
model_name = self.path.split("/")[-1].lower() |
|
logger.info(f"Model name: {model_name}, Model type: {model_type}") |
|
if model_type in text_generation_types or model_name in ["fine_tuned_gpt2", "merged_distilgpt2"]: |
|
return "text-generation" |
|
elif model_type in text_classification_types or model_name in ["emotion_classifier", "emotion_model", "intent_classifier", "intent_fallback"]: |
|
return "text-classification" |
|
elif model_name in ["intent_encoder", "sentence_transformer"] or "sentence_bert_config.json" in os.listdir(self.path): |
|
return "sentence-embedding" |
|
raise ValueError(f"Could not determine task for model_type: {model_type}, model_name: {model_name}") |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
inputs = data.get("inputs", "") |
|
parameters = data.get("parameters", None) |
|
if not inputs: |
|
logger.warning("No inputs provided") |
|
return [{"error": "No inputs provided"}] |
|
|
|
try: |
|
logger.info(f"Processing inputs for task: {self.task}") |
|
if self.task == "text-generation": |
|
result = self.pipeline(inputs, max_length=50, num_return_sequences=1, **(parameters or {})) |
|
return [{"generated_text": item["generated_text"]} for item in result] |
|
elif self.task == "text-classification": |
|
result = self.pipeline(inputs, return_all_scores=True, **(parameters or {})) |
|
return [{"label": item["label"], "score": item["score"]} for sublist in result for item in sublist] |
|
elif self.task == "sentence-embedding": |
|
embeddings = self.model.encode(inputs) |
|
return [{"embeddings": embeddings.tolist()}] |
|
return [{"error": f"Unsupported task: {self.task}"}] |
|
except Exception as e: |
|
logger.error(f"Inference failed: {str(e)}") |
|
return [{"error": f"Inference failed: {str(e)}"}] |