mohdelgaar commited on
Commit
132da4a
·
1 Parent(s): 251ecbb

improve performance

Browse files
Files changed (4) hide show
  1. app.py +11 -9
  2. const.py +0 -1
  3. model.py +309 -180
  4. options.py +57 -20
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import spacy
2
  import nltk
3
  nltk.download('wordnet', quiet=True)
4
- spacy.cli.download('en_core_web_sm')
 
5
  from compute_lng import compute_lng
6
 
7
  import torch
@@ -111,7 +112,7 @@ def impute_targets():
111
  shared_state.target = round_ling(interp_raw).tolist()
112
  return shared_state.target
113
 
114
- def generate_with_feedback(sent1, approx):
115
  if sent1 == '':
116
  raise gr.Error('Please input a source text.')
117
 
@@ -122,24 +123,25 @@ def generate_with_feedback(sent1, approx):
122
  input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
123
  ling2 = torch.tensor(scaler.transform([shared_state.target])).float().to(device)
124
  inputs = {
125
- 'sentence1_input_ids': input_ids,
126
  'sentence2_ling': ling2,
127
- 'sentence1_attention_mask': torch.ones_like(input_ids)
128
  }
129
 
130
- pred, (pred_text, interpolations) = model.infer_with_feedback_BP(ling_disc, sem_emb, inputs, tokenizer)
 
131
 
132
  interpolation = '-- ' + '\n-- '.join(interpolations)
133
  # Return both the generation results and the updated slider values
134
  return [pred_text, interpolation] + [gr.update(value=val) for val in shared_state.target]
135
 
136
- def generate_random(sent1, count, approx):
137
  if sent1 == '':
138
  raise gr.Error('Please input a source text.')
139
  preds, interpolations = [], []
140
  orig_active_indices = shared_state.active_indices
141
  shared_state.active_indices = set(range(len(lng_names)))
142
- for c in range(count):
143
  idx = np.random.randint(0, len(ling_collection))
144
  ling_ex = ling_collection[idx]
145
  shared_state.target = ling_ex.copy()
@@ -167,7 +169,7 @@ def generate_random(sent1, count, approx):
167
  shared_state.active_indices = orig_active_indices
168
  return '\n***\n'.join(preds), '\n***\n'.join(interpolations)
169
 
170
- def estimate_gen(sent1, sent2, approx):
171
  if 'approximate' in approx:
172
  input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device)
173
  with torch.no_grad():
@@ -183,7 +185,7 @@ def estimate_gen(sent1, sent2, approx):
183
 
184
  orig_active_indices = shared_state.active_indices
185
  shared_state.active_indices = set(range(len(lng_names)))
186
- gen = generate_with_feedback(sent1, approx)[:2]
187
  shared_state.active_indices = orig_active_indices
188
  return gen + [gr.update(value=val) for val in shared_state.target]
189
 
 
1
  import spacy
2
  import nltk
3
  nltk.download('wordnet', quiet=True)
4
+ if not spacy.util.is_package('en_core_web_sm'):
5
+ spacy.cli.download('en_core_web_sm')
6
  from compute_lng import compute_lng
7
 
8
  import torch
 
112
  shared_state.target = round_ling(interp_raw).tolist()
113
  return shared_state.target
114
 
115
+ def generate_with_feedback(sent1, approx, progress=gr.Progress()):
116
  if sent1 == '':
117
  raise gr.Error('Please input a source text.')
118
 
 
123
  input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
124
  ling2 = torch.tensor(scaler.transform([shared_state.target])).float().to(device)
125
  inputs = {
126
+ 'input_ids': input_ids,
127
  'sentence2_ling': ling2,
128
+ 'attention_mask': torch.ones_like(input_ids)
129
  }
130
 
131
+ progress((0, None), unit='intermediate paraphrase generated.')
132
+ pred, (pred_text, interpolations) = model.infer_with_feedback_BP(ling_disc, sem_emb, inputs, tokenizer, progress)
133
 
134
  interpolation = '-- ' + '\n-- '.join(interpolations)
135
  # Return both the generation results and the updated slider values
136
  return [pred_text, interpolation] + [gr.update(value=val) for val in shared_state.target]
137
 
138
+ def generate_random(sent1, count, approx, progress=gr.Progress()):
139
  if sent1 == '':
140
  raise gr.Error('Please input a source text.')
141
  preds, interpolations = [], []
142
  orig_active_indices = shared_state.active_indices
143
  shared_state.active_indices = set(range(len(lng_names)))
144
+ for c in progress.tqdm(range(count), desc='Generating random sentences', unit='paraphrases'):
145
  idx = np.random.randint(0, len(ling_collection))
146
  ling_ex = ling_collection[idx]
147
  shared_state.target = ling_ex.copy()
 
169
  shared_state.active_indices = orig_active_indices
170
  return '\n***\n'.join(preds), '\n***\n'.join(interpolations)
171
 
172
+ def estimate_gen(sent1, sent2, approx, progress=gr.Progress()):
173
  if 'approximate' in approx:
174
  input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device)
175
  with torch.no_grad():
 
185
 
186
  orig_active_indices = shared_state.active_indices
187
  shared_state.active_indices = set(range(len(lng_names)))
188
+ gen = generate_with_feedback(sent1, approx, progress)[:2]
189
  shared_state.active_indices = orig_active_indices
190
  return gen + [gr.update(value=val) for val in shared_state.target]
191
 
const.py CHANGED
@@ -1030,7 +1030,6 @@ used_indices = [
1030
  63, 64, 65, 66, 67, 68, 73, 121, 124, 129, 134, 136, 254,
1031
  257, 258, 261, 263, 272, 274
1032
  ]
1033
- lftk_used_indices = [1, 7, 8, 9, 10, 11, 12, 17, 65, 68, 73, 78, 80, 198, 201, 202, 205, 207, 216, 218]
1034
 
1035
  eval_indices = [4,5,6,18,257,272]
1036
  eval_indices = [used_indices.index(idx) for idx in eval_indices]
 
1030
  63, 64, 65, 66, 67, 68, 73, 121, 124, 129, 134, 136, 254,
1031
  257, 258, 261, 263, 272, 274
1032
  ]
 
1033
 
1034
  eval_indices = [4,5,6,18,257,272]
1035
  eval_indices = [used_indices.index(idx) for idx in eval_indices]
model.py CHANGED
@@ -10,6 +10,10 @@ from types import MethodType
10
  from utils import *
11
  from ling_disc import DebertaReplacedTokenizer
12
  from const import *
 
 
 
 
13
 
14
 
15
 
@@ -77,9 +81,9 @@ class LingGenerator(nn.Module):
77
  bs = inputs_embeds.shape[0]
78
 
79
  if self.gen_input == 's+l':
80
- sent1_ling = self.ling_embed(batch['sentence1_ling'])
81
- sent1_ling = sent1_ling.view(bs, 1, -1)
82
- inputs_embeds = inputs_embeds + sent1_ling
83
 
84
  gen = self.gen(inputs_embeds=inputs_embeds,
85
  attention_mask=inputs_att_mask).last_hidden_state.mean(1)
@@ -185,13 +189,13 @@ class SemEmb(T5EncoderModel):
185
  nn.Linear(hidden_dim, 1))
186
 
187
  def compare_sem(self, **batch):
188
- bs = batch['sentence1_attention_mask'].shape[0]
189
- ones = torch.ones((bs, 1), device=batch['sentence1_attention_mask'].device)
190
  sep = torch.ones((bs, 1), dtype=torch.long,
191
- device=batch['sentence1_attention_mask'].device) * self.sep_token_id
192
- att_mask = torch.cat([batch['sentence1_attention_mask'], ones, batch['sentence2_attention_mask']], dim=1)
193
  if 'logits' in batch:
