gap-text2sql
/
gap-text2sql-main
/mrat-sql-gap
/experiments
/spider-configs
/BART-large-en
/gap-bart.jsonnet
local _0428_base = import 'nl2code-base.libsonnet'; | |
local _data_path = 'data/spider-en/'; | |
local _output_from = true; | |
local _fs = 2; | |
function(args) _0428_base(output_from=_output_from, data_path=_data_path) + { | |
local lr_s = '%0.1e' % args.lr, | |
local bert_lr_s = '%0.1e' % args.bert_lr, | |
local end_lr_s = if args.end_lr == 0 then '0e0' else '%0.1e' % args.end_lr, | |
local base_bert_enc_size = 1024, | |
local enc_size = base_bert_enc_size, | |
model_name: 'bs=%(bs)d,lr=%(lr)s,bert_lr=%(bert_lr)s,end_lr=%(end_lr)s,att=%(att)d' % (args + { | |
lr: lr_s, | |
bert_lr: bert_lr_s, | |
end_lr: end_lr_s, | |
}), | |
model+: { | |
encoder+: { | |
name: 'spider-bart', | |
batch_encs_update:: null, | |
question_encoder:: null, | |
column_encoder:: null, | |
table_encoder:: null, | |
dropout:: null, | |
update_config+: { | |
name: 'relational_transformer', | |
num_layers: args.num_layers, | |
num_heads: 8, | |
sc_link: args.sc_link, | |
cv_link: args.cv_link, | |
}, | |
summarize_header: args.summarize_header, | |
use_column_type: args.use_column_type, | |
bart_version: args.bart_version, | |
pretrained_checkpoint: args.pretrained_checkpoint, | |
top_k_learnable:: null, | |
word_emb_size:: null, | |
}, | |
encoder_preproc+: { | |
word_emb:: null, | |
min_freq:: null, | |
max_count:: null, | |
db_path: _data_path + "database", | |
compute_sc_link: args.sc_link, | |
compute_cv_link: args.cv_link, | |
fix_issue_16_primary_keys: true, | |
bart_version: args.bart_version, | |
pretrained_checkpoint: args.pretrained_checkpoint, | |
count_tokens_in_word_emb_for_vocab:: null, | |
save_path: _data_path + 'BART-large-nl2code-1115,output_from=%s,fs=%d,emb=bart,cvlink' % [_output_from, _fs], | |
}, | |
decoder_preproc+: { | |
grammar+: { | |
end_with_from: args.end_with_from, | |
clause_order: args.clause_order, | |
infer_from_conditions: true, | |
factorize_sketch: _fs, | |
}, | |
save_path: _data_path + 'BART-large-nl2code-1115,output_from=%s,fs=%d,emb=bart,cvlink' % [_output_from, _fs], | |
compute_sc_link:: null, | |
compute_cv_link:: null, | |
db_path:: null, | |
fix_issue_16_primary_keys:: null, | |
bart_version:: null, | |
pretrained_checkpoint:: null, | |
}, | |
decoder+: { | |
name: 'NL2Code', | |
dropout: 0.20687225956012834, | |
desc_attn: 'mha', | |
enc_recurrent_size: enc_size, | |
recurrent_size : args.decoder_hidden_size, | |
loss_type: 'softmax', | |
use_align_mat: args.use_align_mat, | |
use_align_loss: args.use_align_loss, | |
}, | |
}, | |
train+: { | |
batch_size: args.bs, | |
num_batch_accumulated: args.num_batch_accumulated, | |
clip_grad: 1, | |
model_seed: args.att, | |
data_seed: args.att, | |
init_seed: args.att, | |
}, | |
optimizer: { | |
name: 'bertAdamw', | |
lr: 0.0, | |
bert_lr: 0.0, | |
}, | |
lr_scheduler+: { | |
name: 'bert_warmup_polynomial_group', | |
start_lrs: [args.lr, args.bert_lr], | |
end_lr: args.end_lr, | |
num_warmup_steps: $.train.max_steps / 8, | |
}, | |
} | |