Update handler.py
Browse files- handler.py +5 -11
handler.py
CHANGED
@@ -29,25 +29,19 @@ class EndpointHandler:
|
|
29 |
allowed_ids.add(token_id)
|
30 |
return allowed_ids
|
31 |
|
32 |
-
def filter_allowed_tokens(input_ids: torch.Tensor, scores: np.ndarray, allowed_token_ids: set[int]) -> np.ndarray:
|
33 |
"""
|
34 |
-
Modify scores to allow only tokens in the allowed_token_ids set
|
35 |
Handles both 1D and 2D scores arrays.
|
36 |
"""
|
37 |
-
# Define the range of exempt tokens
|
38 |
-
exempt_range = set(range(128000, 128256))
|
39 |
-
|
40 |
-
# Combine exempt tokens with allowed tokens
|
41 |
-
effective_allowed_ids = allowed_token_ids.union(exempt_range)
|
42 |
-
|
43 |
if scores.ndim == 1:
|
44 |
# 1D case: Apply mask directly
|
45 |
-
mask = np.isin(np.arange(scores.shape[0]), list(
|
46 |
scores[~mask] = float('-inf')
|
47 |
elif scores.ndim == 2:
|
48 |
# 2D case: Apply mask across each row
|
49 |
for i in range(scores.shape[0]):
|
50 |
-
mask = np.isin(np.arange(scores.shape[1]), list(
|
51 |
scores[i, ~mask] = float('-inf')
|
52 |
else:
|
53 |
raise ValueError(f"Unsupported scores dimension: {scores.ndim}")
|
@@ -95,4 +89,4 @@ class EndpointHandler:
|
|
95 |
# Decode the output
|
96 |
generated_text = response["choices"][0]["message"]["content"]
|
97 |
|
98 |
-
return [{"generated_text": generated_text}]
|
|
|
29 |
allowed_ids.add(token_id)
|
30 |
return allowed_ids
|
31 |
|
32 |
+
def filter_allowed_tokens(self, input_ids: torch.Tensor, scores: np.ndarray, allowed_token_ids: set[int]) -> np.ndarray:
|
33 |
"""
|
34 |
+
Modify scores to allow only tokens in the allowed_token_ids set.
|
35 |
Handles both 1D and 2D scores arrays.
|
36 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
if scores.ndim == 1:
|
38 |
# 1D case: Apply mask directly
|
39 |
+
mask = np.isin(np.arange(scores.shape[0]), list(allowed_token_ids))
|
40 |
scores[~mask] = float('-inf')
|
41 |
elif scores.ndim == 2:
|
42 |
# 2D case: Apply mask across each row
|
43 |
for i in range(scores.shape[0]):
|
44 |
+
mask = np.isin(np.arange(scores.shape[1]), list(allowed_token_ids))
|
45 |
scores[i, ~mask] = float('-inf')
|
46 |
else:
|
47 |
raise ValueError(f"Unsupported scores dimension: {scores.ndim}")
|
|
|
89 |
# Decode the output
|
90 |
generated_text = response["choices"][0]["message"]["content"]
|
91 |
|
92 |
+
return [{"generated_text": generated_text}]
|