kjcjohnson commited on
Commit
f715cab
·
1 Parent(s): 4804703
Files changed (1) hide show
  1. loop.py +4 -2
loop.py CHANGED
@@ -45,6 +45,7 @@ class EndpointHandler():
45
 
46
  if grammar_str is None or len(grammar_str) == 0 or grammar_str.isspace():
47
  logits_processors = None
 
48
  else:
49
  print("=== GOT GRAMMAR ===")
50
  print(grammar_str)
@@ -61,7 +62,7 @@ class EndpointHandler():
61
 
62
  #input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]
63
  input_ids = self.tokenizer.apply_chat_template(
64
- [ {"role": "user", "content": "inputs"}],
65
  tokenize=True,
66
  add_generation_prompt=True,
67
  return_tensors="pt"
@@ -85,7 +86,8 @@ class EndpointHandler():
85
  output_scores=True
86
  )
87
 
88
- gad_oracle_processor.reset()
 
89
 
90
  # Detokenize generated output
91
  input_length = 1 if self.model.config.is_encoder_decoder else input_ids.shape[1]
 
45
 
46
  if grammar_str is None or len(grammar_str) == 0 or grammar_str.isspace():
47
  logits_processors = None
48
+ gad_oracle_processor = None
49
  else:
50
  print("=== GOT GRAMMAR ===")
51
  print(grammar_str)
 
62
 
63
  #input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]
64
  input_ids = self.tokenizer.apply_chat_template(
65
+ [{"role": "user", "content": inputs}],
66
  tokenize=True,
67
  add_generation_prompt=True,
68
  return_tensors="pt"
 
86
  output_scores=True
87
  )
88
 
89
+ if gad_oracle_processor is not None:
90
+ gad_oracle_processor.reset()
91
 
92
  # Detokenize generated output
93
  input_length = 1 if self.model.config.is_encoder_decoder else input_ids.shape[1]