Upload 2 files
Browse files- get_annotations.py +51 -0
- get_ent_clusters.py +82 -0
get_annotations.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import csv
|
3 |
+
|
4 |
+
from datasets import Dataset, DatasetDict
|
5 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
6 |
+
|
7 |
+
model_name = 't5-literary-coreference'
|
8 |
+
device = 'cuda'
|
9 |
+
|
10 |
+
print("Loading in data")
|
11 |
+
|
12 |
+
df = pd.read_csv('example_input.csv')
|
13 |
+
df = df.sample(frac=1) # Shuffle dataframe contents
|
14 |
+
|
15 |
+
to_annotate = Dataset.from_pandas(df)
|
16 |
+
|
17 |
+
speech_excerpts = DatasetDict({"annotate": to_annotate})
|
18 |
+
|
19 |
+
print("Loading models")
|
20 |
+
# Change max_model_length to fit your data
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained("t5-3b", model_max_length=500)
|
22 |
+
|
23 |
+
def preprocess_function(examples, input_text = "input", output_text = "output"):
|
24 |
+
model_inputs = tokenizer(examples[input_text], max_length=500, truncation=True)
|
25 |
+
|
26 |
+
targets = tokenizer(examples[output_text], max_length=500, truncation=True)
|
27 |
+
|
28 |
+
model_inputs["labels"] = targets["input_ids"]
|
29 |
+
|
30 |
+
return model_inputs
|
31 |
+
|
32 |
+
tokenized_speech_excerpts = speech_excerpts.map(preprocess_function, batched=True)
|
33 |
+
|
34 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device=device)
|
35 |
+
|
36 |
+
print("Begin creating annotations")
|
37 |
+
header = ["input", "model_output"]
|
38 |
+
rows = []
|
39 |
+
|
40 |
+
for item in speech_excerpts["annotate"]:
|
41 |
+
input_ids = tokenizer(item["input"], return_tensors="pt").input_ids
|
42 |
+
result = model.generate(input_ids.to(device=device), max_length = 500)
|
43 |
+
rows.append([item["input"], tokenizer.decode(result[0], skip_special_tokens = True)])
|
44 |
+
|
45 |
+
f = open("results.csv", "w")
|
46 |
+
writer = csv.writer(f)
|
47 |
+
writer.writerow(header)
|
48 |
+
writer.writerows(rows)
|
49 |
+
f.close()
|
50 |
+
|
51 |
+
print("Finished")
|
get_ent_clusters.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import csv
|
5 |
+
|
6 |
+
def extract_paren(annotation):
|
7 |
+
ents = []
|
8 |
+
for i in range(len(annotation)):
|
9 |
+
if annotation[i] == "[":
|
10 |
+
ent = "["
|
11 |
+
open_paren = 0
|
12 |
+
|
13 |
+
for j in range(i+1, len(annotation)):
|
14 |
+
if annotation[j] == "[":
|
15 |
+
open_paren += 1
|
16 |
+
elif annotation[j] == "]":
|
17 |
+
if open_paren > 0:
|
18 |
+
open_paren -= 1
|
19 |
+
ent = ent[:len(ent)-3]
|
20 |
+
else:
|
21 |
+
|
22 |
+
ent += "]"
|
23 |
+
digit = re.search(r": [0-9]{1,3}", ent)
|
24 |
+
|
25 |
+
if digit:
|
26 |
+
matches = re.findall(r": [0-9]{1,3}", annotation[:i])
|
27 |
+
str_index = annotation[:i].count(" ") - len(matches)
|
28 |
+
ent += "|" + str(str_index)
|
29 |
+
ents.append(ent)
|
30 |
+
break
|
31 |
+
else:
|
32 |
+
ent += annotation[j]
|
33 |
+
return ents
|
34 |
+
|
35 |
+
def create_clusters(ents):
|
36 |
+
clusters = {}
|
37 |
+
|
38 |
+
for e in ents:
|
39 |
+
digit_ann = re.search(r": [0-9]{1,3}", e)
|
40 |
+
if digit_ann:
|
41 |
+
clean_e = e.replace("[", "").replace("]", "").replace(digit_ann.group(), "")
|
42 |
+
|
43 |
+
digit = re.search(r"[0-9]{1,3}", digit_ann.group())
|
44 |
+
digit = int(digit.group())
|
45 |
+
|
46 |
+
if digit not in clusters:
|
47 |
+
clusters[digit] = []
|
48 |
+
|
49 |
+
clusters[digit].append(clean_e)
|
50 |
+
else:
|
51 |
+
print("OH NO:", e)
|
52 |
+
print()
|
53 |
+
|
54 |
+
return clusters
|
55 |
+
|
56 |
+
headers = ["input", "model_output", "model_output_clusters"]
|
57 |
+
|
58 |
+
df = pd.read_csv("results.csv")
|
59 |
+
|
60 |
+
rows = []
|
61 |
+
for index, row in df.iterrows():
|
62 |
+
annotation = row["model_output"]
|
63 |
+
|
64 |
+
if isinstance(annotation, str):
|
65 |
+
ann_ents = extract_paren(annotation)
|
66 |
+
|
67 |
+
ann_clusters = {}
|
68 |
+
if ann_ents:
|
69 |
+
ann_clusters = create_clusters(ann_ents)
|
70 |
+
else:
|
71 |
+
ann_clusters = {}
|
72 |
+
|
73 |
+
|
74 |
+
new_row = [row["input"], annotation, str(ann_clusters)]
|
75 |
+
rows.append(new_row)
|
76 |
+
|
77 |
+
|
78 |
+
f = open("cluster_results.csv", "w")
|
79 |
+
writer = csv.writer(f)
|
80 |
+
writer.writerow(headers)
|
81 |
+
writer.writerows(rows)
|
82 |
+
f.close()
|