Lang2mol-Diff / src /control_gen /eval_control.py
ndhieunguyen's picture
Add application file
7dd9869
import torch, argparse, json
import benepar, spacy_stanza
import numpy as np
import sys, os
import csv
from nltk.tree import Tree
sys.path.insert(0, os.path.join(sys.path[0], '../scripts/'))
from tree_helper import chart_from_tree, pad_charts, padded_chart_from_spans
sys.path.insert(0, os.path.join(sys.path[0], '../../misc/self-attentive-parser/src/'))
import evaluate
from spacy.lang.en import English
from collections import defaultdict
from transformers import AutoModelForCausalLM, AutoTokenizer
from improved_diffusion.rounding import rounding_func, load_models, load_tokenizer
nlp = English()
tokenizer_spacy = nlp.tokenizer
def eval_ppl2(args, text_samples):
print(f'loading from {args.model_name_or_path}')
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path, # path to the AR model trained for LMing this task.
).cuda()
if 'r2l' in args.model_name_or_path:
print('Use the right-to-left encoding.')
args.model_path = 'predictability/diffusion_models_v6/diff_e2e-tgt_pad_rand16_transformer_' \
'lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart/ema_0.9999_200000.pt'
tokenizer = load_tokenizer('e2e-tgt', 'random', os.path.split(args.model_path)[0])
# print(args.modality, tokenizer, args.experiment)
reverse_tokenizer = {v: k for k, v in tokenizer.items()}
full_score = []
for idxx, (gold, full_word_lst) in enumerate(text_samples.items()):
# print(len(full_word_lst), full_word_lst[0])
agg_loss = []
for x in full_word_lst:
# x = " ".join(x).split()
if 'r2l' in args.model_name_or_path:
string = ["START"] + list(reversed(x)) + ["END"]
tokenized_x = [reverse_tokenizer.get(s, reverse_tokenizer['UNK']) for s in string]
else:
tokenized_x = [reverse_tokenizer['START']] + [reverse_tokenizer.get(s, reverse_tokenizer['UNK']) for s in x] \
+ [reverse_tokenizer['END']]
# print(tokenized_x)
tokenized_x = torch.LongTensor(tokenized_x).cuda()
labels = tokenized_x.clone()
labels[labels == reverse_tokenizer['PAD']] = -100
model_output = model(tokenized_x, labels=labels)
# print(model_output.loss)
# if idxx == 3:
# print(tokenized_x, model_output.loss.item())
agg_loss.append(model_output.loss.item())
example_mean_score = torch.tensor(agg_loss).mean()
# print(f'\nthe mean loss is {example_mean_score} for index', idxx )
full_score.append(example_mean_score)
full_score_ = np.array(full_score).mean()
print(f'full NLL score is {full_score_} for {len(full_score)}')
print(f'full PPL score is {np.e ** full_score_} for {len(full_score)}')
def eval_ppl(args, text_samples):
'''
Evaluating using GPT2 finetuned on this task...
:param text_lst:
:return:
'''
# load model
print(f'loading from {args.model_name_or_path}')
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path, # path to the AR model trained for LMing this task.
).cuda()
# load tokenizer.
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
print('finished loading models.')
args.model_path = 'predictability/diffusion_models_v6/diff_e2e-tgt_pad_rand16_transformer_' \
'lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart/ema_0.9999_200000.pt'
diff_tokenizer = load_tokenizer('e2e-tgt', 'random', os.path.split(args.model_path)[0])
reverse_diff_tokenizer = {v: k for k, v in diff_tokenizer.items()}
full_score = []
for gold, full_word_lst in text_samples.items():
agg_loss = []
for x in full_word_lst:
x = [kk if kk in reverse_diff_tokenizer else 'UNK' for kk in x]
x = tokenizer.bos_token + " ".join(x) + tokenizer.eos_token
# print(x)
# should also add BOS EOS token?
tokenized_x = tokenizer(x, return_tensors='pt') #[reverse_tokenizer[s] for s in x]
input_ids = tokenized_x['input_ids'].cuda()
labels = input_ids.clone()
# print(tokenized_x)
# tokenized_x = torch.LongTensor(tokenized_x).cuda()
# labels = tokenized_x.clone()
# labels[labels == reverse_tokenizer['PAD']] = -100
model_output = model(input_ids, labels=labels)
agg_loss.append(model_output.loss.item())
example_mean_score = torch.tensor(agg_loss).mean()
# print(f'\nthe mean loss is {example_mean_score}', )
full_score.append(example_mean_score)
full_score_ = np.array(full_score).mean()
print(f'full NLL score is {full_score_} for {len(full_score)}')
print(f'full PPL score is {np.e ** full_score_} for {len(full_score)}')
def read_files(args):
'''
:param args:
:return: list of tokenized sentences.
'''
if args.input_format == 'file':
text_samples = []
if args.input_text.endswith('json'):
with open(args.input_text, 'r') as f:
for line in f:
words = [x.text for x in tokenizer_spacy(json.loads(line)[0])]
text_samples.append(words)
# text_samples.append(json.loads(line)[0].split(' '))
else:
with open(args.input_text, 'r') as f:
for line in f:
text_samples.append(line.strip().split())
# remove trailing PAD tokens.
text_samples2 = []
for sent in text_samples:
tempsent = [x for x in sent if x != 'PAD']
if tempsent[0] == 'START':
tempsent = tempsent[1:]
if tempsent[-1] == 'END':
tempsent = tempsent[:-1]
if tempsent[-1] == '\n' and args.mode in ['e2e-tgt-tree', 'e2e-tgt-tree-paired']:
tempsent[-1] = '.'
text_samples2.append(tempsent)
return text_samples2
elif args.input_format == 'paired':
import ast
# nlp = English()
# tokenizer = nlp.tokenizer
result_lst = defaultdict(list)
if args.input_text.endswith('json'):
with open(args.input_text, 'r') as f:
for line in f:
try:
line = json.loads(line)
except:
if args.mode == 'e2e-tgt-spans-paired':
line = ast.literal_eval(line)
line = {tuple(ast.literal_eval(k[0])) : v for k, v in line.items()}
result_lst.update(line)
else:
line = ast.literal_eval(line)
result_lst.update(line)
elif args.input_text.endswith('log'):
with open(args.input_text, 'r') as csvfile:
roc_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|')
for idx, row in enumerate(roc_reader):
if idx == 0: continue
if args.mode == 'e2e-tgt-spans-paired' or args.mode == 'e2e-tgt-length-paired':
pos = tuple(ast.literal_eval(row[0]))
if args.mode == 'e2e-tgt-length-paired':
pos = list(pos)
pos[0] = int(pos[0]) + 2 # because this count didn't accounted for START and END
pos = tuple(pos)
else:
pos = tuple(row[0].split())
result_lst[pos].append(row[2])
clean_result_lst = {}
for k, text_samples in result_lst.items():
text_samples2 = []
for sent in text_samples:
sent = sent.split(' ')
# KEY DEBUG.
# sent = [x.text for x in tokenizer_spacy(sent)]
# print(sent, sent2)
# 10/0
tempsent = [x for x in sent if x != 'PAD']
if tempsent[0] == 'START':
tempsent = tempsent[1:]
if tempsent[-1] == 'END':
tempsent = tempsent[:-1]
if tempsent[-1] == '\n' and args.mode == 'e2e-tgt-tree':
tempsent[-1] = '.'
# KEY DEBUG.
tempsent = " ".join(tempsent)
tempsent = [x.text for x in tokenizer_spacy(tempsent)]
text_samples2.append(tempsent)
if k[0] == 'START' and k[-1] == 'END':
kk_ = k[1:-1]
else:
kk_ = k
clean_result_lst[kk_] = text_samples2 # remove start and end from the training data.
return clean_result_lst
def eval_parse(parser, generated, tree_vocab):
sent_lst = []
for sent in generated:
# print(sent)
input_sentence1 = benepar.InputSentence(
words=sent,
)
sent_lst.append(input_sentence1)
parse_lst = list(parser.parse_sents(sent_lst))
# print(examples['text'][:10])
assert len(parse_lst) == len(generated)
# print(parse_lst[:2])
spans_lst = []
for parse in parse_lst:
chart, spans = chart_from_tree(tree_vocab, parse, verbose=True)
spans_lst.append(spans)
return parse_lst, spans_lst
def levenshteinDistance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2+1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def score_spans(gold_spans, generated_span):
print(gold_spans)
print(generated_span)
gold_spans = set([gold_spans])
generated_span = set(generated_span)
intersection = gold_spans.intersection(generated_span)
print(intersection, len(intersection) / len(gold_spans))
# union = gold_spans.union(generated_span)
# print(len(union), len(intersection))
# if unlabeled:
# print(generated_span)
# unlabeled_gold_spans = set([(a,b) for (a, b, v) in gold_spans])
# unlabeled_generated_span =set([(a,b) for (a, b, v) in generated_span])
# intersection = gold_spans.intersection(generated_span)
# union = gold_spans.union(generated_span)
return len(intersection) / len(gold_spans)
def score_tree(gold_tree, pred_trees):
# print([x.leaves() for x in pred_trees])
def reset_leaves(tree_):
simple_increm = 0
for s in tree_.subtrees(lambda t: t.height() == 2):
s[0] = simple_increm
s._label = 'NN'
simple_increm += 1
return simple_increm
# reset.
increm_gold = reset_leaves(gold_tree)
# print(increm_gold)
for i, pred in enumerate(pred_trees):
increm_pred = reset_leaves(pred)
# print(increm_pred, 'pred', i)
use_evalb = True
if use_evalb:
# print(len(gold_tree), len(pred_trees), gold_tree)
gold_trees = [gold_tree] * len(pred_trees)
print(len(gold_tree.leaves()), [len(x.leaves()) for x in pred_trees])
# print(pred_trees[0])
dev_fscore = evaluate.evalb('diffusion_lm/misc/self-attentive-parser/EVALB',
gold_trees, pred_trees)
print(dev_fscore)
return dev_fscore
def score_pos(gold_pos, generated_pos):
ed = levenshteinDistance(gold_pos, generated_pos)
return 1 - (ed / len(gold_pos))
def score_pos_em(gold_pos, generated_pos):
# print(len(gold_pos), len(generated_pos), gold_pos, generated_pos)
if len(generated_pos) > len(gold_pos):
generated_pos = generated_pos[:len(gold_pos)]
elif len(generated_pos) < len(gold_pos):
generated_pos = generated_pos + ['PAD'] * (len(gold_pos) - len(generated_pos))
assert len(gold_pos) == len(generated_pos)
correct = 0
all = 0
for x1, x2 in zip(gold_pos, generated_pos):
if x1 == x2:
correct += 1
all += 1
return correct/all
def score_attributes(gold_att, generated):
if gold_att in generated:
return 1.
else:
return 0.
def eval_pos(tagger, generated_text):
generated_pos = []
for sent in generated_text:
sent_full = " ".join(sent)
doc = tagger(sent_full)
generated_pos.append([token.pos_ for token in doc])
return generated_pos
def eval_(args, text_samples):
if args.mode == 'e2e-tgt-tree':
parser = benepar.Parser("benepar_en3")
tree_vocab = parser._parser.config["label_vocab"]
if args.gold_ref == 'full':
# toy1 = 'START Located in riverside area , Alimentum restaurant is a place to bring the whole family . \n END'.split()
# toy1 = 'START Alimentum is not a family - friendly place , located in city centre . \n END'.split()
toy1 = ['START', 'The', 'Vaults', 'pub', 'near', 'Café', 'Adriatic', 'has', 'a', '5', 'star', 'rating',
'.', 'Prices', 'start', 'at', '£', '30', '.', 'END']
input_sentence1 = benepar.InputSentence(
words=toy1[1:-1],
)
gold_parse = list(parser.parse_sents([input_sentence1]))[0]
chart, gold_spans = chart_from_tree(tree_vocab, gold_parse, verbose=True)
print(len(toy1[1:-1]), len(list(gold_parse.leaves())))
elif args.gold_ref == 'span':
# spans = [(10, 14, 'ADJP')]
gold_spans = [(0, 4, 'S::VP')]
gold_spans = [(0, 0, 'NP')]
gold_spans = [(9, 13, 'ADJP')]
# gold_spans = [(9, 13, 'PP')]
print(text_samples[:1])
# correct for length:
target_len = len(gold_parse.leaves())
print(gold_parse.leaves(), 'target')
for i, x in enumerate(text_samples):
if len(x) == target_len:
continue
elif len(x) > target_len:
text_samples[i] = x[:target_len]
else:
print('padded to same length', (target_len-len(x)))
text_samples[i] = x + ['.'] * (target_len-len(x))
# print(text_samples[i])
# print('SAD, our model is shorter??')
generated_parse, generated_span = eval_parse(parser, text_samples, tree_vocab)
# print(gold_spans)
# print(generated_span[:2])
evalb_score = score_tree(gold_parse, generated_parse)
print([len(x) for x in text_samples])
score_lst = []
for x in generated_span:
score_lst.append(score_spans(gold_spans, x))
print(np.array(score_lst).mean())
elif args.mode == 'e2e-tgt-pos':
tagger = spacy_stanza.load_pipeline("en", processors='tokenize,mwt,pos', ) #processors={"tokenize": "spacy",}
if args.gold_ref == 'full':
toy1 = 'START The Mill is a coffee shop with an expensive menu near The Sorrento . \n END'.split()
toy1 = ['START', 'The', 'Vaults', 'pub', 'near', 'Café', 'Adriatic', 'has', 'a', '5', 'star', 'rating', '.',
'Prices', 'start', 'at', '£', '30', '.', '\n', 'END']
sent_full = " ".join(toy1[1:-1])
doc = tagger(sent_full)
gold_pos = [token.pos_ for token in doc]
elif args.gold_ref == 'span':
gold_pos = [(9, 'PROPN')]
generated_pos = eval_pos(tagger, text_samples)
score_lst = []
score_lst2 = []
for x in generated_pos:
print(gold_pos)
print(x)
print()
score_lst.append(score_pos(gold_pos, x))
score_lst2.append(score_pos_em(gold_pos, x))
print(np.array(score_lst).mean())
print(np.array(score_lst2).mean())
elif args.mode == 'e2e-tgt-pos-paired':
import stanza
nlp = spacy_stanza.load_pipeline("en", processors={"tokenize": "spacy"})
print(nlp)
# nlp = stanza.Pipeline("en", processors={"tokenize": "spacy", 'pos': 'combined'}, package=None)
full_score = []
for gold, full_word_lst in text_samples.items():
print(gold, len(full_word_lst), full_word_lst[:2])
# full_word_lst = full_word_lst[:2]
sent_lst = [" ".join(seq) for seq in full_word_lst]
sent_full = " ".join(sent_lst)
# print(sent_lst)
try:
doc = nlp(sent_full)
doc_token_pos = [(token.text, token.pos_,) for token in doc]
len_lst = [len(seq) for seq in full_word_lst]
print(sum(len_lst), len(doc_token_pos), 'should be equal!!! ')
assert sum(len_lst) == len(doc_token_pos)
pos_lst = []
init_idx = 0
for len_temp in len_lst:
pos_lst.append([x[1] for x in doc_token_pos[init_idx:init_idx + len_temp]])
init_idx = init_idx + len_temp
except:
print(f'stanza pipeline failed... for this {gold}')
# parse each sentence separately...
pos_lst = []
for single_sent in sent_lst:
doc = nlp(single_sent)
# doc_token_pos = [(token.text, token.pos_,) for token in doc]
pos_lst.append([ token.pos_ for token in doc])
score_lst = []
score_lst2 = []
for x in pos_lst:
score_lst.append(score_pos(gold, x))
score_lst2.append(score_pos_em(gold, x))
score_ed = np.array(score_lst).mean()
score_em = np.array(score_lst2).mean()
print(len(score_lst), score_ed, score_em)
full_score.append(score_em)
full_score_em = np.array(full_score).mean()
print(full_score_em, f"\pm {np.array(full_score).std()}", len(full_score))
if args.mode == 'e2e-tgt-tree-paired':
parser = benepar.Parser("benepar_en3")
tree_vocab = parser._parser.config["label_vocab"]
full_score = []
for idx, (gold_parse, full_word_lst) in enumerate(text_samples.items()):
# to avoid evalb complain --> change \n to .
gold_parse_str = gold_parse[0]
gold_parse_str = gold_parse_str.replace('\n', '.')
# print([gold_parse_str], 'gold tree string ')
gold_parse = Tree.fromstring(gold_parse_str)
target_len = len(gold_parse.leaves())
# print(gold_parse.leaves(), 'target')
# print(full_word_lst)
for i, x in enumerate(full_word_lst):
if len(x) == target_len:
continue
elif len(x) > target_len:
print('generated seq is longer than gold seq')
full_word_lst[i] = x[:target_len]
else:
print('padded to same length', (target_len - len(x)))
full_word_lst[i] = x + ['.'] * (target_len - len(x))
# print(text_samples[i])
# print('SAD, our model is shorter??')
generated_parse, generated_span = eval_parse(parser, full_word_lst, tree_vocab)
evalb_score = score_tree(gold_parse, generated_parse) # inputs are nltk.Tree
# print(type(evalb_score))
print(evalb_score.fscore)
full_score.append(evalb_score.fscore)
full_score_f1 = np.array(full_score).mean()
# print(full_score_f1, len(full_score))
print(full_score_f1, f"\pm {np.array(full_score).std()}", len(full_score))
elif args.mode == 'e2e-tgt-spans-paired':
parser = benepar.Parser("benepar_en3")
tree_vocab = parser._parser.config["label_vocab"]
full_score = []
for idx, (gold_spans, full_word_lst) in enumerate(text_samples.items()):
# to avoid evalb complain --> change \n to .
print(gold_spans, '11 gold')
generated_parse, generated_span = eval_parse(parser, full_word_lst, tree_vocab)
score_lst = []
for x in generated_span:
score_lst.append(score_spans(gold_spans, x))
print(score_lst)
score_lst_mean = np.array(score_lst).mean()
full_score.append(score_lst_mean)
full_score_span = np.array(full_score).mean()
print(full_score_span, f"\pm {np.array(full_score).std()}", len(full_score))
if args.mode == 'e2e-tgt-attribute-paired':
full_score = []
for idx, (attribute, full_word_lst) in enumerate(text_samples.items()):
# print(attribute)
attribute = " ".join(attribute).split(':')[1].strip()
gold_attribute = attribute
score_lst = []
for i, x in enumerate(full_word_lst):
# print(gold_attribute, x)
score_lst.append(score_attributes(gold_attribute, " ".join(x)))
score_lst_mean = np.array(score_lst).mean()
full_score.append(score_lst_mean)
full_score_mean = np.array(full_score).mean()
# print(full_score_mean, len(full_score))
print(full_score_mean, f"\pm {np.array(full_score).std()}", len(full_score))
if args.mode == 'e2e-tgt-length-paired':
full_score = []
for idx, (attribute, full_word_lst) in enumerate(text_samples.items()):
tgt_len = int(attribute[0]) - 2 # remove START and END.
score_lst = []
for i, x in enumerate(full_word_lst):
if tgt_len == len(x):
# if np.abs(tgt_len - len(x)) <= 2:
score_lst.append(1.)
else:
score_lst.append(0.)
score_lst_mean = np.array(score_lst).mean()
full_score.append(score_lst_mean)
full_score_mean = np.array(full_score).mean()
# print(full_score_mean, len(full_score))
print(full_score_mean, f"\pm {np.array(full_score).std()}", len(full_score))
elif args.mode == 'e2e-tgt-attribute':
gold_attribute = ""
score_lst = []
for x in text_samples:
score_lst.append(score_attributes(gold_attribute, x))
print(np.array(score_lst).mean())
if __name__ == '__main__':
# 'diffusion_lm/improved_diffusion/out_gen/diff_e2e-tgt_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart.ema_0.9999_200000.pt.infill_control_tree_50x64x16_tree_partial-cat-lgv0.1.json'
parser = argparse.ArgumentParser(description='training args.')
parser.add_argument('--input_text', type=str, default='diffusion_lm/improved_diffusion/out_gen/diff_e2e-tgt_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart.ema_0.9999_200000.pt.'
'infill_control_tree_50x64x16_tree_partial-cat-lgv0.1.json',)
parser.add_argument('--input_format', type=str, default='batch', help='wp, wikitext')
parser.add_argument('--mode', type=str, default='e2e-tgt-tree', help='')
parser.add_argument('--gold_ref', type=str, default='full', help='')
parser.add_argument('--model_name_or_path', type=str, default='predictability/diff_models/e2e-tgt_e=20_b=64_m=gpt2_wikitext-103-raw-v1_101_wp_finetune_UNK', help='')
# default='predictability/diff_models/e2e-tgt_e=6_b=10_m=gpt2_wikitext-103-raw-v1_101_wp_pad', help='')
args = parser.parse_args()
text_samples = read_files(args)
eval_(args, text_samples)
eval_ppl(args, text_samples)
# eval_ppl2(args, text_samples)