Lang2mol-Diff / src /scripts /batch_decode.py
ndhieunguyen's picture
Add application file
7dd9869
import os, sys, glob
# full_lst = glob.glob('diff_models_synth128*')
# full_lst = glob.glob('diff_models_synth32*')
# full_lst = glob.glob('diff_models_synth32_3_rand16*')
# full_lst = glob.glob('diff_models_synth_rand_16_trans_lr_1e-5_long_Lsimple')
full_lst = glob.glob(sys.argv[1])
top_p = -1.0 if len(sys.argv) < 2 else sys.argv[2]
print(f'top_p = {top_p}')
pattern_ = 'model' if len(sys.argv) < 3 else sys.argv[3]
print(f'pattern_ = {pattern_}', sys.argv[3])
# print(full_lst)
output_lst = []
for lst in full_lst:
print(lst)
try:
tgt = sorted(glob.glob(f"{lst}/{pattern_}*pt"))[-1]
lst = os.path.split(lst)[1]
print(lst)
num = 1
except:
continue
model_arch_ = lst.split('_')[5-num]
model_arch = 'conv-unet' if 'conv-unet' in lst else 'transformer'
mode = 'image' if ('conv' in model_arch ) else 'text' #or '1d-unet' in model_arch_
print(mode, model_arch_)
dim_ =lst.split('_')[4-num]
# diffusion_steps= 4000
# noise_schedule = 'cosine'
# dim = dim_.split('rand')[1]
if 'synth' in lst:
modality = 'synth'
elif 'pos' in lst:
modality = 'pos'
elif 'image' in lst:
modality = 'image'
elif 'roc' in lst:
modality = 'roc'
elif 'e2e-tgt' in lst:
modality = 'e2e-tgt'
elif 'simple-wiki' in lst:
modality = 'simple-wiki'
elif 'book' in lst:
modality = 'book'
elif 'yelp' in lst:
modality = 'yelp'
elif 'commonGen' in lst:
modality = 'commonGen'
elif 'e2e' in lst:
modality = 'e2e'
if 'synth32' in lst:
kk = 32
elif 'synth128' in lst:
kk = 128
try:
diffusion_steps = int(lst.split('_')[7-num])
print(diffusion_steps)
except:
diffusion_steps = 4000
try:
noise_schedule = lst.split('_')[8-num]
assert noise_schedule in ['cosine', 'linear']
print(noise_schedule)
except:
noise_schedule = 'cosine'
try:
dim = int(dim_.split('rand')[1])
except:
dim =lst.split('_')[4-num]
try:
print(len(lst.split('_')))
num_channels = int(lst.split('_')[-1].split('h')[1])
except:
num_channels = 128
print(tgt, model_arch, dim, num_channels)
# out_dir = 'diffusion_lm/improved_diffusion/out_gen_large_nucleus'
# num_samples = 512
# out_dir = 'diffusion_lm/improved_diffusion/out_gen_v2_nucleus'
out_dir = 'generation_outputs'
num_samples = 50
if modality == 'e2e':
num_samples = 547
COMMAND = f'python scripts/{mode}_sample.py ' \
f'--model_path {tgt} --batch_size 50 --num_samples {num_samples} --top_p {top_p} ' \
f'--out_dir {out_dir} '
print(COMMAND)
# os.system(COMMAND)
# shape_str = "x".join([str(x) for x in arr.shape])
model_base_name = os.path.basename(os.path.split(tgt)[0]) + f'.{os.path.split(tgt)[1]}'
if modality == 'e2e-tgt' or modality == 'e2e':
out_path2 = os.path.join(out_dir, f"{model_base_name}.samples_{top_p}.json")
else:
out_path2 = os.path.join(out_dir, f"{model_base_name}.samples_{top_p}.txt")
output_cands = glob.glob(out_path2)
print(out_path2, output_cands)
if len(output_cands) > 0:
out_path2 = glob.glob(out_path2)[0]
else:
os.system(COMMAND)
out_path2 = glob.glob(out_path2)[0]
output_lst.append(out_path2)
if modality == 'pos':
model_name_path = 'predictability/diff_models/pos_e=15_b=20_m=gpt2_wikitext-103-raw-v1_s=102'
elif modality == 'synth':
if kk == 128:
model_name_path = 'predictability/diff_models/synth_e=15_b=10_m=gpt2_wikitext-103-raw-v1_None'
else:
model_name_path = 'predictability/diff_models/synth_e=15_b=20_m=gpt2_wikitext-103-raw-v1_None'
elif modality == 'e2e-tgt':
model_name_path = "predictability/diff_models/e2e-tgt_e=15_b=20_m=gpt2_wikitext-103-raw-v1_101_None"
elif modality == 'roc':
model_name_path = "predictability/diff_models/roc_e=6_b=10_m=gpt2_wikitext-103-raw-v1_101_wp_pad_v1"
elif modality == 'e2e':
COMMAND1 = f"python diffusion_lm/e2e_data/mbr.py {out_path2}"
os.system(COMMAND1)
COMMAND2 = f"python e2e-metrics/measure_scores.py " \
f"diffusion_lm/improved_diffusion/out_gen_v2_dropout2/1_valid_gold " \
f"{out_path2}.clean -p -t -H > {os.path.join(os.path.split(tgt)[0], 'e2e_valid_eval.txt')}"
print(COMMAND2)
os.system(COMMAND2)
continue
else:
print('not trained a AR model yet... only look at the output plz.')
continue
COMMAND = f"python scripts/ppl_under_ar.py " \
f"--model_path {tgt} " \
f"--modality {modality} --experiment random " \
f"--model_name_or_path {model_name_path} " \
f"--input_text {out_path2} --mode eval"
print(COMMAND)
print()
os.system(COMMAND)
print('output lists:')
print("\n".join(output_lst))