torinriley commited on
Commit
1d624af
·
verified ·
1 Parent(s): 340599d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +48 -47
handler.py CHANGED
@@ -48,50 +48,51 @@ class EndpointHandler:
48
  self.model.load_state_dict(checkpoint["model_state_dict"])
49
  self.model.eval()
50
 
51
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
52
- """
53
- Process the incoming request and return the translation.
54
- """
55
- try:
56
- inputs = data.get("inputs", "")
57
- if not inputs:
58
- return [{"error": "No 'inputs' provided in request"}]
59
-
60
- # Precompute the encoder output
61
- source = self.tokenizer_src.encode(inputs)
62
- source = torch.cat([
63
- torch.tensor([self.tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64),
64
- torch.tensor(source.ids, dtype=torch.int64),
65
- torch.tensor([self.tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64),
66
- torch.tensor([self.tokenizer_src.token_to_id("[PAD]")] * (350 - len(source.ids) - 2), dtype=torch.int64)
67
- ], dim=0).to(self.device)
68
- source_mask = (source != self.tokenizer_src.token_to_id("[PAD]")).unsqueeze(0).unsqueeze(1).int().to(self.device)
69
- encoder_output = self.model.encode(source, source_mask)
70
-
71
- # Generate translation word by word
72
- decoder_input = torch.empty(1, 1).fill_(self.tokenizer_tgt.token_to_id("[SOS]")).type_as(source).to(self.device)
73
- predicted_words = []
74
-
75
- while decoder_input.size(1) < 350:
76
- decoder_mask = torch.triu(
77
- torch.ones((1, decoder_input.size(1), decoder_input.size(1))),
78
- diagonal=1
79
- ).type(torch.int).type_as(source_mask).to(self.device)
80
- out = self.model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
81
-
82
- # Project next token
83
- prob = self.model.project(out[:, -1])
84
- _, next_word = torch.max(prob, dim=1)
85
- decoder_input = torch.cat(
86
- [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(self.device)], dim=1)
87
-
88
- decoded_word = self.tokenizer_tgt.decode([next_word.item()])
89
- if next_word == self.tokenizer_tgt.token_to_id("[EOS]"):
90
- break
91
-
92
- predicted_words.append(decoded_word)
93
-
94
- predicted_translation = " ".join(predicted_words).replace("[EOS]", "").strip()
95
- return [{"translation": predicted_translation}]
96
- except Exception as e:
97
- return [{"error": str(e)}]
 
 
48
  self.model.load_state_dict(checkpoint["model_state_dict"])
49
  self.model.eval()
50
 
51
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
52
+ """
53
+ Process the incoming request and return the translation.
54
+ """
55
+ try:
56
+ inputs = data.get("inputs", "")
57
+ if not inputs:
58
+ return [{"error": "No 'inputs' provided in request"}]
59
+
60
+ source = self.tokenizer_src.encode(inputs)
61
+ source = torch.cat([
62
+ torch.tensor([self.tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64),
63
+ torch.tensor(source.ids, dtype=torch.int64),
64
+ torch.tensor([self.tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64),
65
+ torch.tensor([self.tokenizer_src.token_to_id("[PAD]")] * (350 - len(source.ids) - 2), dtype=torch.int64)
66
+ ], dim=0).to(self.device)
67
+
68
+ source_mask = (source != self.tokenizer_src.token_to_id("[PAD]")).unsqueeze(0).unsqueeze(1).int().to(self.device)
69
+ encoder_output = self.model.encode(source, source_mask)
70
+
71
+ decoder_input = torch.empty(1, 1).fill_(self.tokenizer_tgt.token_to_id("[SOS]")).type_as(source).to(self.device)
72
+ predicted_words = []
73
+
74
+ while decoder_input.size(1) < 350:
75
+ decoder_mask = torch.triu(
76
+ torch.ones((1, decoder_input.size(1), decoder_input.size(1))),
77
+ diagonal=1
78
+ ).type(torch.int).type_as(source_mask).to(self.device)
79
+
80
+ out = self.model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
81
+ prob = self.model.project(out[:, -1])
82
+ _, next_word = torch.max(prob, dim=1)
83
+
84
+ decoder_input = torch.cat(
85
+ [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(self.device)], dim=1)
86
+
87
+ decoded_word = self.tokenizer_tgt.decode([next_word.item()])
88
+ if next_word == self.tokenizer_tgt.token_to_id("[EOS]"):
89
+ break
90
+
91
+ predicted_words.append(decoded_word)
92
+
93
+ predicted_translation = " ".join(predicted_words).replace("[EOS]", "").strip()
94
+
95
+ return [{"translation": predicted_translation}]
96
+ except Exception as e:
97
+ return [{"error": str(e)}]
98
+