damienliccia commited on
Commit
6249ec3
·
verified ·
1 Parent(s): ea23c1e

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +56 -0
handler.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
3
+
4
+ class EndpointHandler:
5
+ def __init__(self, model_dir):
6
+ self.tokenizer = MBart50TokenizerFast.from_pretrained(model_dir)
7
+ self.model = MBartForConditionalGeneration.from_pretrained(model_dir)
8
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ self.model.to(self.device)
10
+ self.max_length = 1024
11
+
12
+ def _validate_input(self, inputs):
13
+ if isinstance(inputs, str):
14
+ return [inputs]
15
+ elif isinstance(inputs, list) and all(isinstance(item, dict) and "input" in item for item in inputs):
16
+ return [item["input"] for item in inputs]
17
+ raise ValueError("Input must be a string or a list of dictionaries with 'input' key")
18
+
19
+ def _prepare_input(self, text):
20
+ return f"ru_RU {text}"
21
+
22
+ def process(self, inputs):
23
+ try:
24
+ # Validation et préparation
25
+ texts = self._validate_input(inputs)
26
+ prepared_texts = [self._prepare_input(text) for text in texts]
27
+
28
+ # Tokenization
29
+ inputs = self.tokenizer(
30
+ prepared_texts,
31
+ return_tensors="pt",
32
+ padding=True,
33
+ truncation=True,
34
+ max_length=self.max_length
35
+ ).to(self.device)
36
+
37
+ # Inférence
38
+ with torch.no_grad():
39
+ outputs = self.model.generate(
40
+ **inputs,
41
+ max_length=self.max_length,
42
+ num_beams=5,
43
+ do_sample=False
44
+ )
45
+
46
+ # Post-traitement
47
+ translations = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
48
+ return [{"output": translation} for translation in translations]
49
+
50
+ except Exception as e:
51
+ return [{"error": str(e)}]
52
+
53
+ def __call__(self, data):
54
+ if not isinstance(data, dict) or "inputs" not in data:
55
+ return [{"error": "Request must contain 'inputs' field"}]
56
+ return self.process(data["inputs"])