kjcjohnson
commited on
Commit
·
f715cab
1
Parent(s):
4804703
fixes
Browse files
loop.py
CHANGED
@@ -45,6 +45,7 @@ class EndpointHandler():
|
|
45 |
|
46 |
if grammar_str is None or len(grammar_str) == 0 or grammar_str.isspace():
|
47 |
logits_processors = None
|
|
|
48 |
else:
|
49 |
print("=== GOT GRAMMAR ===")
|
50 |
print(grammar_str)
|
@@ -61,7 +62,7 @@ class EndpointHandler():
|
|
61 |
|
62 |
#input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]
|
63 |
input_ids = self.tokenizer.apply_chat_template(
|
64 |
-
[
|
65 |
tokenize=True,
|
66 |
add_generation_prompt=True,
|
67 |
return_tensors="pt"
|
@@ -85,7 +86,8 @@ class EndpointHandler():
|
|
85 |
output_scores=True
|
86 |
)
|
87 |
|
88 |
-
gad_oracle_processor
|
|
|
89 |
|
90 |
# Detokenize generated output
|
91 |
input_length = 1 if self.model.config.is_encoder_decoder else input_ids.shape[1]
|
|
|
45 |
|
46 |
if grammar_str is None or len(grammar_str) == 0 or grammar_str.isspace():
|
47 |
logits_processors = None
|
48 |
+
gad_oracle_processor = None
|
49 |
else:
|
50 |
print("=== GOT GRAMMAR ===")
|
51 |
print(grammar_str)
|
|
|
62 |
|
63 |
#input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]
|
64 |
input_ids = self.tokenizer.apply_chat_template(
|
65 |
+
[{"role": "user", "content": inputs}],
|
66 |
tokenize=True,
|
67 |
add_generation_prompt=True,
|
68 |
return_tensors="pt"
|
|
|
86 |
output_scores=True
|
87 |
)
|
88 |
|
89 |
+
if gad_oracle_processor is not None:
|
90 |
+
gad_oracle_processor.reset()
|
91 |
|
92 |
# Detokenize generated output
|
93 |
input_length = 1 if self.model.config.is_encoder_decoder else input_ids.shape[1]
|