onisj commited on
Commit
88ef157
·
verified ·
1 Parent(s): c06380b

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +98 -0
handler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
3
+ from sentence_transformers import SentenceTransformer
4
+ import torch
5
+ import os
6
+ import logging
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class EndpointHandler:
12
+ def __init__(self, path=""):
13
+ self.path = path
14
+ try:
15
+ self.task = self._determine_task()
16
+ except Exception as e:
17
+ logger.error(f"Failed to determine task: {str(e)}")
18
+ raise
19
+
20
+ logger.info(f"Initializing model for task: {self.task} at path: {path}")
21
+ if self.task == "text-generation":
22
+ self.model = AutoModelForCausalLM.from_pretrained(
23
+ path,
24
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
25
+ )
26
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
27
+ self.pipeline = pipeline(
28
+ "text-generation",
29
+ model=self.model,
30
+ tokenizer=self.tokenizer,
31
+ device=0 if torch.cuda.is_available() else -1
32
+ )
33
+ elif self.task == "text-classification":
34
+ self.model = AutoModelForSequenceClassification.from_pretrained(
35
+ path,
36
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
37
+ )
38
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
39
+ self.pipeline = pipeline(
40
+ "text-classification",
41
+ model=self.model,
42
+ tokenizer=self.tokenizer,
43
+ device=0 if torch.cuda.is_available() else -1
44
+ )
45
+ elif self.task == "sentence-embedding":
46
+ self.model = SentenceTransformer(path)
47
+ else:
48
+ raise ValueError(f"Unsupported task: {self.task} for model at {path}")
49
+
50
+ def _determine_task(self):
51
+ config_path = os.path.join(self.path, "config.json")
52
+ if not os.path.exists(config_path):
53
+ logger.error(f"config.json not found in {self.path}")
54
+ raise ValueError(f"config.json not found in {self.path}")
55
+
56
+ try:
57
+ config = AutoConfig.from_pretrained(self.path)
58
+ model_type = config.model_type if hasattr(config, "model_type") else None
59
+ except Exception as e:
60
+ logger.error(f"Failed to load config: {str(e)}")
61
+ raise ValueError(f"Invalid config.json in {self.path}: {str(e)}")
62
+
63
+ text_generation_types = ["gpt2"]
64
+ text_classification_types = ["bert", "distilbert", "roberta"]
65
+ embedding_types = ["bert"]
66
+
67
+ model_name = self.path.split("/")[-1].lower()
68
+ logger.info(f"Model name: {model_name}, Model type: {model_type}")
69
+ if model_type in text_generation_types or model_name in ["fine_tuned_gpt2", "merged_distilgpt2"]:
70
+ return "text-generation"
71
+ elif model_type in text_classification_types or model_name in ["emotion_classifier", "emotion_model", "intent_classifier", "intent_fallback"]:
72
+ return "text-classification"
73
+ elif model_name in ["intent_encoder", "sentence_transformer"] or "sentence_bert_config.json" in os.listdir(self.path):
74
+ return "sentence-embedding"
75
+ raise ValueError(f"Could not determine task for model_type: {model_type}, model_name: {model_name}")
76
+
77
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
78
+ inputs = data.get("inputs", "")
79
+ parameters = data.get("parameters", None)
80
+ if not inputs:
81
+ logger.warning("No inputs provided")
82
+ return [{"error": "No inputs provided"}]
83
+
84
+ try:
85
+ logger.info(f"Processing inputs for task: {self.task}")
86
+ if self.task == "text-generation":
87
+ result = self.pipeline(inputs, max_length=50, num_return_sequences=1, **(parameters or {}))
88
+ return [{"generated_text": item["generated_text"]} for item in result]
89
+ elif self.task == "text-classification":
90
+ result = self.pipeline(inputs, return_all_scores=True, **(parameters or {}))
91
+ return [{"label": item["label"], "score": item["score"]} for sublist in result for item in sublist]
92
+ elif self.task == "sentence-embedding":
93
+ embeddings = self.model.encode(inputs)
94
+ return [{"embeddings": embeddings.tolist()}]
95
+ return [{"error": f"Unsupported task: {self.task}"}]
96
+ except Exception as e:
97
+ logger.error(f"Inference failed: {str(e)}")
98
+ return [{"error": f"Inference failed: {str(e)}"}]