Update handler.py
Browse files- handler.py +11 -2
handler.py
CHANGED
@@ -1,19 +1,28 @@
|
|
1 |
from typing import Dict, Any, List
|
2 |
import torch
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
|
|
|
|
|
|
|
5 |
class EndpointHandler():
|
6 |
def __init__(self, path=""):
|
7 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
8 |
try:
|
|
|
9 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(f"{path}/model_v2/").to(self.device)
|
10 |
self.tokenizer = AutoTokenizer.from_pretrained(f"{path}/model_v2/")
|
11 |
except Exception as e:
|
12 |
-
|
13 |
# Handle error (e.g., exit or set model/tokenizer to None)
|
14 |
|
15 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
if not inputs:
|
18 |
return [{"error": "No inputs provided"}]
|
19 |
|
|
|
1 |
from typing import Dict, Any, List
|
2 |
import torch
|
3 |
+
import logging
|
4 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
5 |
|
6 |
+
logging.basicConfig(level=logging.INFO)
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
class EndpointHandler():
|
10 |
def __init__(self, path=""):
|
11 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
12 |
try:
|
13 |
+
logger.info(f"Loading model and tokenizer from path: {path}")
|
14 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(f"{path}/model_v2/").to(self.device)
|
15 |
self.tokenizer = AutoTokenizer.from_pretrained(f"{path}/model_v2/")
|
16 |
except Exception as e:
|
17 |
+
logger.error(f"Error loading model or tokenizer from path {path}: {e}")
|
18 |
# Handle error (e.g., exit or set model/tokenizer to None)
|
19 |
|
20 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
21 |
+
if self.model is None or self.tokenizer is None:
|
22 |
+
error_message = "Model or tokenizer not properly initialized"
|
23 |
+
logger.error(error_message)
|
24 |
+
return [{"error": error_message}]
|
25 |
+
inputs = data.get("inputs")
|
26 |
if not inputs:
|
27 |
return [{"error": "No inputs provided"}]
|
28 |
|