194
- input_ids = torch.cat([batch['sentence1_input_ids'], sep], dim=1)
195
  embeds1 = self.shared(input_ids)
196
 
197
  logits = batch['logits']
@@ -201,11 +205,11 @@ class SemEmb(T5EncoderModel):
201
 
202
  embeds2 = onehot_ @ self.shared.weight
203
  embeds1_2 = torch.cat([embeds1, embeds2], dim=1)
204
- hidden_units = self(inputs_embeds=embeds1_2,
205
  attention_mask=att_mask).last_hidden_state.mean(1)
206
  elif 'sentence2_input_ids' in batch:
207
- input_ids = torch.cat([batch['sentence1_input_ids'], sep, batch['sentence2_input_ids']], dim=1)
208
- hidden_units = self(input_ids=input_ids,
209
  attention_mask=att_mask).last_hidden_state.mean(1)
210
  probs = self.projection(hidden_units)
211
  return probs
@@ -222,31 +226,36 @@ def prepare_inputs_for_generation(
222
  cross_attn_head_mask=None,
223
  use_cache=None,
224
  encoder_outputs=None,
225
- sent1_ling=None,
226
- sent2_ling=None,
227
  **kwargs
228
  ):
229
-
230
  # cut decoder_input_ids if past is used
231
  if past_key_values is not None:
232
  input_ids = input_ids[:, -1:]
233
 
 
 
234
  input_ids = input_ids.clone()
235
  decoder_inputs_embeds = self.shared(input_ids)
236
 
237
- if combine_method == 'decoder_add_first':
238
- sent2_ling = torch.cat([sent2_ling,
239
- torch.repeat_interleave(torch.zeros_like(sent2_ling), input_ids.shape[1] - 1, dim=1)], dim = 1)
240
- if combine_method == 'decoder_concat':
 
 
 
241
  if ling2_only:
242
- decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1)
243
  else:
244
- decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1)
245
- elif combine_method == 'decoder_add'or (past_key_values is None and combine_method == 'decoder_add_first'):
 
246
  if ling2_only:
247
- decoder_inputs_embeds = decoder_inputs_embeds + sent2_ling
248
  else:
249
- decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling
250
 
251
  return {
252
  "decoder_inputs_embeds": decoder_inputs_embeds,
@@ -257,19 +266,27 @@ def prepare_inputs_for_generation(
257
  "decoder_head_mask": decoder_head_mask,
258
  "cross_attn_head_mask": cross_attn_head_mask,
259
  "use_cache": use_cache,
 
260
  }
261
 
262
  class LogitsAdd(LogitsProcessor):
263
- def __init__(self, sent2_ling):
264
  super().__init__()
265
- self.sent2_ling = sent2_ling
266
 
267
  def __call__(self, input_ids, scores):
268
- return scores + self.sent2_ling
269
 
270
- class EncoderDecoderVAE(T5ForConditionalGeneration):
271
  def __init__(self, config, args, pad_token_id, sepeos_token_id, vocab_size = 32128):
 
 
 
 
 
 
272
  super().__init__(config)
 
273
  self.prepare_inputs_for_generation = types.MethodType(
274
  partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only),
275
  self)
@@ -287,7 +304,7 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
287
  nn.ReLU(),
288
  nn.Linear(hidden_dim, hidden_dim),
289
  )
290
- elif 'concat' in args.combine_method or 'add' in args.combine_method:
291
  if args.ling_embed_type == 'two-layer':
292
  self.ling_embed = nn.Sequential(
293
  nn.Linear(args.lng_dim, args.lng_dim),
@@ -297,6 +314,7 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
297
  else:
298
  self.ling_embed = nn.Linear(args.lng_dim, hidden_dim)
299
  self.ling_dropout = nn.Dropout(args.ling_dropout)
 
300
 
301
  if args.ling_vae:
302
  self.ling_mu = nn.Linear(hidden_dim, hidden_dim)
@@ -306,8 +324,20 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
306
  nn.init.xavier_uniform_(self.ling_logvar.weight)
307
 
308
 
309
- generate_with_grad = unwrap(self.generate)
310
  self.generate_with_grad = MethodType(generate_with_grad, self)
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  def get_fusion_layer(self):
313
  if 'fusion' in self.args.combine_method:
@@ -321,122 +351,143 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
321
  std = torch.exp(0.5 * logvar)
322
  return mu + std * torch.randn_like(std)
323
 
324
- def encode(self, batch):
325
- if 'inputs_embeds' in batch:
326
- inputs_embeds = batch['inputs_embeds']
 
 
 
 
 
 
 
327
  else:
328
- inputs_embeds = self.shared(batch['sentence1_input_ids'])
329
- inputs_att_mask = batch['sentence1_attention_mask']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  bs = inputs_embeds.shape[0]
331
- cache = {}
332
  if self.args.combine_method in ('input_concat', 'input_add'):
333
- if 'sent1_ling_embed' in batch:
334
- sent1_ling = batch['sent1_ling_embed']
335
- else:
336
- sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling']))
337
- if 'sent2_ling_embed' in batch:
338
- sent2_ling = batch['sent2_ling_embed']
339
- else:
340
- sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling']))
341
- if self.args.ling_vae:
342
- sent1_ling = F.leaky_relu(sent1_ling)
343
- sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling)
344
- sent1_ling = self.sample(sent1_mu, sent1_logvar)
345
-
346
- sent2_ling = F.leaky_relu(sent2_ling)
347
- sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
348
- sent2_ling = self.sample(sent2_mu, sent2_logvar)
349
- cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
350
- 'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
351
- 'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
352
- else:
353
- cache.update({'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
354
- sent1_ling = sent1_ling.view(bs, 1, -1)
355
- sent2_ling = sent2_ling.view(bs, 1, -1)
356
  if self.args.combine_method == 'input_concat':
357
  if self.args.ling2_only:
358
- inputs_embeds = torch.cat([inputs_embeds, sent2_ling], dim=1)
359
  inputs_att_mask = torch.cat([inputs_att_mask,
360
  torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1)
361
  else:
362
- inputs_embeds = torch.cat([inputs_embeds, sent1_ling, sent2_ling], dim=1)
363
  inputs_att_mask = torch.cat([inputs_att_mask,
364
  torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1)
365
  elif self.args.combine_method == 'input_add':
366
  if self.args.ling2_only:
367
- inputs_embeds = inputs_embeds + sent2_ling
368
  else:
369
- inputs_embeds = inputs_embeds + sent1_ling + sent2_ling
 
 
 
370
  return self.encoder(inputs_embeds=inputs_embeds,
371
  attention_mask=inputs_att_mask), inputs_att_mask, cache
372
 
373
- def decode(self, batch, enc_output, inputs_att_mask, generate):
374
- bs = inputs_att_mask.shape[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  cache = {}
376
- if self.args.combine_method in ('embed_concat', 'decoder_concat', 'decoder_add', 'logits_add', 'decoder_add_first'):
377
- if 'sent1_ling_embed' in batch:
378
- sent1_ling = batch['sent1_ling_embed']
379
- elif 'sentence1_ling' in batch:
380
- sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling']))
381
- else:
382
- sent1_ling = None
383
- if 'sent2_ling_embed' in batch:
384
- sent2_ling = batch['sent2_ling_embed']
385
- else:
386
- sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling']))
387
- if self.args.ling_vae:
388
- sent1_ling = F.leaky_relu(sent1_ling)
389
- sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling)
390
- sent1_ling = self.sample(sent1_mu, sent1_logvar)
391
-
392
- sent2_ling = F.leaky_relu(sent2_ling)
393
- sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
394
- sent2_ling = self.sample(sent2_mu, sent2_logvar)
395
- cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
396
- 'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
397
- 'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
398
- else:
399
- cache.update({'sent2_ling': sent2_ling})
400
- if sent1_ling is not None:
401
- cache.update({'sent1_ling': sent1_ling})
402
- if sent1_ling is not None:
403
- sent1_ling = sent1_ling.view(bs, 1, -1)
404
- sent2_ling = sent2_ling.view(bs, 1, -1)
405
- if self.args.combine_method == 'decoder_add_first' and not generate:
406
- sent2_ling = torch.cat([sent2_ling,
407
- torch.repeat_interleave(torch.zeros_like(sent2_ling), batch['sentence2_input_ids'].shape[1] - 1, dim=1)], dim = 1)
408
- else:
409
- sent1_ling, sent2_ling = None, None
410
-
411
- if self.args.combine_method == 'embed_concat':
412
- enc_output.last_hidden_state = torch.cat([enc_output.last_hidden_state,
413
- sent1_ling, sent2_ling], dim=1)
414
- inputs_att_mask = torch.cat([inputs_att_mask,
415
- torch.ones((bs, 2)).to(inputs_att_mask.device)], dim=1)
416
- elif 'fusion' in self.args.combine_method:
417
- sent1_ling = batch['sentence1_ling'].unsqueeze(1)\
418
- .expand(-1, enc_output.last_hidden_state.shape[1], -1)
419
- sent2_ling = batch['sentence2_ling'].unsqueeze(1)\
420
- .expand(-1, enc_output.last_hidden_state.shape[1], -1)
421
- if self.args.ling2_only:
422
- combined_embedding = torch.cat([enc_output.last_hidden_state, sent2_ling], dim=2)
423
  else:
424
- combined_embedding = torch.cat([enc_output.last_hidden_state, sent1_ling, sent2_ling], dim=2)
425
- enc_output.last_hidden_state = self.fusion(combined_embedding)
426
-
427
  if generate:
428
  if self.args.combine_method == 'logits_add':
429
- logits_processor = LogitsProcessorList([LogitsAdd(sent2_ling.view(bs, -1))])
430
  else:
431
  logits_processor = LogitsProcessorList()
432
 
433
- dec_output = self.generate_with_grad(
434
- attention_mask=inputs_att_mask,
435
- encoder_outputs=enc_output,
436
- sent1_ling=sent1_ling,
437
- sent2_ling=sent2_ling,
438
- return_dict_in_generate=True,
439
- output_scores=True,
440
  logits_processor = logits_processor,
441
  # renormalize_logits=True,
442
  # do_sample=True,
@@ -445,68 +496,135 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
445
  # min_new_tokens=3,
446
  # repetition_penalty=1.2,
447
  max_length=self.args.max_length,
 
 
448
  )
449
- scores = torch.stack(dec_output.scores, 1)
450
- cache.update({'scores': scores})
451
- return dec_output.sequences, cache
452
-
453
- decoder_input_ids = self._shift_right(batch['sentence2_input_ids'])
454
- decoder_inputs_embeds = self.shared(decoder_input_ids)
455
- decoder_att_mask = batch['sentence2_attention_mask']
456
- labels = batch['sentence2_input_ids'].clone()
457
- labels[labels == self.pad_token_id] = -100
458
-
459
- if self.args.combine_method == 'decoder_concat':
460
- if self.args.ling2_only:
461
- decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1)
462
- decoder_att_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
463
- labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
464
- labels], dim=1)
465
- else:
466
- decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1)
467
- decoder_att_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
468
- labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
469
- labels], dim=1)
470
- elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' :
471
- if self.args.ling2_only:
472
- decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sent2_ling
473
- else:
474
- decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling
475
 
