Spaces:
Running
Running
import torch, argparse, json | |
import benepar, spacy_stanza | |
import numpy as np | |
import sys, os | |
import csv | |
from nltk.tree import Tree | |
sys.path.insert(0, os.path.join(sys.path[0], '../scripts/')) | |
from tree_helper import chart_from_tree, pad_charts, padded_chart_from_spans | |
sys.path.insert(0, os.path.join(sys.path[0], '../../misc/self-attentive-parser/src/')) | |
import evaluate | |
from spacy.lang.en import English | |
from collections import defaultdict | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from improved_diffusion.rounding import rounding_func, load_models, load_tokenizer | |
nlp = English() | |
tokenizer_spacy = nlp.tokenizer | |
def eval_ppl2(args, text_samples): | |
print(f'loading from {args.model_name_or_path}') | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name_or_path, # path to the AR model trained for LMing this task. | |
).cuda() | |
if 'r2l' in args.model_name_or_path: | |
print('Use the right-to-left encoding.') | |
args.model_path = 'predictability/diffusion_models_v6/diff_e2e-tgt_pad_rand16_transformer_' \ | |
'lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart/ema_0.9999_200000.pt' | |
tokenizer = load_tokenizer('e2e-tgt', 'random', os.path.split(args.model_path)[0]) | |
# print(args.modality, tokenizer, args.experiment) | |
reverse_tokenizer = {v: k for k, v in tokenizer.items()} | |
full_score = [] | |
for idxx, (gold, full_word_lst) in enumerate(text_samples.items()): | |
# print(len(full_word_lst), full_word_lst[0]) | |
agg_loss = [] | |
for x in full_word_lst: | |
# x = " ".join(x).split() | |
if 'r2l' in args.model_name_or_path: | |
string = ["START"] + list(reversed(x)) + ["END"] | |
tokenized_x = [reverse_tokenizer.get(s, reverse_tokenizer['UNK']) for s in string] | |
else: | |
tokenized_x = [reverse_tokenizer['START']] + [reverse_tokenizer.get(s, reverse_tokenizer['UNK']) for s in x] \ | |
+ [reverse_tokenizer['END']] | |
# print(tokenized_x) | |
tokenized_x = torch.LongTensor(tokenized_x).cuda() | |
labels = tokenized_x.clone() | |
labels[labels == reverse_tokenizer['PAD']] = -100 | |
model_output = model(tokenized_x, labels=labels) | |
# print(model_output.loss) | |
# if idxx == 3: | |
# print(tokenized_x, model_output.loss.item()) | |
agg_loss.append(model_output.loss.item()) | |
example_mean_score = torch.tensor(agg_loss).mean() | |
# print(f'\nthe mean loss is {example_mean_score} for index', idxx ) | |
full_score.append(example_mean_score) | |
full_score_ = np.array(full_score).mean() | |
print(f'full NLL score is {full_score_} for {len(full_score)}') | |
print(f'full PPL score is {np.e ** full_score_} for {len(full_score)}') | |
def eval_ppl(args, text_samples): | |
''' | |
Evaluating using GPT2 finetuned on this task... | |
:param text_lst: | |
:return: | |
''' | |
# load model | |
print(f'loading from {args.model_name_or_path}') | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name_or_path, # path to the AR model trained for LMing this task. | |
).cuda() | |
# load tokenizer. | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) | |
print('finished loading models.') | |
args.model_path = 'predictability/diffusion_models_v6/diff_e2e-tgt_pad_rand16_transformer_' \ | |
'lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart/ema_0.9999_200000.pt' | |
diff_tokenizer = load_tokenizer('e2e-tgt', 'random', os.path.split(args.model_path)[0]) | |
reverse_diff_tokenizer = {v: k for k, v in diff_tokenizer.items()} | |
full_score = [] | |
for gold, full_word_lst in text_samples.items(): | |
agg_loss = [] | |
for x in full_word_lst: | |
x = [kk if kk in reverse_diff_tokenizer else 'UNK' for kk in x] | |
x = tokenizer.bos_token + " ".join(x) + tokenizer.eos_token | |
# print(x) | |
# should also add BOS EOS token? | |
tokenized_x = tokenizer(x, return_tensors='pt') #[reverse_tokenizer[s] for s in x] | |
input_ids = tokenized_x['input_ids'].cuda() | |
labels = input_ids.clone() | |
# print(tokenized_x) | |
# tokenized_x = torch.LongTensor(tokenized_x).cuda() | |
# labels = tokenized_x.clone() | |
# labels[labels == reverse_tokenizer['PAD']] = -100 | |
model_output = model(input_ids, labels=labels) | |
agg_loss.append(model_output.loss.item()) | |
example_mean_score = torch.tensor(agg_loss).mean() | |
# print(f'\nthe mean loss is {example_mean_score}', ) | |
full_score.append(example_mean_score) | |
full_score_ = np.array(full_score).mean() | |
print(f'full NLL score is {full_score_} for {len(full_score)}') | |
print(f'full PPL score is {np.e ** full_score_} for {len(full_score)}') | |
def read_files(args): | |
''' | |
:param args: | |
:return: list of tokenized sentences. | |
''' | |
if args.input_format == 'file': | |
text_samples = [] | |
if args.input_text.endswith('json'): | |
with open(args.input_text, 'r') as f: | |
for line in f: | |
words = [x.text for x in tokenizer_spacy(json.loads(line)[0])] | |
text_samples.append(words) | |
# text_samples.append(json.loads(line)[0].split(' ')) | |
else: | |
with open(args.input_text, 'r') as f: | |
for line in f: | |
text_samples.append(line.strip().split()) | |
# remove trailing PAD tokens. | |
text_samples2 = [] | |
for sent in text_samples: | |
tempsent = [x for x in sent if x != 'PAD'] | |
if tempsent[0] == 'START': | |
tempsent = tempsent[1:] | |
if tempsent[-1] == 'END': | |
tempsent = tempsent[:-1] | |
if tempsent[-1] == '\n' and args.mode in ['e2e-tgt-tree', 'e2e-tgt-tree-paired']: | |
tempsent[-1] = '.' | |
text_samples2.append(tempsent) | |
return text_samples2 | |
elif args.input_format == 'paired': | |
import ast | |
# nlp = English() | |
# tokenizer = nlp.tokenizer | |
result_lst = defaultdict(list) | |
if args.input_text.endswith('json'): | |
with open(args.input_text, 'r') as f: | |
for line in f: | |
try: | |
line = json.loads(line) | |
except: | |
if args.mode == 'e2e-tgt-spans-paired': | |
line = ast.literal_eval(line) | |
line = {tuple(ast.literal_eval(k[0])) : v for k, v in line.items()} | |
result_lst.update(line) | |
else: | |
line = ast.literal_eval(line) | |
result_lst.update(line) | |
elif args.input_text.endswith('log'): | |
with open(args.input_text, 'r') as csvfile: | |
roc_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|') | |
for idx, row in enumerate(roc_reader): | |
if idx == 0: continue | |
if args.mode == 'e2e-tgt-spans-paired' or args.mode == 'e2e-tgt-length-paired': | |
pos = tuple(ast.literal_eval(row[0])) | |
if args.mode == 'e2e-tgt-length-paired': | |
pos = list(pos) | |
pos[0] = int(pos[0]) + 2 # because this count didn't accounted for START and END | |
pos = tuple(pos) | |
else: | |
pos = tuple(row[0].split()) | |
result_lst[pos].append(row[2]) | |
clean_result_lst = {} | |
for k, text_samples in result_lst.items(): | |
text_samples2 = [] | |
for sent in text_samples: | |
sent = sent.split(' ') | |
# KEY DEBUG. | |
# sent = [x.text for x in tokenizer_spacy(sent)] | |
# print(sent, sent2) | |
# 10/0 | |
tempsent = [x for x in sent if x != 'PAD'] | |
if tempsent[0] == 'START': | |
tempsent = tempsent[1:] | |
if tempsent[-1] == 'END': | |
tempsent = tempsent[:-1] | |
if tempsent[-1] == '\n' and args.mode == 'e2e-tgt-tree': | |
tempsent[-1] = '.' | |
# KEY DEBUG. | |
tempsent = " ".join(tempsent) | |
tempsent = [x.text for x in tokenizer_spacy(tempsent)] | |
text_samples2.append(tempsent) | |
if k[0] == 'START' and k[-1] == 'END': | |
kk_ = k[1:-1] | |
else: | |
kk_ = k | |
clean_result_lst[kk_] = text_samples2 # remove start and end from the training data. | |
return clean_result_lst | |
def eval_parse(parser, generated, tree_vocab): | |
sent_lst = [] | |
for sent in generated: | |
# print(sent) | |
input_sentence1 = benepar.InputSentence( | |
words=sent, | |
) | |
sent_lst.append(input_sentence1) | |
parse_lst = list(parser.parse_sents(sent_lst)) | |
# print(examples['text'][:10]) | |
assert len(parse_lst) == len(generated) | |
# print(parse_lst[:2]) | |
spans_lst = [] | |
for parse in parse_lst: | |
chart, spans = chart_from_tree(tree_vocab, parse, verbose=True) | |
spans_lst.append(spans) | |
return parse_lst, spans_lst | |
def levenshteinDistance(s1, s2): | |
if len(s1) > len(s2): | |
s1, s2 = s2, s1 | |
distances = range(len(s1) + 1) | |
for i2, c2 in enumerate(s2): | |
distances_ = [i2+1] | |
for i1, c1 in enumerate(s1): | |
if c1 == c2: | |
distances_.append(distances[i1]) | |
else: | |
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) | |
distances = distances_ | |
return distances[-1] | |
def score_spans(gold_spans, generated_span): | |
print(gold_spans) | |
print(generated_span) | |
gold_spans = set([gold_spans]) | |
generated_span = set(generated_span) | |
intersection = gold_spans.intersection(generated_span) | |
print(intersection, len(intersection) / len(gold_spans)) | |
# union = gold_spans.union(generated_span) | |
# print(len(union), len(intersection)) | |
# if unlabeled: | |
# print(generated_span) | |
# unlabeled_gold_spans = set([(a,b) for (a, b, v) in gold_spans]) | |
# unlabeled_generated_span =set([(a,b) for (a, b, v) in generated_span]) | |
# intersection = gold_spans.intersection(generated_span) | |
# union = gold_spans.union(generated_span) | |
return len(intersection) / len(gold_spans) | |
def score_tree(gold_tree, pred_trees): | |
# print([x.leaves() for x in pred_trees]) | |
def reset_leaves(tree_): | |
simple_increm = 0 | |
for s in tree_.subtrees(lambda t: t.height() == 2): | |
s[0] = simple_increm | |
s._label = 'NN' | |
simple_increm += 1 | |
return simple_increm | |
# reset. | |
increm_gold = reset_leaves(gold_tree) | |
# print(increm_gold) | |
for i, pred in enumerate(pred_trees): | |
increm_pred = reset_leaves(pred) | |
# print(increm_pred, 'pred', i) | |
use_evalb = True | |
if use_evalb: | |
# print(len(gold_tree), len(pred_trees), gold_tree) | |
gold_trees = [gold_tree] * len(pred_trees) | |
print(len(gold_tree.leaves()), [len(x.leaves()) for x in pred_trees]) | |
# print(pred_trees[0]) | |
dev_fscore = evaluate.evalb('diffusion_lm/misc/self-attentive-parser/EVALB', | |
gold_trees, pred_trees) | |
print(dev_fscore) | |
return dev_fscore | |
def score_pos(gold_pos, generated_pos): | |
ed = levenshteinDistance(gold_pos, generated_pos) | |
return 1 - (ed / len(gold_pos)) | |
def score_pos_em(gold_pos, generated_pos): | |
# print(len(gold_pos), len(generated_pos), gold_pos, generated_pos) | |
if len(generated_pos) > len(gold_pos): | |
generated_pos = generated_pos[:len(gold_pos)] | |
elif len(generated_pos) < len(gold_pos): | |
generated_pos = generated_pos + ['PAD'] * (len(gold_pos) - len(generated_pos)) | |
assert len(gold_pos) == len(generated_pos) | |
correct = 0 | |
all = 0 | |
for x1, x2 in zip(gold_pos, generated_pos): | |
if x1 == x2: | |
correct += 1 | |
all += 1 | |
return correct/all | |
def score_attributes(gold_att, generated): | |
if gold_att in generated: | |
return 1. | |
else: | |
return 0. | |
def eval_pos(tagger, generated_text): | |
generated_pos = [] | |
for sent in generated_text: | |
sent_full = " ".join(sent) | |
doc = tagger(sent_full) | |
generated_pos.append([token.pos_ for token in doc]) | |
return generated_pos | |
def eval_(args, text_samples): | |
if args.mode == 'e2e-tgt-tree': | |
parser = benepar.Parser("benepar_en3") | |
tree_vocab = parser._parser.config["label_vocab"] | |
if args.gold_ref == 'full': | |
# toy1 = 'START Located in riverside area , Alimentum restaurant is a place to bring the whole family . \n END'.split() | |
# toy1 = 'START Alimentum is not a family - friendly place , located in city centre . \n END'.split() | |
toy1 = ['START', 'The', 'Vaults', 'pub', 'near', 'Café', 'Adriatic', 'has', 'a', '5', 'star', 'rating', | |
'.', 'Prices', 'start', 'at', '£', '30', '.', 'END'] | |
input_sentence1 = benepar.InputSentence( | |
words=toy1[1:-1], | |
) | |
gold_parse = list(parser.parse_sents([input_sentence1]))[0] | |
chart, gold_spans = chart_from_tree(tree_vocab, gold_parse, verbose=True) | |
print(len(toy1[1:-1]), len(list(gold_parse.leaves()))) | |
elif args.gold_ref == 'span': | |
# spans = [(10, 14, 'ADJP')] | |
gold_spans = [(0, 4, 'S::VP')] | |
gold_spans = [(0, 0, 'NP')] | |
gold_spans = [(9, 13, 'ADJP')] | |
# gold_spans = [(9, 13, 'PP')] | |
print(text_samples[:1]) | |
# correct for length: | |
target_len = len(gold_parse.leaves()) | |
print(gold_parse.leaves(), 'target') | |
for i, x in enumerate(text_samples): | |
if len(x) == target_len: | |
continue | |
elif len(x) > target_len: | |
text_samples[i] = x[:target_len] | |
else: | |
print('padded to same length', (target_len-len(x))) | |
text_samples[i] = x + ['.'] * (target_len-len(x)) | |
# print(text_samples[i]) | |
# print('SAD, our model is shorter??') | |
generated_parse, generated_span = eval_parse(parser, text_samples, tree_vocab) | |
# print(gold_spans) | |
# print(generated_span[:2]) | |
evalb_score = score_tree(gold_parse, generated_parse) | |
print([len(x) for x in text_samples]) | |
score_lst = [] | |
for x in generated_span: | |
score_lst.append(score_spans(gold_spans, x)) | |
print(np.array(score_lst).mean()) | |
elif args.mode == 'e2e-tgt-pos': | |
tagger = spacy_stanza.load_pipeline("en", processors='tokenize,mwt,pos', ) #processors={"tokenize": "spacy",} | |
if args.gold_ref == 'full': | |
toy1 = 'START The Mill is a coffee shop with an expensive menu near The Sorrento . \n END'.split() | |
toy1 = ['START', 'The', 'Vaults', 'pub', 'near', 'Café', 'Adriatic', 'has', 'a', '5', 'star', 'rating', '.', | |
'Prices', 'start', 'at', '£', '30', '.', '\n', 'END'] | |
sent_full = " ".join(toy1[1:-1]) | |
doc = tagger(sent_full) | |
gold_pos = [token.pos_ for token in doc] | |
elif args.gold_ref == 'span': | |
gold_pos = [(9, 'PROPN')] | |
generated_pos = eval_pos(tagger, text_samples) | |
score_lst = [] | |
score_lst2 = [] | |
for x in generated_pos: | |
print(gold_pos) | |
print(x) | |
print() | |
score_lst.append(score_pos(gold_pos, x)) | |
score_lst2.append(score_pos_em(gold_pos, x)) | |
print(np.array(score_lst).mean()) | |
print(np.array(score_lst2).mean()) | |
elif args.mode == 'e2e-tgt-pos-paired': | |
import stanza | |
nlp = spacy_stanza.load_pipeline("en", processors={"tokenize": "spacy"}) | |
print(nlp) | |
# nlp = stanza.Pipeline("en", processors={"tokenize": "spacy", 'pos': 'combined'}, package=None) | |
full_score = [] | |
for gold, full_word_lst in text_samples.items(): | |
print(gold, len(full_word_lst), full_word_lst[:2]) | |
# full_word_lst = full_word_lst[:2] | |
sent_lst = [" ".join(seq) for seq in full_word_lst] | |
sent_full = " ".join(sent_lst) | |
# print(sent_lst) | |
try: | |
doc = nlp(sent_full) | |
doc_token_pos = [(token.text, token.pos_,) for token in doc] | |
len_lst = [len(seq) for seq in full_word_lst] | |
print(sum(len_lst), len(doc_token_pos), 'should be equal!!! ') | |
assert sum(len_lst) == len(doc_token_pos) | |
pos_lst = [] | |
init_idx = 0 | |
for len_temp in len_lst: | |
pos_lst.append([x[1] for x in doc_token_pos[init_idx:init_idx + len_temp]]) | |
init_idx = init_idx + len_temp | |
except: | |
print(f'stanza pipeline failed... for this {gold}') | |
# parse each sentence separately... | |
pos_lst = [] | |
for single_sent in sent_lst: | |
doc = nlp(single_sent) | |
# doc_token_pos = [(token.text, token.pos_,) for token in doc] | |
pos_lst.append([ token.pos_ for token in doc]) | |
score_lst = [] | |
score_lst2 = [] | |
for x in pos_lst: | |
score_lst.append(score_pos(gold, x)) | |
score_lst2.append(score_pos_em(gold, x)) | |
score_ed = np.array(score_lst).mean() | |
score_em = np.array(score_lst2).mean() | |
print(len(score_lst), score_ed, score_em) | |
full_score.append(score_em) | |
full_score_em = np.array(full_score).mean() | |
print(full_score_em, f"\pm {np.array(full_score).std()}", len(full_score)) | |
if args.mode == 'e2e-tgt-tree-paired': | |
parser = benepar.Parser("benepar_en3") | |
tree_vocab = parser._parser.config["label_vocab"] | |
full_score = [] | |
for idx, (gold_parse, full_word_lst) in enumerate(text_samples.items()): | |
# to avoid evalb complain --> change \n to . | |
gold_parse_str = gold_parse[0] | |
gold_parse_str = gold_parse_str.replace('\n', '.') | |
# print([gold_parse_str], 'gold tree string ') | |
gold_parse = Tree.fromstring(gold_parse_str) | |
target_len = len(gold_parse.leaves()) | |
# print(gold_parse.leaves(), 'target') | |
# print(full_word_lst) | |
for i, x in enumerate(full_word_lst): | |
if len(x) == target_len: | |
continue | |
elif len(x) > target_len: | |
print('generated seq is longer than gold seq') | |
full_word_lst[i] = x[:target_len] | |
else: | |
print('padded to same length', (target_len - len(x))) | |
full_word_lst[i] = x + ['.'] * (target_len - len(x)) | |
# print(text_samples[i]) | |
# print('SAD, our model is shorter??') | |
generated_parse, generated_span = eval_parse(parser, full_word_lst, tree_vocab) | |
evalb_score = score_tree(gold_parse, generated_parse) # inputs are nltk.Tree | |
# print(type(evalb_score)) | |
print(evalb_score.fscore) | |
full_score.append(evalb_score.fscore) | |
full_score_f1 = np.array(full_score).mean() | |
# print(full_score_f1, len(full_score)) | |
print(full_score_f1, f"\pm {np.array(full_score).std()}", len(full_score)) | |
elif args.mode == 'e2e-tgt-spans-paired': | |
parser = benepar.Parser("benepar_en3") | |
tree_vocab = parser._parser.config["label_vocab"] | |
full_score = [] | |
for idx, (gold_spans, full_word_lst) in enumerate(text_samples.items()): | |
# to avoid evalb complain --> change \n to . | |
print(gold_spans, '11 gold') | |
generated_parse, generated_span = eval_parse(parser, full_word_lst, tree_vocab) | |
score_lst = [] | |
for x in generated_span: | |
score_lst.append(score_spans(gold_spans, x)) | |
print(score_lst) | |
score_lst_mean = np.array(score_lst).mean() | |
full_score.append(score_lst_mean) | |
full_score_span = np.array(full_score).mean() | |
print(full_score_span, f"\pm {np.array(full_score).std()}", len(full_score)) | |
if args.mode == 'e2e-tgt-attribute-paired': | |
full_score = [] | |
for idx, (attribute, full_word_lst) in enumerate(text_samples.items()): | |
# print(attribute) | |
attribute = " ".join(attribute).split(':')[1].strip() | |
gold_attribute = attribute | |
score_lst = [] | |
for i, x in enumerate(full_word_lst): | |
# print(gold_attribute, x) | |
score_lst.append(score_attributes(gold_attribute, " ".join(x))) | |
score_lst_mean = np.array(score_lst).mean() | |
full_score.append(score_lst_mean) | |
full_score_mean = np.array(full_score).mean() | |
# print(full_score_mean, len(full_score)) | |
print(full_score_mean, f"\pm {np.array(full_score).std()}", len(full_score)) | |
if args.mode == 'e2e-tgt-length-paired': | |
full_score = [] | |
for idx, (attribute, full_word_lst) in enumerate(text_samples.items()): | |
tgt_len = int(attribute[0]) - 2 # remove START and END. | |
score_lst = [] | |
for i, x in enumerate(full_word_lst): | |
if tgt_len == len(x): | |
# if np.abs(tgt_len - len(x)) <= 2: | |
score_lst.append(1.) | |
else: | |
score_lst.append(0.) | |
score_lst_mean = np.array(score_lst).mean() | |
full_score.append(score_lst_mean) | |
full_score_mean = np.array(full_score).mean() | |
# print(full_score_mean, len(full_score)) | |
print(full_score_mean, f"\pm {np.array(full_score).std()}", len(full_score)) | |
elif args.mode == 'e2e-tgt-attribute': | |
gold_attribute = "" | |
score_lst = [] | |
for x in text_samples: | |
score_lst.append(score_attributes(gold_attribute, x)) | |
print(np.array(score_lst).mean()) | |
if __name__ == '__main__': | |
# 'diffusion_lm/improved_diffusion/out_gen/diff_e2e-tgt_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart.ema_0.9999_200000.pt.infill_control_tree_50x64x16_tree_partial-cat-lgv0.1.json' | |
parser = argparse.ArgumentParser(description='training args.') | |
parser.add_argument('--input_text', type=str, default='diffusion_lm/improved_diffusion/out_gen/diff_e2e-tgt_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart.ema_0.9999_200000.pt.' | |
'infill_control_tree_50x64x16_tree_partial-cat-lgv0.1.json',) | |
parser.add_argument('--input_format', type=str, default='batch', help='wp, wikitext') | |
parser.add_argument('--mode', type=str, default='e2e-tgt-tree', help='') | |
parser.add_argument('--gold_ref', type=str, default='full', help='') | |
parser.add_argument('--model_name_or_path', type=str, default='predictability/diff_models/e2e-tgt_e=20_b=64_m=gpt2_wikitext-103-raw-v1_101_wp_finetune_UNK', help='') | |
# default='predictability/diff_models/e2e-tgt_e=6_b=10_m=gpt2_wikitext-103-raw-v1_101_wp_pad', help='') | |
args = parser.parse_args() | |
text_samples = read_files(args) | |
eval_(args, text_samples) | |
eval_ppl(args, text_samples) | |
# eval_ppl2(args, text_samples) | |