gardarjuto
commited on
Commit
•
fc73940
1
Parent(s):
eb01dd5
fix error about max_length
Browse files- handler.py +6 -5
handler.py
CHANGED
@@ -34,7 +34,6 @@ class EndpointHandler:
|
|
34 |
path, device_map="auto", torch_dtype=torch.bfloat16
|
35 |
)
|
36 |
LOGGER.info(f"Inference model loaded from {path}")
|
37 |
-
LOGGER.info(f"Model outline: {self.model}")
|
38 |
LOGGER.info(f"Model device: {self.model.device}")
|
39 |
|
40 |
# Fix the pad and bos tokens to avoid bug in the tokenizer
|
@@ -45,7 +44,7 @@ class EndpointHandler:
|
|
45 |
)
|
46 |
|
47 |
def check_valid_inputs(
|
48 |
-
self, input_a: str, input_b: str, task: int
|
49 |
) -> bool:
|
50 |
"""
|
51 |
Check if the inputs are valid
|
@@ -104,11 +103,13 @@ class EndpointHandler:
|
|
104 |
input_a = data.pop("input_a", None)
|
105 |
input_b = data.pop("input_b", None)
|
106 |
task = data.pop("task", None)
|
107 |
-
parameters = data.pop("parameters",
|
108 |
|
109 |
# Check valid inputs
|
110 |
-
if not self.check_valid_inputs(input_a, input_b, task
|
111 |
-
return []
|
|
|
|
|
112 |
|
113 |
# Tokenize the input
|
114 |
tokenized_input = self.tokenize_input(input_a, input_b, task)
|
|
|
34 |
path, device_map="auto", torch_dtype=torch.bfloat16
|
35 |
)
|
36 |
LOGGER.info(f"Inference model loaded from {path}")
|
|
|
37 |
LOGGER.info(f"Model device: {self.model.device}")
|
38 |
|
39 |
# Fix the pad and bos tokens to avoid bug in the tokenizer
|
|
|
44 |
)
|
45 |
|
46 |
def check_valid_inputs(
|
47 |
+
self, input_a: str, input_b: str, task: int
|
48 |
) -> bool:
|
49 |
"""
|
50 |
Check if the inputs are valid
|
|
|
103 |
input_a = data.pop("input_a", None)
|
104 |
input_b = data.pop("input_b", None)
|
105 |
task = data.pop("task", None)
|
106 |
+
parameters = data.pop("parameters", {})
|
107 |
|
108 |
# Check valid inputs
|
109 |
+
if not self.check_valid_inputs(input_a, input_b, task):
|
110 |
+
return [{"error": "Invalid inputs"}]
|
111 |
+
if "max_new_tokens" not in parameters:
|
112 |
+
parameters["max_new_tokens"] = 512
|
113 |
|
114 |
# Tokenize the input
|
115 |
tokenized_input = self.tokenize_input(input_a, input_b, task)
|