476
- dec_output = self(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  decoder_inputs_embeds=decoder_inputs_embeds,
478
- decoder_attention_mask=decoder_att_mask,
479
- encoder_outputs=enc_output,
480
- attention_mask=inputs_att_mask,
481
  labels=labels,
 
 
482
  )
483
  if self.args.combine_method == 'logits_add':
484
- dec_output.logits = dec_output.logits + self.args.combine_weight * sent2_ling
485
  vocab_size = dec_output.logits.size(-1)
486
  dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1))
487
  return dec_output, cache
488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
- def convert(self, batch, generate=False):
491
- enc_output, enc_att_mask, cache = self.encode(batch)
492
- dec_output, cache2 = self.decode(batch, enc_output, enc_att_mask, generate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  cache.update(cache2)
494
- return dec_output, enc_output, cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
 
496
  def infer_with_cache(self, batch):
497
- dec_output, _, cache = self.convert(batch, generate = True)
498
  return dec_output, cache
499
 
500
  def infer(self, batch):
501
  dec_output, _ = self.infer_with_cache(batch)
502
  return dec_output
503
 
504
- def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer):
505
  from torch.autograd import grad
506
  interpolations = []
507
  def line_search():
508
- best_val = None
509
- best_loss = None
510
  eta = 1e3
511
  sem_prob = 1
512
  patience = 4
@@ -516,13 +634,11 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
516
  new_loss, pred = get_loss(param_)
517
  max_len = pred.shape[1]
518
  lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1
519
- batch.update({
520
- 'sentence2_input_ids': pred,
521
- 'sentence2_attention_mask': sequence_mask(lens, max_len = max_len)
522
- })
523
- sem_prob = torch.sigmoid(sem_emb.compare_sem(**batch)).item()
524
- # if sem_prob <= 0.1:
525
- # patience -= 1
526
  if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1:
527
  return param_
528
  eta *= 2.25
@@ -531,7 +647,7 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
531
 
532
  def get_loss(param):
533
  if self.args.feedback_param == 'l':
534
- batch.update({'sent2_ling_embed': param})
535
  elif self.args.feedback_param == 's':
536
  batch.update({'inputs_embeds': param})
537
 
@@ -539,8 +655,9 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
539
  logits = param
540
  pred = param.argmax(-1)
541
  else:
542
- pred, cache = self.infer_with_cache(batch)
543
- logits = cache['scores']
 
544
  out = ling_disc(logits = logits)
545
  probs = F.softmax(out, 1)
546
  if ling_disc.quant:
@@ -553,13 +670,13 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
553
  ling2_embed = self.ling_embed(batch['sentence2_ling'])
554
  param = torch.nn.Parameter(ling2_embed, requires_grad = True)
555
  elif self.args.feedback_param == 's':
556
- inputs_embeds = self.shared(batch['sentence1_input_ids'])
557
  param = torch.nn.Parameter(inputs_embeds, requires_grad = True)
558
  elif self.args.feedback_param == 'logits':
559
  logits = self.infer_with_cache(batch)[1]['scores']
560
  param = torch.nn.Parameter(logits, requires_grad = True)
561
- target_np = batch['sentence2_ling'][0].cpu().numpy()
562
- while True:
563
  loss, pred = get_loss(param)
564
  pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
565
  skip_special_tokens=True)[0]
@@ -571,6 +688,9 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
571
  param = line_search()
572
  if param is False:
573
  break
 
 
 
574
  return pred, [pred_text, interpolations]
575
 
576
  def set_grad(module, state):
