File size: 8,218 Bytes
20b7679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import argparse
from datetime import datetime
from const import lca_names, sca_names, lingfeat_names
import os, json
from copy import deepcopy
import numpy as np

def parse_args(ckpt=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', default='/data/mohamed/data')
    parser.add_argument('--data', default='ling_conversion')
    parser.add_argument('--data_sources')
    parser.add_argument('--data_type', default='text')
    parser.add_argument('--aim_repo', default='/data/mohamed/')
    parser.add_argument('--ckpt_dir', default='/data/mohamed/checkpoints')
    parser.add_argument('--kld_annealing', default='cyclic')
    parser.add_argument('--lingpred_annealing', default='mono')
    parser.add_argument('--ling_embed_type', default = 'one-layer')
    parser.add_argument('--combine_weight', default=1, type=float)
    parser.add_argument('--alpha_kld', default=1, type=float)
    parser.add_argument('--alpha_lingpred', default=1, type=float)
    parser.add_argument('--alpha_sem', default=1, type=float)
    parser.add_argument('--max_grad_norm', default=10, type=float)
    parser.add_argument('--sem_loss_tao', default=0.5, type=float)
    parser.add_argument('--sem_loss_eps', default=1, type=float)
    parser.add_argument('--ckpt')
    parser.add_argument('--disc_ckpt')
    parser.add_argument('--sem_ckpt')
    parser.add_argument('--lng_ids')
    parser.add_argument('--lng_ids_idx', type=int)
    parser.add_argument('--lng_ids_path', default='/data/mohamed/indices')
    parser.add_argument('--preds_dir', default='/data/mohamed/preds')
    parser.add_argument('--model_name', default="google/flan-t5-base")
    parser.add_argument('--disc_type', default="t5")
    parser.add_argument('--aim_exp', default='ling-conversion')
    parser.add_argument('--sem_loss_type', default='dedicated')
    parser.add_argument('--combine_method', default='none')
    parser.add_argument('--train_log', type=int, default=200)
    parser.add_argument('--val_log', type=int, default=2000)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--eval_batch_size', type=int, default=32)
    parser.add_argument('--max_eval_samples', type=int, default=1000)
    parser.add_argument('--test_batch_size', type=int, default=1)
    parser.add_argument('--hidden_dim', type=int, default=500)
    parser.add_argument('--latent_dim', type=int, default=150)
    parser.add_argument('--lng_dim', type=int, default=40)
    parser.add_argument('--disc_lng_dim', type=int)
    parser.add_argument('--use_lora', action='store_true')
    parser.add_argument('--lora_r', type=int, default=64)
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--grad_accumulation', type=int, default=1)
    parser.add_argument('--n_ica', type=int, default=10)
    parser.add_argument('--max_length', type=int, default=200)
    parser.add_argument('--total_steps', type=int)
    parser.add_argument('--kld_const', type=float, default=1)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--kl_weight', type=float, default=1e-1)
    parser.add_argument('--weight_decay', type=float, default=1e-2)
    parser.add_argument('--ling_dropout', type=float, default=0.1)
    parser.add_argument('--predict_fn', default = 'logs/test.txt')
    parser.add_argument('--save_predict', action='store_true')
    parser.add_argument('--use_ica', action='store_true')
    parser.add_argument('--pretrain_gen', action='store_true')
    parser.add_argument('--pretrain_sem', action='store_true')
    parser.add_argument('--pretrain_disc', action='store_true')
    parser.add_argument('--linggen_type', default='none')
    parser.add_argument('--linggen_input', default='s+l')
    parser.add_argument('--aug_same', action='store_true')
    parser.add_argument('--ling_vae', action='store_true')
    parser.add_argument('--process_lingpred', action='store_true')
    parser.add_argument('--fudge_lambda', type=float, default=1.0)
    parser.add_argument('--use_lingpred', action='store_true')
    parser.add_argument('--ling2_only', action='store_true')
    parser.add_argument('--cycle_loss', action='store_true')
    parser.add_argument('--disc_loss', action='store_true')
    parser.add_argument('--sem_loss', action='store_true')
    parser.add_argument('--sim_loss', action='store_true')
    parser.add_argument('--optuna', action='store_true')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--demo', action='store_true')
    parser.add_argument('--fudge', action='store_true')
    parser.add_argument('--fb_log', default='feedback_logs/default.txt')
    parser.add_argument('--eval_only', action='store_true')
    parser.add_argument('--predict_with_feedback', action='store_true')
    parser.add_argument('--feedback_param', default = 's')
    parser.add_argument('--eval_ling', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--major_arg', default = 0, type=int)
    parser.add_argument('--quantize_lng', action='store_true')
    parser.add_argument('--quant_nbins', type=int, default=20)
    parser.add_argument('--src_lng', default = 'ling')
    parser.add_argument('--to_restore', nargs='+', default=[])
    # args = parser.parse_args()
    args, unknown = parser.parse_known_args()
    args.name = f'{datetime.now().strftime("%m%d_%H-%M-%S")}-{args.data}-{args.combine_method}'

    major_arg = args.major_arg
    to_restore = [
            'total_steps','major_arg','gpu','demo', 'eval_only', 'save_predict', 'predict_fn', 'fudge', 'predict_with_feedback',
            'feedback_param', 'fb_log', 'data_dir', 'data', 'disc_ckpt', 'disc_type', 'sem_ckpt', 'fudge_lambda', 'test_batch_size', 'src_lng'
            ] + args.to_restore
    to_restore = {k: args.__dict__[k] for k in to_restore}

    if not args.disc_loss or args.disc_ckpt:
        args.disc_steps = 0

    if args.data_sources is not None:
        args.data_sources = args.data_sources.split(',')

    if ckpt is not None:
        args.ckpt = ckpt

    args_list = [args]
    if args.ckpt:
        if ',' in args.ckpt:
            ckpts = args.ckpt.split(',')
            args_list = [deepcopy(args) for _ in range(len(ckpts))]
            for i in range(len(ckpts)):
                args_path = ckpts[i].replace('_best', '').replace('.pt', '.json')
                with open(args_path) as f:
                    args_list[i].__dict__.update(json.load(f))
                args_list[i].__dict__.update(to_restore)
                args_list[i].ckpt = ckpts[i]
        else:
            args_path = args.ckpt.replace('_best', '').replace('.pt', '.json')
            ckpt = args.ckpt
            with open(args_path) as f:
                args.__dict__.update(json.load(f))
                args.__dict__.update(to_restore)
                args.ckpt = ckpt

    lng_names = lca_names + sca_names + lingfeat_names
    for i in range(len(args_list)):
        if args_list[i].lng_ids or args_list[i].lng_ids_idx:
            if args_list[i].lng_ids_idx:
                lng_ids = np.load(os.path.join(args_list[i].lng_ids_path, f'{args_list[i].lng_ids_idx}.npy'))
            elif args_list[i].lng_ids[0].isnumeric():
                lng_ids = [int(x) for x in args_list[i].lng_ids.split(',')]
            elif ',' in args_list[i].lng_ids:
                lng_ids = [lng_names.index(x) for x in args_list[i].lng_ids.split(',')]
            else:
                lng_ids = np.load(args_list[i].lng_ids)
            args_list[i].lng_dim = len(lng_ids)
            args_list[i].lng_ids = lng_ids.tolist()
            # lng_names = [lng_names[i] for i in lng_ids]
        elif args_list[i].use_ica:
            args_list[i].lng_dim = args_list[i].n_ica
        if args_list[i].disc_lng_dim is None:
            args_list[i].disc_lng_dim = args_list[i].lng_dim

    if not args.ckpt and not args.eval_only:
        args_path = os.path.join(args.ckpt_dir, '%s.json'%args.name)
        with open(args_path, 'w') as f:
            s = json.dumps(args.__dict__)
            f.write(s)

    return args_list[major_arg], args_list, lng_names