t5-literary-coreference / get_annotations.py
rmmhicke's picture
Upload 2 files
0d1a7d5
import pandas as pd
import csv
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = 't5-literary-coreference'
device = 'cuda'
print("Loading in data")
df = pd.read_csv('example_input.csv')
df = df.sample(frac=1) # Shuffle dataframe contents
to_annotate = Dataset.from_pandas(df)
speech_excerpts = DatasetDict({"annotate": to_annotate})
print("Loading models")
# Change max_model_length to fit your data
tokenizer = AutoTokenizer.from_pretrained("t5-3b", model_max_length=500)
def preprocess_function(examples, input_text = "input", output_text = "output"):
model_inputs = tokenizer(examples[input_text], max_length=500, truncation=True)
targets = tokenizer(examples[output_text], max_length=500, truncation=True)
model_inputs["labels"] = targets["input_ids"]
return model_inputs
tokenized_speech_excerpts = speech_excerpts.map(preprocess_function, batched=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device=device)
print("Begin creating annotations")
header = ["input", "model_output"]
rows = []
for item in speech_excerpts["annotate"]:
input_ids = tokenizer(item["input"], return_tensors="pt").input_ids
result = model.generate(input_ids.to(device=device), max_length = 500)
rows.append([item["input"], tokenizer.decode(result[0], skip_special_tokens = True)])
f = open("results.csv", "w")
writer = csv.writer(f)
writer.writerow(header)
writer.writerows(rows)
f.close()
print("Finished")