@@ -609,7 +729,7 @@ class LingDiscPipeline():
609
  def __init__(self,
610
  model_name="google/flan-t5-base",
611
  disc_type='deberta',
612
- disc_ckpt='/data/mohamed/checkpoints/ling_disc/deberta-v3-small_flan-t5-base_40',
613
  # disc_type='t5',
614
  # disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt',
615
  ):
@@ -629,15 +749,13 @@ def get_model(args, tokenizer, device):
629
  ling_disc = LingDisc(args.model_name, args.disc_type, args.disc_model_path).to(device)
630
  else:
631
  ling_disc = None
632
- if args.linggen_type != 'none':
633
- ling_gen = LingGenerator(args).to(device)
634
 
635
- if not args.pretrain_disc:
636
  model = EncoderDecoderVAE.from_pretrained(args.model_path, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device)
637
  else:
638
- model = ling_disc
639
 
640
- if args.sem_loss or args.sem_ckpt:
641
  if args.sem_loss_type == 'shared':
642
  sem_emb = model.encoder
643
  elif args.sem_loss_type == 'dedicated':
@@ -649,3 +767,14 @@ def get_model(args, tokenizer, device):
649
 
650
  return model, ling_disc, sem_emb
651
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from utils import *
11
  from ling_disc import DebertaReplacedTokenizer
12
  from const import *
13
+ from lingconv_t5 import LingConvT5ForConditionalGeneration
14
+ from dataclasses import dataclass
15
+ from transformers.modeling_outputs import Seq2SeqLMOutput
16
+ from typing import Optional, Dict, Any
17
 
18
 
19
 
 
81
  bs = inputs_embeds.shape[0]
82
 
83
  if self.gen_input == 's+l':
84
+ sentence1_ling = self.ling_embed(batch['sentence1_ling'])
85
+ sentence1_ling = sentence1_ling.view(bs, 1, -1)
86
+ inputs_embeds = inputs_embeds + sentence1_ling
87
 
88
  gen = self.gen(inputs_embeds=inputs_embeds,
89
  attention_mask=inputs_att_mask).last_hidden_state.mean(1)
 
189
  nn.Linear(hidden_dim, 1))
190
 
191
  def compare_sem(self, **batch):
192
+ bs = batch['attention_mask'].shape[0]
193
+ ones = torch.ones((bs, 1), device=batch['attention_mask'].device)
194
  sep = torch.ones((bs, 1), dtype=torch.long,
195
+ device=batch['attention_mask'].device) * self.sep_token_id
196
+ att_mask = torch.cat([batch['attention_mask'], ones, batch['sentence2_attention_mask']], dim=1)
197
  if 'logits' in batch:
198
+ input_ids = torch.cat([batch['input_ids'], sep], dim=1)
199
  embeds1 = self.shared(input_ids)
200
 
201
  logits = batch['logits']
 
205
 
206
  embeds2 = onehot_ @ self.shared.weight
207
  embeds1_2 = torch.cat([embeds1, embeds2], dim=1)
208
+ hidden_units = super().forward(inputs_embeds=embeds1_2,
209
  attention_mask=att_mask).last_hidden_state.mean(1)
210
  elif 'sentence2_input_ids' in batch:
211
+ input_ids = torch.cat([batch['input_ids'], sep, batch['sentence2_input_ids']], dim=1)
212
+ hidden_units = super().forward(input_ids=input_ids,
213
  attention_mask=att_mask).last_hidden_state.mean(1)
214
  probs = self.projection(hidden_units)
215
  return probs
 
226
  cross_attn_head_mask=None,
227
  use_cache=None,
228
  encoder_outputs=None,
229
+ sentence1_ling=None,
230
+ sentence2_ling=None,
231
  **kwargs
232
  ):
 
233
  # cut decoder_input_ids if past is used
234
  if past_key_values is not None:
235
  input_ids = input_ids[:, -1:]
236
 
237
+ cached = use_cache and len(past_key_values) > 0
238
+
239
  input_ids = input_ids.clone()
240
  decoder_inputs_embeds = self.shared(input_ids)
241
 
242
+ if combine_method == 'layer_injection':
243
+ # For layer injection, we'll pass the ling embeddings separately
244
+ ling_embed = sentence2_ling if ling2_only else (sentence1_ling + sentence2_ling)
245
+ elif combine_method == 'decoder_add_first' and not cached:
246
+ sentence2_ling = torch.cat([sentence2_ling,
247
+ torch.repeat_interleave(torch.zeros_like(sentence2_ling), input_ids.shape[1] - 1, dim=1)], dim = 1)
248
+ elif combine_method == 'decoder_concat':
249
  if ling2_only:
250
+ decoder_inputs_embeds = torch.cat([sentence2_ling, decoder_inputs_embeds], dim=1)
251
  else:
252
+ decoder_inputs_embeds = torch.cat([sentence1_ling, sentence2_ling, decoder_inputs_embeds], dim=1)
253
+
254
+ if combine_method == 'decoder_add' or (not cached and combine_method == 'decoder_add_first'):
255
  if ling2_only:
256
+ decoder_inputs_embeds = decoder_inputs_embeds + sentence2_ling
257
  else:
258
+ decoder_inputs_embeds = decoder_inputs_embeds + sentence1_ling + sentence2_ling
259
 
260
  return {
261
  "decoder_inputs_embeds": decoder_inputs_embeds,
 
266
  "decoder_head_mask": decoder_head_mask,
267
  "cross_attn_head_mask": cross_attn_head_mask,
268
  "use_cache": use_cache,
269
+ "ling_embed": ling_embed if combine_method == 'layer_injection' else None,
270
  }
271
 
272
  class LogitsAdd(LogitsProcessor):
273
+ def __init__(self, sentence2_ling):
274
  super().__init__()
275
+ self.sentence2_ling = sentence2_ling
276
 
277
  def __call__(self, input_ids, scores):
278
+ return scores + self.sentence2_ling
279
 
280
+ class EncoderDecoderVAE(LingConvT5ForConditionalGeneration):
281
  def __init__(self, config, args, pad_token_id, sepeos_token_id, vocab_size = 32128):
282
+ if args.combine_method == 'layer_injection':
283
+ if args.injection_layer < 0 or args.injection_layer >= config.num_decoder_layers:
284
+ raise ValueError(f"Invalid injection layer: {args.injection_layer}. Must be between 0 and {config.num_decoder_layers - 1}.")
285
+ config.ling_injection_layer = args.injection_layer
286
+ config.ling_injection_type = args.injection_type # 'first' or 'all'
287
+
288
  super().__init__(config)
289
+
290
  self.prepare_inputs_for_generation = types.MethodType(
291
  partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only),
292
  self)
 
304
  nn.ReLU(),
305
  nn.Linear(hidden_dim, hidden_dim),
306
  )
307
+ elif 'concat' in args.combine_method or 'add' in args.combine_method or 'layer_injection' in args.combine_method:
308
  if args.ling_embed_type == 'two-layer':
