kjcjohnson commited on
Commit
4804703
·
1 Parent(s): 32cbc00

Use chat template and allow empty grammar

Browse files
Files changed (1) hide show
  1. loop.py +21 -12
loop.py CHANGED
@@ -43,20 +43,29 @@ class EndpointHandler():
43
  max_new_tokens = safe_int_cast(data.get("max-new-tokens"), MAX_NEW_TOKENS)
44
  max_time = safe_int_cast(data.get("max-time"), MAX_TIME)
45
 
46
- print("=== GOT GRAMMAR ===")
47
- print(grammar_str)
48
- print("===================")
49
- grammar = IncrementalGrammarConstraint(grammar_str, "root", self.tokenizer)
 
 
 
50
 
51
- # Initialize logits processor for the grammar
52
- gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar)
53
- inf_nan_remove_processor = InfNanRemoveLogitsProcessor()
54
- logits_processors = LogitsProcessorList([
55
- inf_nan_remove_processor,
56
- gad_oracle_processor,
57
- ])
58
 
59
- input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]
 
 
 
 
 
 
60
  input_ids = input_ids.to(self.model.device)
61
 
62
  output = self.model.generate(
 
43
  max_new_tokens = safe_int_cast(data.get("max-new-tokens"), MAX_NEW_TOKENS)
44
  max_time = safe_int_cast(data.get("max-time"), MAX_TIME)
45
 
46
+ if grammar_str is None or len(grammar_str) == 0 or grammar_str.isspace():
47
+ logits_processors = None
48
+ else:
49
+ print("=== GOT GRAMMAR ===")
50
+ print(grammar_str)
51
+ print("===================")
52
+ grammar = IncrementalGrammarConstraint(grammar_str, "root", self.tokenizer)
53
 
54
+ # Initialize logits processor for the grammar
55
+ gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar)
56
+ inf_nan_remove_processor = InfNanRemoveLogitsProcessor()
57
+ logits_processors = LogitsProcessorList([
58
+ inf_nan_remove_processor,
59
+ gad_oracle_processor,
60
+ ])
61
 
62
+ #input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]
63
+ input_ids = self.tokenizer.apply_chat_template(
64
+ [ {"role": "user", "content": "inputs"}],
65
+ tokenize=True,
66
+ add_generation_prompt=True,
67
+ return_tensors="pt"
68
+ )
69
  input_ids = input_ids.to(self.model.device)
70
 
71
  output = self.model.generate(