GiaPhu commited on
Commit
5a26ca7
·
verified ·
1 Parent(s): 9337d23

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +27 -0
handler.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ class EndpointHandler():
6
+ def __init__(self, path=""):
7
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+ try:
9
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device)
10
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
11
+ except Exception as e:
12
+ print(f"Error loading model or tokenizer from path {path}: {e}")
13
+ # Handle error (e.g., exit or set model/tokenizer to None)
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ inputs = data.get("inputs", "")
17
+ if not inputs:
18
+ return [{"error": "No inputs provided"}]
19
+
20
+ tokenized_input = self.tokenizer(inputs, return_tensors="pt")
21
+ input_ids,attention_masks = tokenized_input["input_ids"].to(self.device), tokenized_input["attention_mask"].to(self.device) # Move input tensors to the same device as model
22
+
23
+ summary_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_masks,)
24
+
25
+ summary_text = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
26
+ print('good')
27
+ return [{"summary": summary_text}]