309
  self.ling_embed = nn.Sequential(
310
  nn.Linear(args.lng_dim, args.lng_dim),
 
314
  else:
315
  self.ling_embed = nn.Linear(args.lng_dim, hidden_dim)
316
  self.ling_dropout = nn.Dropout(args.ling_dropout)
317
+ self.ling_embed.apply(self._init_weights)
318
 
319
  if args.ling_vae:
320
  self.ling_mu = nn.Linear(hidden_dim, hidden_dim)
 
324
  nn.init.xavier_uniform_(self.ling_logvar.weight)
325
 
326
 
327
+ generate_with_grad = unwrap(super().generate)
328
  self.generate_with_grad = MethodType(generate_with_grad, self)
329
+ self.generate_original = super().generate
330
+
331
+ def _init_weights(self, module):
332
+ std = self.args.initializer_range
333
+ if isinstance(module, nn.Linear):
334
+ module.weight.data.normal_(mean=0.0, std=std)
335
+ if module.bias is not None:
336
+ module.bias.data.zero_()
337
+ elif isinstance(module, nn.Embedding):
338
+ module.weight.data.normal_(mean=0.0, std=std)
339
+ if module.padding_idx is not None:
340
+ module.weight.data[module.padding_idx].zero_()
341
 
342
  def get_fusion_layer(self):
343
  if 'fusion' in self.args.combine_method:
 
351
  std = torch.exp(0.5 * logvar)
352
  return mu + std * torch.randn_like(std)
353
 
354
+ def _process_ling_embeddings(self, sentence1_ling, sentence2_ling,
355
+ sentence1_ling_embed, sentence2_ling_embed, bs):
356
+ """Helper method to process linguistic embeddings"""
357
+ cache = {}
358
+
359
+ # Process sentence1 embedding
360
+ if sentence1_ling_embed is not None:
361
+ sentence1_ling = sentence1_ling_embed
362
+ elif sentence1_ling is not None:
363
+ sentence1_ling = self.ling_embed(self.ling_dropout(sentence1_ling))
364
  else:
365
+ sentence1_ling = None
366
+
367
+ # Process sentence2 embedding
368
+ if sentence2_ling_embed is not None:
369
+ sentence2_ling = sentence2_ling_embed
370
+ elif sentence2_ling is not None:
371
+ sentence2_ling = self.ling_embed(self.ling_dropout(sentence2_ling))
372
+ else:
373
+ sentence2_ling = None
374
+
375
+ # Apply VAE if configured
376
+ if self.args.ling_vae and sentence1_ling is not None and sentence2_ling is not None:
377
+ sentence1_ling = F.leaky_relu(sentence1_ling)
378
+ sent1_mu, sent1_logvar = self.ling_mu(sentence1_ling), self.ling_logvar(sentence1_ling)
379
+ sentence1_ling = self.sample(sent1_mu, sent1_logvar)
380
+
381
+ sentence2_ling = F.leaky_relu(sentence2_ling)
382
+ sent2_mu, sent2_logvar = self.ling_mu(sentence2_ling), self.ling_logvar(sentence2_ling)
383
+ sentence2_ling = self.sample(sent2_mu, sent2_logvar)
384
+
385
+ cache.update({
386
+ 'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
387
+ 'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
388
+ 'sentence1_ling': sentence1_ling, 'sentence2_ling': sentence2_ling
389
+ })
390
+ else:
391
+ if sentence2_ling is not None:
392
+ cache['sentence2_ling'] = sentence2_ling
393
+ if sentence1_ling is not None:
394
+ cache['sentence1_ling'] = sentence1_ling
395
+
396
+ # Reshape embeddings
397
+ if sentence1_ling is not None:
398
+ sentence1_ling = sentence1_ling.view(bs, 1, -1)
399
+ if sentence2_ling is not None:
400
+ sentence2_ling = sentence2_ling.view(bs, 1, -1)
401
+
402
+ return sentence1_ling, sentence2_ling, cache
403
+
404
+ def encode(self,
405
+ input_ids=None,
406
+ attention_mask=None,
407
+ sentence1_ling=None,
408
+ sentence2_ling=None,
409
+ sentence1_ling_embed=None,
410
+ sentence2_ling_embed=None,
411
+ inputs_embeds=None,
412
+ ):
413
+ if inputs_embeds is None:
414
+ inputs_embeds = self.shared(input_ids)
415
+ inputs_att_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids)
416
  bs = inputs_embeds.shape[0]
417
+
418
  if self.args.combine_method in ('input_concat', 'input_add'):
419
+ sentence1_ling, sentence2_ling, cache = self._process_ling_embeddings(
420
+ sentence1_ling, sentence2_ling,
421
+ sentence1_ling_embed, sentence2_ling_embed, bs
422
+ )
423
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  if self.args.combine_method == 'input_concat':
425
  if self.args.ling2_only:
426
+ inputs_embeds = torch.cat([inputs_embeds, sentence2_ling], dim=1)
427
  inputs_att_mask = torch.cat([inputs_att_mask,
428
  torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1)
429
  else:
430
+ inputs_embeds = torch.cat([inputs_embeds, sentence1_ling, sentence2_ling], dim=1)
431
  inputs_att_mask = torch.cat([inputs_att_mask,
432
  torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1)
433
  elif self.args.combine_method == 'input_add':
434
  if self.args.ling2_only:
435
+ inputs_embeds = inputs_embeds + sentence2_ling
436
  else:
437
+ inputs_embeds = inputs_embeds + sentence1_ling + sentence2_ling
438
+ else:
439
+ cache = {}
440
+
441
  return self.encoder(inputs_embeds=inputs_embeds,
442
  attention_mask=inputs_att_mask), inputs_att_mask, cache
443
 
444
+ def decode(self,
445
+ sentence2_input_ids=None,
446
+ sentence1_ling=None,
447
+ sentence2_ling=None,
448
+ encoder_outputs=None,
449
+ encoder_attention_mask=None,
450
+ decoder_inputs_embeds=None,
451
+ decoder_attention_mask=None,
452
+ generate=False,
453
+ sentence1_ling_embed=None,
454
+ sentence2_ling_embed=None,
455
+ ling_embed=None,
456
+ generate_with_grad=False,
457
+ **kwargs
458
+ ):
459
+ bs = encoder_outputs[0].shape[0]
460
  cache = {}
461
+
462
+ if decoder_inputs_embeds is None:
463
+ if self.args.combine_method in ('embed_concat', 'decoder_concat', 'decoder_add',
464
+ 'logits_add', 'decoder_add_first', 'layer_injection'):
465
+ sentence1_ling, sentence2_ling, cache = self._process_ling_embeddings(
466
+ sentence1_ling, sentence2_ling,
467
+ sentence1_ling_embed, sentence2_ling_embed, bs
468
+ )
469
+
470
+ if (self.args.combine_method == 'decoder_add_first' or
471
+ (self.args.combine_method == 'layer_injection' and
472
+ self.args.injection_type == 'first')) and not generate:
473
+ sentence2_ling = torch.cat([sentence2_ling,
474
+ torch.repeat_interleave(torch.zeros_like(sentence2_ling),
475
+ sentence2_input_ids.shape[1] - 1, dim=1)], dim = 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  else:
477
+ sentence1_ling, sentence2_ling = None, None
478
+
 
479
  if generate:
480
  if self.args.combine_method == 'logits_add':
481
+ logits_processor = LogitsProcessorList([LogitsAdd(sentence2_ling.view(bs, -1))])
482
  else:
483
  logits_processor = LogitsProcessorList()
484
 
485
+ generate_fn = self.generate_with_grad if generate_with_grad else self.generate_original
486
+ dec_output = generate_fn(
487
+ attention_mask=encoder_attention_mask,
488
+ encoder_outputs=encoder_outputs,
489
+ sentence1_ling=sentence1_ling,
490
+ sentence2_ling=sentence2_ling,
 
491
  logits_processor = logits_processor,
492
  # renormalize_logits=True,
493
  # do_sample=True,
 
496
  # min_new_tokens=3,
497
  # repetition_penalty=1.2,
498
  max_length=self.args.max_length,
499
+ use_cache=True,
500
+ **kwargs
501
  )
502
+ return dec_output, cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
+ if sentence2_input_ids is not None:
505
+ labels = sentence2_input_ids.clone()
506
+ labels[labels == self.pad_token_id] = -100
507
+ else:
508
+ labels = None
509
+
510
+ if decoder_inputs_embeds is None:
511
+ decoder_input_ids = self._shift_right(sentence2_input_ids)
512
+ decoder_inputs_embeds = self.shared(decoder_input_ids)
513
+
514
+ if self.args.combine_method == 'decoder_concat':
515
+ if self.args.ling2_only:
516
+ decoder_inputs_embeds = torch.cat([sentence2_ling, decoder_inputs_embeds], dim=1)
517
+ decoder_attention_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_attention_mask], dim=1)
518
+ labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
519
+ labels], dim=1)
520
+ else:
521
+ decoder_inputs_embeds = torch.cat([sentence1_ling, sentence2_ling, decoder_inputs_embeds], dim=1)
522
+ decoder_attention_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_attention_mask], dim=1)
523
+ labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
524
+ labels], dim=1)
525
+ elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' :
526
+ if self.args.ling2_only:
527
+ decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sentence2_ling
528
+ else:
529
+ decoder_inputs_embeds = decoder_inputs_embeds + sentence1_ling + sentence2_ling
530
+
531
+ if ling_embed is None:
532
+ ling_embed = sentence2_ling
533
+
534
+ dec_output = super().forward(
535
  decoder_inputs_embeds=decoder_inputs_embeds,
536
+ decoder_attention_mask=decoder_attention_mask,
537
+ encoder_outputs=encoder_outputs,
538
+ attention_mask=encoder_attention_mask,
539
  labels=labels,
540
+ ling_embed=ling_embed,
541
+ **kwargs
542
  )
