My-Chat / modules /grammar.py
Lee Thanh
Upload All
0eeee8c
raw
history blame
869 Bytes
from torch_grammar import GrammarSampler
from transformers.generation.logits_process import LogitsProcessor
from modules import shared
sampler = None
grammar = None
grammar_string = ''
class GrammarLogitsProcessor(LogitsProcessor):
def __init__(self, string):
global sampler, grammar, grammar_string
if string != grammar_string:
grammar_string = string
if string.strip() != '':
string = string.strip() + '\n'
sampler = GrammarSampler(string, 'root', shared.tokenizer)
else:
sampler = None
if sampler is not None:
grammar = sampler.logits_processor()
else:
grammar = None
def __call__(self, input_ids, scores):
if grammar is not None:
scores = grammar(input_ids, scores)
return scores