File size: 4,701 Bytes
88ef157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import Dict, List, Any
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer
import torch
import os
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, path=""):
        self.path = path
        try:
            self.task = self._determine_task()
        except Exception as e:
            logger.error(f"Failed to determine task: {str(e)}")
            raise
        
        logger.info(f"Initializing model for task: {self.task} at path: {path}")
        if self.task == "text-generation":
            self.model = AutoModelForCausalLM.from_pretrained(
                path, 
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            )
            self.tokenizer = AutoTokenizer.from_pretrained(path)
            self.pipeline = pipeline(
                "text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                device=0 if torch.cuda.is_available() else -1
            )
        elif self.task == "text-classification":
            self.model = AutoModelForSequenceClassification.from_pretrained(
                path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            )
            self.tokenizer = AutoTokenizer.from_pretrained(path)
            self.pipeline = pipeline(
                "text-classification",
                model=self.model,
                tokenizer=self.tokenizer,
                device=0 if torch.cuda.is_available() else -1
            )
        elif self.task == "sentence-embedding":
            self.model = SentenceTransformer(path)
        else:
            raise ValueError(f"Unsupported task: {self.task} for model at {path}")

    def _determine_task(self):
        config_path = os.path.join(self.path, "config.json")
        if not os.path.exists(config_path):
            logger.error(f"config.json not found in {self.path}")
            raise ValueError(f"config.json not found in {self.path}")
        
        try:
            config = AutoConfig.from_pretrained(self.path)
            model_type = config.model_type if hasattr(config, "model_type") else None
        except Exception as e:
            logger.error(f"Failed to load config: {str(e)}")
            raise ValueError(f"Invalid config.json in {self.path}: {str(e)}")
        
        text_generation_types = ["gpt2"]
        text_classification_types = ["bert", "distilbert", "roberta"]
        embedding_types = ["bert"]
        
        model_name = self.path.split("/")[-1].lower()
        logger.info(f"Model name: {model_name}, Model type: {model_type}")
        if model_type in text_generation_types or model_name in ["fine_tuned_gpt2", "merged_distilgpt2"]:
            return "text-generation"
        elif model_type in text_classification_types or model_name in ["emotion_classifier", "emotion_model", "intent_classifier", "intent_fallback"]:
            return "text-classification"
        elif model_name in ["intent_encoder", "sentence_transformer"] or "sentence_bert_config.json" in os.listdir(self.path):
            return "sentence-embedding"
        raise ValueError(f"Could not determine task for model_type: {model_type}, model_name: {model_name}")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        inputs = data.get("inputs", "")
        parameters = data.get("parameters", None)
        if not inputs:
            logger.warning("No inputs provided")
            return [{"error": "No inputs provided"}]

        try:
            logger.info(f"Processing inputs for task: {self.task}")
            if self.task == "text-generation":
                result = self.pipeline(inputs, max_length=50, num_return_sequences=1, **(parameters or {}))
                return [{"generated_text": item["generated_text"]} for item in result]
            elif self.task == "text-classification":
                result = self.pipeline(inputs, return_all_scores=True, **(parameters or {}))
                return [{"label": item["label"], "score": item["score"]} for sublist in result for item in sublist]
            elif self.task == "sentence-embedding":
                embeddings = self.model.encode(inputs)
                return [{"embeddings": embeddings.tolist()}]
            return [{"error": f"Unsupported task: {self.task}"}]
        except Exception as e:
            logger.error(f"Inference failed: {str(e)}")
            return [{"error": f"Inference failed: {str(e)}"}]