543
  if self.args.combine_method == 'logits_add':
544
+ dec_output.logits = dec_output.logits + self.args.combine_weight * sentence2_ling
545
  vocab_size = dec_output.logits.size(-1)
546
  dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1))
547
  return dec_output, cache
548
 
549
+ def generate(self, *args, **kwargs):
550
+ return self.forward(*args, **kwargs, generate=True)
551
+
552
+
553
+ def forward(self,
554
+ input_ids=None,
555
+ attention_mask=None,
556
+ labels=None,
557
+ decoder_attention_mask=None,
558
+ decoder_inputs_embeds=None,
559
+ sentence1_ling=None,
560
+ sentence2_ling=None,
561
+ sentence1_ling_embed=None,
562
+ sentence2_ling_embed=None,
563
+ inputs_embeds=None,
564
+ generate=False,
565
+ encoder_outputs=None,
566
+ encoder_attention_mask=None,
567
+ ling_embed=None,
568
+ generate_with_grad=False,
569
+ **kwargs):
570
 
571
+ cache = {}
572
+ if encoder_outputs is None:
573
+ encoder_outputs, encoder_attention_mask, cache = self.encode(
574
+ input_ids=input_ids,
575
+ attention_mask=attention_mask,
576
+ sentence1_ling=sentence1_ling,
577
+ sentence2_ling=sentence2_ling,
578
+ sentence1_ling_embed=sentence1_ling_embed,
579
+ sentence2_ling_embed=sentence2_ling_embed,
580
+ inputs_embeds=inputs_embeds
581
+ )
582
+
583
+ dec_output, cache2 = self.decode(
584
+ sentence2_input_ids=labels,
585
+ sentence1_ling=sentence1_ling,
586
+ sentence2_ling=sentence2_ling,
587
+ decoder_inputs_embeds=decoder_inputs_embeds,
588
+ decoder_attention_mask=decoder_attention_mask,
589
+ encoder_outputs=encoder_outputs,
590
+ encoder_attention_mask=encoder_attention_mask,
591
+ generate=generate,
592
+ sentence1_ling_embed=sentence1_ling_embed,
593
+ sentence2_ling_embed=sentence2_ling_embed,
594
+ ling_embed=ling_embed,
595
+ generate_with_grad=generate_with_grad,
596
+ **kwargs
597
+ )
598
+
599
  cache.update(cache2)
600
+ if generate:
601
+ return dec_output
602
+ else:
603
+ return MySeq2SeqLMOutput(
604
+ loss=dec_output.loss,
605
+ logits=dec_output.logits,
606
+ past_key_values=dec_output.past_key_values,
607
+ decoder_hidden_states=dec_output.decoder_hidden_states,
608
+ decoder_attentions=dec_output.decoder_attentions,
609
+ cross_attentions=dec_output.cross_attentions,
610
+ encoder_last_hidden_state=encoder_outputs[0],
611
+ encoder_hidden_states=getattr(encoder_outputs, 'hidden_states', None),
612
+ encoder_attentions=getattr(encoder_outputs, 'attentions', None),
613
+ cache=cache
614
+ )
615
 
616
  def infer_with_cache(self, batch):
617
+ dec_output, _, cache = self(batch, generate = True)
618
  return dec_output, cache
619
 
620
  def infer(self, batch):
621
  dec_output, _ = self.infer_with_cache(batch)
622
  return dec_output
623
 
624
+ def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer, progress=None):
625
  from torch.autograd import grad
626
  interpolations = []
627
  def line_search():
 
 
628
  eta = 1e3
629
  sem_prob = 1
630
  patience = 4
 
634
  new_loss, pred = get_loss(param_)
635
  max_len = pred.shape[1]
636
  lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1
637
+ sem_batch = {**batch,
638
+ 'sentence2_input_ids': pred,
639
+ 'sentence2_attention_mask': sequence_mask(lens, max_len = max_len)
640
+ }
641
+ sem_prob = torch.sigmoid(sem_emb.compare_sem(**sem_batch)).item()
 
 
642
  if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1:
643
  return param_
644
  eta *= 2.25
 
647
 
648
  def get_loss(param):
649
  if self.args.feedback_param == 'l':
650
+ batch.update({'sentence2_ling_embed': param})
651
  elif self.args.feedback_param == 's':
652
  batch.update({'inputs_embeds': param})
653
 
 
655
  logits = param
656
  pred = param.argmax(-1)
657
  else:
658
+ outputs = self.generate(**batch, output_scores=True, return_dict_in_generate=True, generate_with_grad=True)
659
+ pred = outputs.sequences
660
+ logits = torch.stack(outputs.scores, dim=1)
661
  out = ling_disc(logits = logits)
662
  probs = F.softmax(out, 1)
663
  if ling_disc.quant:
 
670
  ling2_embed = self.ling_embed(batch['sentence2_ling'])
671
  param = torch.nn.Parameter(ling2_embed, requires_grad = True)
672
  elif self.args.feedback_param == 's':
673
+ inputs_embeds = self.shared(batch['input_ids'])
674
  param = torch.nn.Parameter(inputs_embeds, requires_grad = True)
675
  elif self.args.feedback_param == 'logits':
676
  logits = self.infer_with_cache(batch)[1]['scores']
677
  param = torch.nn.Parameter(logits, requires_grad = True)
678
+ num_iter = 0
679
+ while num_iter < 3:
680
  loss, pred = get_loss(param)
681
  pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
682
  skip_special_tokens=True)[0]
 
688
  param = line_search()
689
  if param is False:
690
  break
691
+ num_iter += 1
692
+ if progress is not None:
693
+ progress((num_iter, None), unit='intermediate paraphrase generated.')
694
  return pred, [pred_text, interpolations]
695
 
696
  def set_grad(module, state):
 
729
  def __init__(self,
730
  model_name="google/flan-t5-base",
731
  disc_type='deberta',
732
+ disc_ckpt='mohdelgaar/lingconv-discriminator',
733
  # disc_type='t5',
734
  # disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt',
735
  ):
 
