MathT5-base / MathT5.py
jmeadows17's picture
Update MathT5.py
da41eb2
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
def pretty_print(text, prompt=True):
s = ""
if prompt:
for section in text.split(', '):
premises = section.split(" and ")
if len(premises) > 1:
for premise in premises[:-1]:
s += premise + "\n\n\n" + "and" + "\n\n\n"
s += premises[-1] + "\n\n\n"
else:
s += section + "\n\n\n"
else:
for equation in text.split("and"):
s += equation + "\n\n\n"
return print(s[:-3])
def load_model(model_id):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = T5Tokenizer.from_pretrained(model_id)
model = T5ForConditionalGeneration.from_pretrained(model_id).to(device)
return tokenizer, model
def inference(prompt, tokenizer, model):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input_ids = tokenizer.encode(prompt, return_tensors='pt', max_length=512, truncation=True).to(device)
output = model.generate(input_ids=input_ids, max_length=512, early_stopping=True)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
# post-processing
derivation = generated_text.replace("\\ ","\\")
partial_symbols = derivation.split(" ")
backslash_syms = set([i for i in partial_symbols if "\\" in i])
for i in range(len(partial_symbols)):
sym = partial_symbols[i]
for b_sym in backslash_syms:
if b_sym.replace("\\","") == sym:
partial_symbols[i] = b_sym
return " ".join(partial_symbols)