749
  ling_disc = LingDisc(args.model_name, args.disc_type, args.disc_model_path).to(device)
750
  else:
751
  ling_disc = None
 
 
752
 
753
+ if args.model_path:
754
  model = EncoderDecoderVAE.from_pretrained(args.model_path, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device)
755
  else:
756
+ model = EncoderDecoderVAE.from_pretrained(args.model_name, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device)
757
 
758
+ if args.sem_loss or args.model_path:
759
  if args.sem_loss_type == 'shared':
760
  sem_emb = model.encoder
761
  elif args.sem_loss_type == 'dedicated':
 
767
 
768
  return model, ling_disc, sem_emb
769
 
770
+ @dataclass
771
+ class MySeq2SeqLMOutput(Seq2SeqLMOutput):
772
+ """
773
+ Extends Seq2SeqLMOutput to include a cache dictionary for additional model outputs.
774
+
775
+ Args:
776
+ cache (`Dict[str, Any]`):
777
+ Dictionary containing additional model outputs like linguistic features,
778
+ VAE parameters, scores, etc.
779
+ """
780
+ cache: Optional[Dict[str, Any]] = None
options.py CHANGED
@@ -1,16 +1,28 @@
1
- import os, json
2
  import argparse
3
- import numpy as np
4
  from datetime import datetime
5
  from const import lftkplus_names
 
6
  from copy import deepcopy
 
7
 
 
 
 
 
 
 
 
 
 
8
 
9
  def parse_args(ckpt=None):
10
  parser = argparse.ArgumentParser()
 
 
 
11
  parser.add_argument('--data_dir', default='/data/mohamed/data')
12
  parser.add_argument('--data', default='ling_conversion')
13
- parser.add_argument('--data_sources')
14
  parser.add_argument('--data_type', default='text')
15
  parser.add_argument('--aim_repo', default='/data/mohamed/')
16
  parser.add_argument('--ckpt_dir', default='/data/mohamed/checkpoints')
@@ -25,7 +37,7 @@ def parse_args(ckpt=None):
25
  parser.add_argument('--sem_loss_tao', default=0.5, type=float)
26
  parser.add_argument('--sem_loss_eps', default=1, type=float)
27
  parser.add_argument('--ckpt')
28
- parser.add_argument('--disc_ckpt')
29
  parser.add_argument('--sem_ckpt')
30
  parser.add_argument('--lng_ids')
31
  parser.add_argument('--lng_ids_idx', type=int)
@@ -36,30 +48,34 @@ def parse_args(ckpt=None):
36
  parser.add_argument('--sem_path', default="mohdelgaar/lingconv-semantic-classifier")
37
  parser.add_argument('--sem_model_path', default="mohdelgaar/lingconv-semantic-classifier")
38
  parser.add_argument('--disc_model_path', default="mohdelgaar/lingconv-discriminator")
39
- parser.add_argument('--disc_type', default="t5")
40
- parser.add_argument('--aim_exp', default='ling-conversion')
41
  parser.add_argument('--sem_loss_type', default='dedicated')
42
- parser.add_argument('--combine_method', default='none')
 
 
43
  parser.add_argument('--train_log', type=int, default=200)
44
- parser.add_argument('--val_log', type=int, default=2000)
45
- parser.add_argument('--batch_size', type=int, default=64)
46
- parser.add_argument('--eval_batch_size', type=int, default=32)
47
- parser.add_argument('--max_eval_samples', type=int, default=1000)
48
- parser.add_argument('--test_batch_size', type=int, default=1)
 
49
  parser.add_argument('--hidden_dim', type=int, default=500)
50
  parser.add_argument('--latent_dim', type=int, default=150)
51
  parser.add_argument('--lng_dim', type=int, default=40)
52
- parser.add_argument('--disc_lng_dim', type=int)
53
  parser.add_argument('--use_lora', action='store_true')
54
  parser.add_argument('--lora_r', type=int, default=64)
55
  parser.add_argument('--gpu', type=str, default='0')
56
- parser.add_argument('--epochs', type=int, default=10)
57
  parser.add_argument('--grad_accumulation', type=int, default=1)
58
  parser.add_argument('--n_ica', type=int, default=10)
59
  parser.add_argument('--max_length', type=int, default=200)
60
  parser.add_argument('--total_steps', type=int)
61
  parser.add_argument('--kld_const', type=float, default=1)
62
- parser.add_argument('--lr', type=float, default=1e-4)
 
63
  parser.add_argument('--kl_weight', type=float, default=1e-1)
64
  parser.add_argument('--weight_decay', type=float, default=1e-2)
65
  parser.add_argument('--ling_dropout', type=float, default=0.1)
@@ -71,12 +87,12 @@ def parse_args(ckpt=None):
71
  parser.add_argument('--pretrain_disc', action='store_true')
72
  parser.add_argument('--linggen_type', default='none')
73
  parser.add_argument('--linggen_input', default='s+l')
74
- parser.add_argument('--aug_same', action='store_true')
75
  parser.add_argument('--ling_vae', action='store_true')
76
  parser.add_argument('--process_lingpred', action='store_true')
77
  parser.add_argument('--fudge_lambda', type=float, default=1.0)
78
  parser.add_argument('--use_lingpred', action='store_true')
79
- parser.add_argument('--ling2_only', action='store_true')
80
  parser.add_argument('--cycle_loss', action='store_true')
81
  parser.add_argument('--disc_loss', action='store_true')
82
  parser.add_argument('--sem_loss', action='store_true')
@@ -96,19 +112,36 @@ def parse_args(ckpt=None):
96
  parser.add_argument('--quant_nbins', type=int, default=20)
97
  parser.add_argument('--src_lng', default = 'ling')
98
  parser.add_argument('--to_restore', nargs='+', default=[])
 
 
 
 
 
 
 
 
 
 
 
 
99
  # args = parser.parse_args()
100
  args, unknown = parser.parse_known_args()
101
  args.name = f'{datetime.now().strftime("%m%d_%H-%M-%S")}-{args.data}-{args.combine_method}'
102
 
103
  major_arg = args.major_arg
104
  to_restore = [
 
 
 
105
  ] + args.to_restore
106
  to_restore = {k: args.__dict__[k] for k in to_restore}
107
 
108
  if not args.disc_loss or args.disc_ckpt:
109
  args.disc_steps = 0
110
 
111
- if args.data_sources is not None:
 
 
112
  args.data_sources = args.data_sources.split(',')
113
 
114
  if ckpt is not None:
@@ -120,13 +153,17 @@ def parse_args(ckpt=None):
120
  ckpts = args.ckpt.split(',')
121
  args_list = [deepcopy(args) for _ in range(len(ckpts))]
122
  for i in range(len(ckpts)):
123
- args_path = ckpts[i].replace('_best', '').replace('.pt', '.json')
124
  with open(args_path) as f:
125
  args_list[i].__dict__.update(json.load(f))
126
  args_list[i].__dict__.update(to_restore)
127
  args_list[i].ckpt = ckpts[i]
128
  else:
129
- args_path = args.ckpt.replace('_best', '').replace('.pt', '.json')
 
 
 
 
130
  ckpt = args.ckpt
131
  with open(args_path) as f:
132
  args.__dict__.update(json.load(f))
 
 
1
  import argparse
 
2
  from datetime import datetime
3
  from const import lftkplus_names
4
+ import os, json
5
  from copy import deepcopy
6
+ import numpy as np
7
 
8
+ def str2bool(v):
9
+ if isinstance(v, bool):
10
+ return v
11
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
12
+ return True
13
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
14
+ return False
15
+ else:
16
+ raise argparse.ArgumentTypeError('Boolean value expected.')
17
 
18
  def parse_args(ckpt=None):
19
  parser = argparse.ArgumentParser()
20
+ parser.add_argument('--do_train', action='store_true')
21
+ parser.add_argument('--do_eval', action='store_true')
22
+ parser.add_argument('--do_predict', action='store_true')
23
  parser.add_argument('--data_dir', default='/data/mohamed/data')
24
  parser.add_argument('--data', default='ling_conversion')
25
+ parser.add_argument('--data_sources', default='qqp,mrpc,stsb')
26
  parser.add_argument('--data_type', default='text')
27
  parser.add_argument('--aim_repo', default='/data/mohamed/')
28
  parser.add_argument('--ckpt_dir', default='/data/mohamed/checkpoints')
 
37
  parser.add_argument('--sem_loss_tao', default=0.5, type=float)
38
  parser.add_argument('--sem_loss_eps', default=1, type=float)
39
  parser.add_argument('--ckpt')
40
+ parser.add_argument('--disc_ckpt', default='mohdelgaar/lingconv-discriminator')
41
  parser.add_argument('--sem_ckpt')
42
  parser.add_argument('--lng_ids')
43
  parser.add_argument('--lng_ids_idx', type=int)
 
48
  parser.add_argument('--sem_path', default="mohdelgaar/lingconv-semantic-classifier")
49
  parser.add_argument('--sem_model_path', default="mohdelgaar/lingconv-semantic-classifier")
50
  parser.add_argument('--disc_model_path', default="mohdelgaar/lingconv-discriminator")
51
+ parser.add_argument('--disc_type', default="deberta")
52
+ parser.add_argument('--aim_exp', default='lingconv-1201')
53
  parser.add_argument('--sem_loss_type', default='dedicated')
54
+ parser.add_argument('--combine_method', default='decoder_add_first')
55
+ parser.add_argument('--injection_type', default='first')
56
+ parser.add_argument('--injection_layer', type=int, default=1)
57
  parser.add_argument('--train_log', type=int, default=200)
58
+ parser.add_argument('--val_log', type=int, default=1000)
59
+ parser.add_argument('--warmup_steps', type=int, default=1000)
60
+ parser.add_argument('--batch_size', type=int, default=16)
61
+ parser.add_argument('--eval_batch_size', type=int, default=256)
62
+ parser.add_argument('--max_eval_samples', type=int, default=3000)
63
+ parser.add_argument('--test_batch_size', type=int, default=256)
64
  parser.add_argument('--hidden_dim', type=int, default=500)
65
  parser.add_argument('--latent_dim', type=int, default=150)
66
  parser.add_argument('--lng_dim', type=int, default=40)
67
+ parser.add_argument('--disc_lng_dim', type=int, default=40)
68
  parser.add_argument('--use_lora', action='store_true')
69
  parser.add_argument('--lora_r', type=int, default=64)
70
  parser.add_argument('--gpu', type=str, default='0')
71
+ parser.add_argument('--epochs', type=int, default=2)
72
  parser.add_argument('--grad_accumulation', type=int, default=1)
73
  parser.add_argument('--n_ica', type=int, default=10)
74
  parser.add_argument('--max_length', type=int, default=200)
75
  parser.add_argument('--total_steps', type=int)
76
  parser.add_argument('--kld_const', type=float, default=1)
77
+ parser.add_argument('--lr', type=float, default=1e-3)
78
+ parser.add_argument('--initializer_range', type=float, default=0.02)
79
  parser.add_argument('--kl_weight', type=float, default=1e-1)
80
  parser.add_argument('--weight_decay', type=float, default=1e-2)
81
  parser.add_argument('--ling_dropout', type=float, default=0.1)
 
87
  parser.add_argument('--pretrain_disc', action='store_true')
88
  parser.add_argument('--linggen_type', default='none')
89
  parser.add_argument('--linggen_input', default='s+l')
90
+ parser.add_argument("--aug_same", type=str2bool, nargs='?', const=True, default=False)
91
  parser.add_argument('--ling_vae', action='store_true')
92
  parser.add_argument('--process_lingpred', action='store_true')
93
  parser.add_argument('--fudge_lambda', type=float, default=1.0)
94
  parser.add_argument('--use_lingpred', action='store_true')
95
+ parser.add_argument('--ling2_only', action='store_true', default=True)
96
  parser.add_argument('--cycle_loss', action='store_true')
97
  parser.add_argument('--disc_loss', action='store_true')
98
  parser.add_argument('--sem_loss', action='store_true')
 
112
  parser.add_argument('--quant_nbins', type=int, default=20)
113
  parser.add_argument('--src_lng', default = 'ling')
114
  parser.add_argument('--to_restore', nargs='+', default=[])
115
+ parser.add_argument('--freeze_lm', action='store_true',
116
+ help='Freeze the language model and only train the linguistic embedding')
117
+ parser.add_argument('--prepend_prompt', action='store_true',
118
+ help='Prepend "generate a paraphrase: " to input text')
119
+ parser.add_argument('--prompt_text', type=str, default="generate a paraphrase: ",
120
+ help='Text to prepend to input if prepend_prompt is True')
121
+ parser.add_argument('--do_imputation', action='store_true',
122
+ help='Whether to perform imputation on linguistic features')
123
+ parser.add_argument('--imputation_percentage', type=int, default=20,
124
+ help='Percentage of features to impute (20, 40, 60, 80)')
125
+ parser.add_argument('--imputation_seed', type=int, default=0,
126
+ help='Seed for imputation set selection (0, 1, 2)')
127
  # args = parser.parse_args()
128
  args, unknown = parser.parse_known_args()
129
  args.name = f'{datetime.now().strftime("%m%d_%H-%M-%S")}-{args.data}-{args.combine_method}'
130
 
131
  major_arg = args.major_arg
132
  to_restore = [
133
+ 'total_steps','major_arg','gpu','demo', 'eval_only', 'save_predict', 'predict_fn', 'fudge', 'predict_with_feedback',
134
+ 'feedback_param', 'fb_log', 'data_dir', 'data', 'disc_ckpt', 'disc_type', 'sem_ckpt', 'fudge_lambda', 'eval_batch_size', 'test_batch_size', 'max_eval_samples',
135
+ 'do_train', 'do_eval', 'do_predict',
136
  ] + args.to_restore
137
  to_restore = {k: args.__dict__[k] for k in to_restore}
138
 
139
  if not args.disc_loss or args.disc_ckpt:
140
  args.disc_steps = 0
141
 
142
+ if args.data_sources == 'all':
143
+ args.data_sources = None
144
+ elif args.data_sources is not None:
145
  args.data_sources = args.data_sources.split(',')
146
 
147
  if ckpt is not None:
 
153
  ckpts = args.ckpt.split(',')
154
  args_list = [deepcopy(args) for _ in range(len(ckpts))]
155
  for i in range(len(ckpts)):
156
+ args_path = ckpts[i].replace('_best', '').replace('.pt', '') + '.json'
157
  with open(args_path) as f:
158
  args_list[i].__dict__.update(json.load(f))
159
  args_list[i].__dict__.update(to_restore)
160
  args_list[i].ckpt = ckpts[i]
161
  else:
162
+ args.ckpt = args.ckpt.rstrip('/')
163
+ if 'checkpoint-' in args.ckpt:
164
+ args_path = os.path.dirname(args.ckpt) + '.json'
165
+ else:
166
+ args_path = args.ckpt.replace('.pt', '') + '.json'
167
  ckpt = args.ckpt
168
  with open(args_path) as f:
169
  args.__dict__.update(json.load(f))