Spaces:
Sleeping
Sleeping
Commit
·
132da4a
1
Parent(s):
251ecbb
improve performance
Browse files
app.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import spacy
|
2 |
import nltk
|
3 |
nltk.download('wordnet', quiet=True)
|
4 |
-
spacy.
|
|
|
5 |
from compute_lng import compute_lng
|
6 |
|
7 |
import torch
|
@@ -111,7 +112,7 @@ def impute_targets():
|
|
111 |
shared_state.target = round_ling(interp_raw).tolist()
|
112 |
return shared_state.target
|
113 |
|
114 |
-
def generate_with_feedback(sent1, approx):
|
115 |
if sent1 == '':
|
116 |
raise gr.Error('Please input a source text.')
|
117 |
|
@@ -122,24 +123,25 @@ def generate_with_feedback(sent1, approx):
|
|
122 |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
|
123 |
ling2 = torch.tensor(scaler.transform([shared_state.target])).float().to(device)
|
124 |
inputs = {
|
125 |
-
'
|
126 |
'sentence2_ling': ling2,
|
127 |
-
'
|
128 |
}
|
129 |
|
130 |
-
|
|
|
131 |
|
132 |
interpolation = '-- ' + '\n-- '.join(interpolations)
|
133 |
# Return both the generation results and the updated slider values
|
134 |
return [pred_text, interpolation] + [gr.update(value=val) for val in shared_state.target]
|
135 |
|
136 |
-
def generate_random(sent1, count, approx):
|
137 |
if sent1 == '':
|
138 |
raise gr.Error('Please input a source text.')
|
139 |
preds, interpolations = [], []
|
140 |
orig_active_indices = shared_state.active_indices
|
141 |
shared_state.active_indices = set(range(len(lng_names)))
|
142 |
-
for c in range(count):
|
143 |
idx = np.random.randint(0, len(ling_collection))
|
144 |
ling_ex = ling_collection[idx]
|
145 |
shared_state.target = ling_ex.copy()
|
@@ -167,7 +169,7 @@ def generate_random(sent1, count, approx):
|
|
167 |
shared_state.active_indices = orig_active_indices
|
168 |
return '\n***\n'.join(preds), '\n***\n'.join(interpolations)
|
169 |
|
170 |
-
def estimate_gen(sent1, sent2, approx):
|
171 |
if 'approximate' in approx:
|
172 |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device)
|
173 |
with torch.no_grad():
|
@@ -183,7 +185,7 @@ def estimate_gen(sent1, sent2, approx):
|
|
183 |
|
184 |
orig_active_indices = shared_state.active_indices
|
185 |
shared_state.active_indices = set(range(len(lng_names)))
|
186 |
-
gen = generate_with_feedback(sent1, approx)[:2]
|
187 |
shared_state.active_indices = orig_active_indices
|
188 |
return gen + [gr.update(value=val) for val in shared_state.target]
|
189 |
|
|
|
1 |
import spacy
|
2 |
import nltk
|
3 |
nltk.download('wordnet', quiet=True)
|
4 |
+
if not spacy.util.is_package('en_core_web_sm'):
|
5 |
+
spacy.cli.download('en_core_web_sm')
|
6 |
from compute_lng import compute_lng
|
7 |
|
8 |
import torch
|
|
|
112 |
shared_state.target = round_ling(interp_raw).tolist()
|
113 |
return shared_state.target
|
114 |
|
115 |
+
def generate_with_feedback(sent1, approx, progress=gr.Progress()):
|
116 |
if sent1 == '':
|
117 |
raise gr.Error('Please input a source text.')
|
118 |
|
|
|
123 |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
|
124 |
ling2 = torch.tensor(scaler.transform([shared_state.target])).float().to(device)
|
125 |
inputs = {
|
126 |
+
'input_ids': input_ids,
|
127 |
'sentence2_ling': ling2,
|
128 |
+
'attention_mask': torch.ones_like(input_ids)
|
129 |
}
|
130 |
|
131 |
+
progress((0, None), unit='intermediate paraphrase generated.')
|
132 |
+
pred, (pred_text, interpolations) = model.infer_with_feedback_BP(ling_disc, sem_emb, inputs, tokenizer, progress)
|
133 |
|
134 |
interpolation = '-- ' + '\n-- '.join(interpolations)
|
135 |
# Return both the generation results and the updated slider values
|
136 |
return [pred_text, interpolation] + [gr.update(value=val) for val in shared_state.target]
|
137 |
|
138 |
+
def generate_random(sent1, count, approx, progress=gr.Progress()):
|
139 |
if sent1 == '':
|
140 |
raise gr.Error('Please input a source text.')
|
141 |
preds, interpolations = [], []
|
142 |
orig_active_indices = shared_state.active_indices
|
143 |
shared_state.active_indices = set(range(len(lng_names)))
|
144 |
+
for c in progress.tqdm(range(count), desc='Generating random sentences', unit='paraphrases'):
|
145 |
idx = np.random.randint(0, len(ling_collection))
|
146 |
ling_ex = ling_collection[idx]
|
147 |
shared_state.target = ling_ex.copy()
|
|
|
169 |
shared_state.active_indices = orig_active_indices
|
170 |
return '\n***\n'.join(preds), '\n***\n'.join(interpolations)
|
171 |
|
172 |
+
def estimate_gen(sent1, sent2, approx, progress=gr.Progress()):
|
173 |
if 'approximate' in approx:
|
174 |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device)
|
175 |
with torch.no_grad():
|
|
|
185 |
|
186 |
orig_active_indices = shared_state.active_indices
|
187 |
shared_state.active_indices = set(range(len(lng_names)))
|
188 |
+
gen = generate_with_feedback(sent1, approx, progress)[:2]
|
189 |
shared_state.active_indices = orig_active_indices
|
190 |
return gen + [gr.update(value=val) for val in shared_state.target]
|
191 |
|
const.py
CHANGED
@@ -1030,7 +1030,6 @@ used_indices = [
|
|
1030 |
63, 64, 65, 66, 67, 68, 73, 121, 124, 129, 134, 136, 254,
|
1031 |
257, 258, 261, 263, 272, 274
|
1032 |
]
|
1033 |
-
lftk_used_indices = [1, 7, 8, 9, 10, 11, 12, 17, 65, 68, 73, 78, 80, 198, 201, 202, 205, 207, 216, 218]
|
1034 |
|
1035 |
eval_indices = [4,5,6,18,257,272]
|
1036 |
eval_indices = [used_indices.index(idx) for idx in eval_indices]
|
|
|
1030 |
63, 64, 65, 66, 67, 68, 73, 121, 124, 129, 134, 136, 254,
|
1031 |
257, 258, 261, 263, 272, 274
|
1032 |
]
|
|
|
1033 |
|
1034 |
eval_indices = [4,5,6,18,257,272]
|
1035 |
eval_indices = [used_indices.index(idx) for idx in eval_indices]
|
model.py
CHANGED
@@ -10,6 +10,10 @@ from types import MethodType
|
|
10 |
from utils import *
|
11 |
from ling_disc import DebertaReplacedTokenizer
|
12 |
from const import *
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
|
@@ -77,9 +81,9 @@ class LingGenerator(nn.Module):
|
|
77 |
bs = inputs_embeds.shape[0]
|
78 |
|
79 |
if self.gen_input == 's+l':
|
80 |
-
|
81 |
-
|
82 |
-
inputs_embeds = inputs_embeds +
|
83 |
|
84 |
gen = self.gen(inputs_embeds=inputs_embeds,
|
85 |
attention_mask=inputs_att_mask).last_hidden_state.mean(1)
|
@@ -185,13 +189,13 @@ class SemEmb(T5EncoderModel):
|
|
185 |
nn.Linear(hidden_dim, 1))
|
186 |
|
187 |
def compare_sem(self, **batch):
|
188 |
-
bs = batch['
|
189 |
-
ones = torch.ones((bs, 1), device=batch['
|
190 |
sep = torch.ones((bs, 1), dtype=torch.long,
|
191 |
-
device=batch['
|
192 |
-
att_mask = torch.cat([batch['
|
193 |
if 'logits' in batch:
|
194 |
-
input_ids = torch.cat([batch['
|
195 |
embeds1 = self.shared(input_ids)
|
196 |
|
197 |
logits = batch['logits']
|
@@ -201,11 +205,11 @@ class SemEmb(T5EncoderModel):
|
|
201 |
|
202 |
embeds2 = onehot_ @ self.shared.weight
|
203 |
embeds1_2 = torch.cat([embeds1, embeds2], dim=1)
|
204 |
-
hidden_units =
|
205 |
attention_mask=att_mask).last_hidden_state.mean(1)
|
206 |
elif 'sentence2_input_ids' in batch:
|
207 |
-
input_ids = torch.cat([batch['
|
208 |
-
hidden_units =
|
209 |
attention_mask=att_mask).last_hidden_state.mean(1)
|
210 |
probs = self.projection(hidden_units)
|
211 |
return probs
|
@@ -222,31 +226,36 @@ def prepare_inputs_for_generation(
|
|
222 |
cross_attn_head_mask=None,
|
223 |
use_cache=None,
|
224 |
encoder_outputs=None,
|
225 |
-
|
226 |
-
|
227 |
**kwargs
|
228 |
):
|
229 |
-
|
230 |
# cut decoder_input_ids if past is used
|
231 |
if past_key_values is not None:
|
232 |
input_ids = input_ids[:, -1:]
|
233 |
|
|
|
|
|
234 |
input_ids = input_ids.clone()
|
235 |
decoder_inputs_embeds = self.shared(input_ids)
|
236 |
|
237 |
-
if combine_method == '
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
241 |
if ling2_only:
|
242 |
-
decoder_inputs_embeds = torch.cat([
|
243 |
else:
|
244 |
-
decoder_inputs_embeds = torch.cat([
|
245 |
-
|
|
|
246 |
if ling2_only:
|
247 |
-
decoder_inputs_embeds = decoder_inputs_embeds +
|
248 |
else:
|
249 |
-
decoder_inputs_embeds = decoder_inputs_embeds +
|
250 |
|
251 |
return {
|
252 |
"decoder_inputs_embeds": decoder_inputs_embeds,
|
@@ -257,19 +266,27 @@ def prepare_inputs_for_generation(
|
|
257 |
"decoder_head_mask": decoder_head_mask,
|
258 |
"cross_attn_head_mask": cross_attn_head_mask,
|
259 |
"use_cache": use_cache,
|
|
|
260 |
}
|
261 |
|
262 |
class LogitsAdd(LogitsProcessor):
|
263 |
-
def __init__(self,
|
264 |
super().__init__()
|
265 |
-
self.
|
266 |
|
267 |
def __call__(self, input_ids, scores):
|
268 |
-
return scores + self.
|
269 |
|
270 |
-
class EncoderDecoderVAE(
|
271 |
def __init__(self, config, args, pad_token_id, sepeos_token_id, vocab_size = 32128):
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
super().__init__(config)
|
|
|
273 |
self.prepare_inputs_for_generation = types.MethodType(
|
274 |
partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only),
|
275 |
self)
|
@@ -287,7 +304,7 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
287 |
nn.ReLU(),
|
288 |
nn.Linear(hidden_dim, hidden_dim),
|
289 |
)
|
290 |
-
elif 'concat' in args.combine_method or 'add' in args.combine_method:
|
291 |
if args.ling_embed_type == 'two-layer':
|
292 |
self.ling_embed = nn.Sequential(
|
293 |
nn.Linear(args.lng_dim, args.lng_dim),
|
@@ -297,6 +314,7 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
297 |
else:
|
298 |
self.ling_embed = nn.Linear(args.lng_dim, hidden_dim)
|
299 |
self.ling_dropout = nn.Dropout(args.ling_dropout)
|
|
|
300 |
|
301 |
if args.ling_vae:
|
302 |
self.ling_mu = nn.Linear(hidden_dim, hidden_dim)
|
@@ -306,8 +324,20 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
306 |
nn.init.xavier_uniform_(self.ling_logvar.weight)
|
307 |
|
308 |
|
309 |
-
generate_with_grad = unwrap(
|
310 |
self.generate_with_grad = MethodType(generate_with_grad, self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
|
312 |
def get_fusion_layer(self):
|
313 |
if 'fusion' in self.args.combine_method:
|
@@ -321,122 +351,143 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
321 |
std = torch.exp(0.5 * logvar)
|
322 |
return mu + std * torch.randn_like(std)
|
323 |
|
324 |
-
def
|
325 |
-
|
326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
else:
|
328 |
-
|
329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
bs = inputs_embeds.shape[0]
|
331 |
-
|
332 |
if self.args.combine_method in ('input_concat', 'input_add'):
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
sent2_ling = batch['sent2_ling_embed']
|
339 |
-
else:
|
340 |
-
sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling']))
|
341 |
-
if self.args.ling_vae:
|
342 |
-
sent1_ling = F.leaky_relu(sent1_ling)
|
343 |
-
sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling)
|
344 |
-
sent1_ling = self.sample(sent1_mu, sent1_logvar)
|
345 |
-
|
346 |
-
sent2_ling = F.leaky_relu(sent2_ling)
|
347 |
-
sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
|
348 |
-
sent2_ling = self.sample(sent2_mu, sent2_logvar)
|
349 |
-
cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
|
350 |
-
'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
|
351 |
-
'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
|
352 |
-
else:
|
353 |
-
cache.update({'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
|
354 |
-
sent1_ling = sent1_ling.view(bs, 1, -1)
|
355 |
-
sent2_ling = sent2_ling.view(bs, 1, -1)
|
356 |
if self.args.combine_method == 'input_concat':
|
357 |
if self.args.ling2_only:
|
358 |
-
inputs_embeds = torch.cat([inputs_embeds,
|
359 |
inputs_att_mask = torch.cat([inputs_att_mask,
|
360 |
torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1)
|
361 |
else:
|
362 |
-
inputs_embeds = torch.cat([inputs_embeds,
|
363 |
inputs_att_mask = torch.cat([inputs_att_mask,
|
364 |
torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1)
|
365 |
elif self.args.combine_method == 'input_add':
|
366 |
if self.args.ling2_only:
|
367 |
-
inputs_embeds = inputs_embeds +
|
368 |
else:
|
369 |
-
inputs_embeds = inputs_embeds +
|
|
|
|
|
|
|
370 |
return self.encoder(inputs_embeds=inputs_embeds,
|
371 |
attention_mask=inputs_att_mask), inputs_att_mask, cache
|
372 |
|
373 |
-
def decode(self,
|
374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
cache = {}
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
sent2_ling = F.leaky_relu(sent2_ling)
|
393 |
-
sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
|
394 |
-
sent2_ling = self.sample(sent2_mu, sent2_logvar)
|
395 |
-
cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
|
396 |
-
'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
|
397 |
-
'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
|
398 |
-
else:
|
399 |
-
cache.update({'sent2_ling': sent2_ling})
|
400 |
-
if sent1_ling is not None:
|
401 |
-
cache.update({'sent1_ling': sent1_ling})
|
402 |
-
if sent1_ling is not None:
|
403 |
-
sent1_ling = sent1_ling.view(bs, 1, -1)
|
404 |
-
sent2_ling = sent2_ling.view(bs, 1, -1)
|
405 |
-
if self.args.combine_method == 'decoder_add_first' and not generate:
|
406 |
-
sent2_ling = torch.cat([sent2_ling,
|
407 |
-
torch.repeat_interleave(torch.zeros_like(sent2_ling), batch['sentence2_input_ids'].shape[1] - 1, dim=1)], dim = 1)
|
408 |
-
else:
|
409 |
-
sent1_ling, sent2_ling = None, None
|
410 |
-
|
411 |
-
if self.args.combine_method == 'embed_concat':
|
412 |
-
enc_output.last_hidden_state = torch.cat([enc_output.last_hidden_state,
|
413 |
-
sent1_ling, sent2_ling], dim=1)
|
414 |
-
inputs_att_mask = torch.cat([inputs_att_mask,
|
415 |
-
torch.ones((bs, 2)).to(inputs_att_mask.device)], dim=1)
|
416 |
-
elif 'fusion' in self.args.combine_method:
|
417 |
-
sent1_ling = batch['sentence1_ling'].unsqueeze(1)\
|
418 |
-
.expand(-1, enc_output.last_hidden_state.shape[1], -1)
|
419 |
-
sent2_ling = batch['sentence2_ling'].unsqueeze(1)\
|
420 |
-
.expand(-1, enc_output.last_hidden_state.shape[1], -1)
|
421 |
-
if self.args.ling2_only:
|
422 |
-
combined_embedding = torch.cat([enc_output.last_hidden_state, sent2_ling], dim=2)
|
423 |
else:
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
if generate:
|
428 |
if self.args.combine_method == 'logits_add':
|
429 |
-
logits_processor = LogitsProcessorList([LogitsAdd(
|
430 |
else:
|
431 |
logits_processor = LogitsProcessorList()
|
432 |
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
output_scores=True,
|
440 |
logits_processor = logits_processor,
|
441 |
# renormalize_logits=True,
|
442 |
# do_sample=True,
|
@@ -445,68 +496,135 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
445 |
# min_new_tokens=3,
|
446 |
# repetition_penalty=1.2,
|
447 |
max_length=self.args.max_length,
|
|
|
|
|
448 |
)
|
449 |
-
|
450 |
-
cache.update({'scores': scores})
|
451 |
-
return dec_output.sequences, cache
|
452 |
-
|
453 |
-
decoder_input_ids = self._shift_right(batch['sentence2_input_ids'])
|
454 |
-
decoder_inputs_embeds = self.shared(decoder_input_ids)
|
455 |
-
decoder_att_mask = batch['sentence2_attention_mask']
|
456 |
-
labels = batch['sentence2_input_ids'].clone()
|
457 |
-
labels[labels == self.pad_token_id] = -100
|
458 |
-
|
459 |
-
if self.args.combine_method == 'decoder_concat':
|
460 |
-
if self.args.ling2_only:
|
461 |
-
decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1)
|
462 |
-
decoder_att_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
|
463 |
-
labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
|
464 |
-
labels], dim=1)
|
465 |
-
else:
|
466 |
-
decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1)
|
467 |
-
decoder_att_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
|
468 |
-
labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
|
469 |
-
labels], dim=1)
|
470 |
-
elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' :
|
471 |
-
if self.args.ling2_only:
|
472 |
-
decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sent2_ling
|
473 |
-
else:
|
474 |
-
decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling
|
475 |
|
476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
decoder_inputs_embeds=decoder_inputs_embeds,
|
478 |
-
decoder_attention_mask=
|
479 |
-
encoder_outputs=
|
480 |
-
attention_mask=
|
481 |
labels=labels,
|
|
|
|
|
482 |
)
|
483 |
if self.args.combine_method == 'logits_add':
|
484 |
-
dec_output.logits = dec_output.logits + self.args.combine_weight *
|
485 |
vocab_size = dec_output.logits.size(-1)
|
486 |
dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1))
|
487 |
return dec_output, cache
|
488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
|
490 |
-
|
491 |
-
|
492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
cache.update(cache2)
|
494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
|
496 |
def infer_with_cache(self, batch):
|
497 |
-
dec_output, _, cache = self
|
498 |
return dec_output, cache
|
499 |
|
500 |
def infer(self, batch):
|
501 |
dec_output, _ = self.infer_with_cache(batch)
|
502 |
return dec_output
|
503 |
|
504 |
-
def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer):
|
505 |
from torch.autograd import grad
|
506 |
interpolations = []
|
507 |
def line_search():
|
508 |
-
best_val = None
|
509 |
-
best_loss = None
|
510 |
eta = 1e3
|
511 |
sem_prob = 1
|
512 |
patience = 4
|
@@ -516,13 +634,11 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
516 |
new_loss, pred = get_loss(param_)
|
517 |
max_len = pred.shape[1]
|
518 |
lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1
|
519 |
-
batch
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
sem_prob = torch.sigmoid(sem_emb.compare_sem(**
|
524 |
-
# if sem_prob <= 0.1:
|
525 |
-
# patience -= 1
|
526 |
if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1:
|
527 |
return param_
|
528 |
eta *= 2.25
|
@@ -531,7 +647,7 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
531 |
|
532 |
def get_loss(param):
|
533 |
if self.args.feedback_param == 'l':
|
534 |
-
batch.update({'
|
535 |
elif self.args.feedback_param == 's':
|
536 |
batch.update({'inputs_embeds': param})
|
537 |
|
@@ -539,8 +655,9 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
539 |
logits = param
|
540 |
pred = param.argmax(-1)
|
541 |
else:
|
542 |
-
|
543 |
-
|
|
|
544 |
out = ling_disc(logits = logits)
|
545 |
probs = F.softmax(out, 1)
|
546 |
if ling_disc.quant:
|
@@ -553,13 +670,13 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
553 |
ling2_embed = self.ling_embed(batch['sentence2_ling'])
|
554 |
param = torch.nn.Parameter(ling2_embed, requires_grad = True)
|
555 |
elif self.args.feedback_param == 's':
|
556 |
-
inputs_embeds = self.shared(batch['
|
557 |
param = torch.nn.Parameter(inputs_embeds, requires_grad = True)
|
558 |
elif self.args.feedback_param == 'logits':
|
559 |
logits = self.infer_with_cache(batch)[1]['scores']
|
560 |
param = torch.nn.Parameter(logits, requires_grad = True)
|
561 |
-
|
562 |
-
while
|
563 |
loss, pred = get_loss(param)
|
564 |
pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
|
565 |
skip_special_tokens=True)[0]
|
@@ -571,6 +688,9 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
571 |
param = line_search()
|
572 |
if param is False:
|
573 |
break
|
|
|
|
|
|
|
574 |
return pred, [pred_text, interpolations]
|
575 |
|
576 |
def set_grad(module, state):
|
@@ -609,7 +729,7 @@ class LingDiscPipeline():
|
|
609 |
def __init__(self,
|
610 |
model_name="google/flan-t5-base",
|
611 |
disc_type='deberta',
|
612 |
-
disc_ckpt='/
|
613 |
# disc_type='t5',
|
614 |
# disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt',
|
615 |
):
|
@@ -629,15 +749,13 @@ def get_model(args, tokenizer, device):
|
|
629 |
ling_disc = LingDisc(args.model_name, args.disc_type, args.disc_model_path).to(device)
|
630 |
else:
|
631 |
ling_disc = None
|
632 |
-
if args.linggen_type != 'none':
|
633 |
-
ling_gen = LingGenerator(args).to(device)
|
634 |
|
635 |
-
if
|
636 |
model = EncoderDecoderVAE.from_pretrained(args.model_path, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device)
|
637 |
else:
|
638 |
-
model =
|
639 |
|
640 |
-
if args.sem_loss or args.
|
641 |
if args.sem_loss_type == 'shared':
|
642 |
sem_emb = model.encoder
|
643 |
elif args.sem_loss_type == 'dedicated':
|
@@ -649,3 +767,14 @@ def get_model(args, tokenizer, device):
|
|
649 |
|
650 |
return model, ling_disc, sem_emb
|
651 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from utils import *
|
11 |
from ling_disc import DebertaReplacedTokenizer
|
12 |
from const import *
|
13 |
+
from lingconv_t5 import LingConvT5ForConditionalGeneration
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from transformers.modeling_outputs import Seq2SeqLMOutput
|
16 |
+
from typing import Optional, Dict, Any
|
17 |
|
18 |
|
19 |
|
|
|
81 |
bs = inputs_embeds.shape[0]
|
82 |
|
83 |
if self.gen_input == 's+l':
|
84 |
+
sentence1_ling = self.ling_embed(batch['sentence1_ling'])
|
85 |
+
sentence1_ling = sentence1_ling.view(bs, 1, -1)
|
86 |
+
inputs_embeds = inputs_embeds + sentence1_ling
|
87 |
|
88 |
gen = self.gen(inputs_embeds=inputs_embeds,
|
89 |
attention_mask=inputs_att_mask).last_hidden_state.mean(1)
|
|
|
189 |
nn.Linear(hidden_dim, 1))
|
190 |
|
191 |
def compare_sem(self, **batch):
|
192 |
+
bs = batch['attention_mask'].shape[0]
|
193 |
+
ones = torch.ones((bs, 1), device=batch['attention_mask'].device)
|
194 |
sep = torch.ones((bs, 1), dtype=torch.long,
|
195 |
+
device=batch['attention_mask'].device) * self.sep_token_id
|
196 |
+
att_mask = torch.cat([batch['attention_mask'], ones, batch['sentence2_attention_mask']], dim=1)
|
197 |
if 'logits' in batch:
|
198 |
+
input_ids = torch.cat([batch['input_ids'], sep], dim=1)
|
199 |
embeds1 = self.shared(input_ids)
|
200 |
|
201 |
logits = batch['logits']
|
|
|
205 |
|
206 |
embeds2 = onehot_ @ self.shared.weight
|
207 |
embeds1_2 = torch.cat([embeds1, embeds2], dim=1)
|
208 |
+
hidden_units = super().forward(inputs_embeds=embeds1_2,
|
209 |
attention_mask=att_mask).last_hidden_state.mean(1)
|
210 |
elif 'sentence2_input_ids' in batch:
|
211 |
+
input_ids = torch.cat([batch['input_ids'], sep, batch['sentence2_input_ids']], dim=1)
|
212 |
+
hidden_units = super().forward(input_ids=input_ids,
|
213 |
attention_mask=att_mask).last_hidden_state.mean(1)
|
214 |
probs = self.projection(hidden_units)
|
215 |
return probs
|
|
|
226 |
cross_attn_head_mask=None,
|
227 |
use_cache=None,
|
228 |
encoder_outputs=None,
|
229 |
+
sentence1_ling=None,
|
230 |
+
sentence2_ling=None,
|
231 |
**kwargs
|
232 |
):
|
|
|
233 |
# cut decoder_input_ids if past is used
|
234 |
if past_key_values is not None:
|
235 |
input_ids = input_ids[:, -1:]
|
236 |
|
237 |
+
cached = use_cache and len(past_key_values) > 0
|
238 |
+
|
239 |
input_ids = input_ids.clone()
|
240 |
decoder_inputs_embeds = self.shared(input_ids)
|
241 |
|
242 |
+
if combine_method == 'layer_injection':
|
243 |
+
# For layer injection, we'll pass the ling embeddings separately
|
244 |
+
ling_embed = sentence2_ling if ling2_only else (sentence1_ling + sentence2_ling)
|
245 |
+
elif combine_method == 'decoder_add_first' and not cached:
|
246 |
+
sentence2_ling = torch.cat([sentence2_ling,
|
247 |
+
torch.repeat_interleave(torch.zeros_like(sentence2_ling), input_ids.shape[1] - 1, dim=1)], dim = 1)
|
248 |
+
elif combine_method == 'decoder_concat':
|
249 |
if ling2_only:
|
250 |
+
decoder_inputs_embeds = torch.cat([sentence2_ling, decoder_inputs_embeds], dim=1)
|
251 |
else:
|
252 |
+
decoder_inputs_embeds = torch.cat([sentence1_ling, sentence2_ling, decoder_inputs_embeds], dim=1)
|
253 |
+
|
254 |
+
if combine_method == 'decoder_add' or (not cached and combine_method == 'decoder_add_first'):
|
255 |
if ling2_only:
|
256 |
+
decoder_inputs_embeds = decoder_inputs_embeds + sentence2_ling
|
257 |
else:
|
258 |
+
decoder_inputs_embeds = decoder_inputs_embeds + sentence1_ling + sentence2_ling
|
259 |
|
260 |
return {
|
261 |
"decoder_inputs_embeds": decoder_inputs_embeds,
|
|
|
266 |
"decoder_head_mask": decoder_head_mask,
|
267 |
"cross_attn_head_mask": cross_attn_head_mask,
|
268 |
"use_cache": use_cache,
|
269 |
+
"ling_embed": ling_embed if combine_method == 'layer_injection' else None,
|
270 |
}
|
271 |
|
272 |
class LogitsAdd(LogitsProcessor):
|
273 |
+
def __init__(self, sentence2_ling):
|
274 |
super().__init__()
|
275 |
+
self.sentence2_ling = sentence2_ling
|
276 |
|
277 |
def __call__(self, input_ids, scores):
|
278 |
+
return scores + self.sentence2_ling
|
279 |
|
280 |
+
class EncoderDecoderVAE(LingConvT5ForConditionalGeneration):
|
281 |
def __init__(self, config, args, pad_token_id, sepeos_token_id, vocab_size = 32128):
|
282 |
+
if args.combine_method == 'layer_injection':
|
283 |
+
if args.injection_layer < 0 or args.injection_layer >= config.num_decoder_layers:
|
284 |
+
raise ValueError(f"Invalid injection layer: {args.injection_layer}. Must be between 0 and {config.num_decoder_layers - 1}.")
|
285 |
+
config.ling_injection_layer = args.injection_layer
|
286 |
+
config.ling_injection_type = args.injection_type # 'first' or 'all'
|
287 |
+
|
288 |
super().__init__(config)
|
289 |
+
|
290 |
self.prepare_inputs_for_generation = types.MethodType(
|
291 |
partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only),
|
292 |
self)
|
|
|
304 |
nn.ReLU(),
|
305 |
nn.Linear(hidden_dim, hidden_dim),
|
306 |
)
|
307 |
+
elif 'concat' in args.combine_method or 'add' in args.combine_method or 'layer_injection' in args.combine_method:
|
308 |
if args.ling_embed_type == 'two-layer':
|
309 |
self.ling_embed = nn.Sequential(
|
310 |
nn.Linear(args.lng_dim, args.lng_dim),
|
|
|
314 |
else:
|
315 |
self.ling_embed = nn.Linear(args.lng_dim, hidden_dim)
|
316 |
self.ling_dropout = nn.Dropout(args.ling_dropout)
|
317 |
+
self.ling_embed.apply(self._init_weights)
|
318 |
|
319 |
if args.ling_vae:
|
320 |
self.ling_mu = nn.Linear(hidden_dim, hidden_dim)
|
|
|
324 |
nn.init.xavier_uniform_(self.ling_logvar.weight)
|
325 |
|
326 |
|
327 |
+
generate_with_grad = unwrap(super().generate)
|
328 |
self.generate_with_grad = MethodType(generate_with_grad, self)
|
329 |
+
self.generate_original = super().generate
|
330 |
+
|
331 |
+
def _init_weights(self, module):
|
332 |
+
std = self.args.initializer_range
|
333 |
+
if isinstance(module, nn.Linear):
|
334 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
335 |
+
if module.bias is not None:
|
336 |
+
module.bias.data.zero_()
|
337 |
+
elif isinstance(module, nn.Embedding):
|
338 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
339 |
+
if module.padding_idx is not None:
|
340 |
+
module.weight.data[module.padding_idx].zero_()
|
341 |
|
342 |
def get_fusion_layer(self):
|
343 |
if 'fusion' in self.args.combine_method:
|
|
|
351 |
std = torch.exp(0.5 * logvar)
|
352 |
return mu + std * torch.randn_like(std)
|
353 |
|
354 |
+
def _process_ling_embeddings(self, sentence1_ling, sentence2_ling,
|
355 |
+
sentence1_ling_embed, sentence2_ling_embed, bs):
|
356 |
+
"""Helper method to process linguistic embeddings"""
|
357 |
+
cache = {}
|
358 |
+
|
359 |
+
# Process sentence1 embedding
|
360 |
+
if sentence1_ling_embed is not None:
|
361 |
+
sentence1_ling = sentence1_ling_embed
|
362 |
+
elif sentence1_ling is not None:
|
363 |
+
sentence1_ling = self.ling_embed(self.ling_dropout(sentence1_ling))
|
364 |
else:
|
365 |
+
sentence1_ling = None
|
366 |
+
|
367 |
+
# Process sentence2 embedding
|
368 |
+
if sentence2_ling_embed is not None:
|
369 |
+
sentence2_ling = sentence2_ling_embed
|
370 |
+
elif sentence2_ling is not None:
|
371 |
+
sentence2_ling = self.ling_embed(self.ling_dropout(sentence2_ling))
|
372 |
+
else:
|
373 |
+
sentence2_ling = None
|
374 |
+
|
375 |
+
# Apply VAE if configured
|
376 |
+
if self.args.ling_vae and sentence1_ling is not None and sentence2_ling is not None:
|
377 |
+
sentence1_ling = F.leaky_relu(sentence1_ling)
|
378 |
+
sent1_mu, sent1_logvar = self.ling_mu(sentence1_ling), self.ling_logvar(sentence1_ling)
|
379 |
+
sentence1_ling = self.sample(sent1_mu, sent1_logvar)
|
380 |
+
|
381 |
+
sentence2_ling = F.leaky_relu(sentence2_ling)
|
382 |
+
sent2_mu, sent2_logvar = self.ling_mu(sentence2_ling), self.ling_logvar(sentence2_ling)
|
383 |
+
sentence2_ling = self.sample(sent2_mu, sent2_logvar)
|
384 |
+
|
385 |
+
cache.update({
|
386 |
+
'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
|
387 |
+
'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
|
388 |
+
'sentence1_ling': sentence1_ling, 'sentence2_ling': sentence2_ling
|
389 |
+
})
|
390 |
+
else:
|
391 |
+
if sentence2_ling is not None:
|
392 |
+
cache['sentence2_ling'] = sentence2_ling
|
393 |
+
if sentence1_ling is not None:
|
394 |
+
cache['sentence1_ling'] = sentence1_ling
|
395 |
+
|
396 |
+
# Reshape embeddings
|
397 |
+
if sentence1_ling is not None:
|
398 |
+
sentence1_ling = sentence1_ling.view(bs, 1, -1)
|
399 |
+
if sentence2_ling is not None:
|
400 |
+
sentence2_ling = sentence2_ling.view(bs, 1, -1)
|
401 |
+
|
402 |
+
return sentence1_ling, sentence2_ling, cache
|
403 |
+
|
404 |
+
def encode(self,
|
405 |
+
input_ids=None,
|
406 |
+
attention_mask=None,
|
407 |
+
sentence1_ling=None,
|
408 |
+
sentence2_ling=None,
|
409 |
+
sentence1_ling_embed=None,
|
410 |
+
sentence2_ling_embed=None,
|
411 |
+
inputs_embeds=None,
|
412 |
+
):
|
413 |
+
if inputs_embeds is None:
|
414 |
+
inputs_embeds = self.shared(input_ids)
|
415 |
+
inputs_att_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids)
|
416 |
bs = inputs_embeds.shape[0]
|
417 |
+
|
418 |
if self.args.combine_method in ('input_concat', 'input_add'):
|
419 |
+
sentence1_ling, sentence2_ling, cache = self._process_ling_embeddings(
|
420 |
+
sentence1_ling, sentence2_ling,
|
421 |
+
sentence1_ling_embed, sentence2_ling_embed, bs
|
422 |
+
)
|
423 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
if self.args.combine_method == 'input_concat':
|
425 |
if self.args.ling2_only:
|
426 |
+
inputs_embeds = torch.cat([inputs_embeds, sentence2_ling], dim=1)
|
427 |
inputs_att_mask = torch.cat([inputs_att_mask,
|
428 |
torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1)
|
429 |
else:
|
430 |
+
inputs_embeds = torch.cat([inputs_embeds, sentence1_ling, sentence2_ling], dim=1)
|
431 |
inputs_att_mask = torch.cat([inputs_att_mask,
|
432 |
torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1)
|
433 |
elif self.args.combine_method == 'input_add':
|
434 |
if self.args.ling2_only:
|
435 |
+
inputs_embeds = inputs_embeds + sentence2_ling
|
436 |
else:
|
437 |
+
inputs_embeds = inputs_embeds + sentence1_ling + sentence2_ling
|
438 |
+
else:
|
439 |
+
cache = {}
|
440 |
+
|
441 |
return self.encoder(inputs_embeds=inputs_embeds,
|
442 |
attention_mask=inputs_att_mask), inputs_att_mask, cache
|
443 |
|
444 |
+
def decode(self,
|
445 |
+
sentence2_input_ids=None,
|
446 |
+
sentence1_ling=None,
|
447 |
+
sentence2_ling=None,
|
448 |
+
encoder_outputs=None,
|
449 |
+
encoder_attention_mask=None,
|
450 |
+
decoder_inputs_embeds=None,
|
451 |
+
decoder_attention_mask=None,
|
452 |
+
generate=False,
|
453 |
+
sentence1_ling_embed=None,
|
454 |
+
sentence2_ling_embed=None,
|
455 |
+
ling_embed=None,
|
456 |
+
generate_with_grad=False,
|
457 |
+
**kwargs
|
458 |
+
):
|
459 |
+
bs = encoder_outputs[0].shape[0]
|
460 |
cache = {}
|
461 |
+
|
462 |
+
if decoder_inputs_embeds is None:
|
463 |
+
if self.args.combine_method in ('embed_concat', 'decoder_concat', 'decoder_add',
|
464 |
+
'logits_add', 'decoder_add_first', 'layer_injection'):
|
465 |
+
sentence1_ling, sentence2_ling, cache = self._process_ling_embeddings(
|
466 |
+
sentence1_ling, sentence2_ling,
|
467 |
+
sentence1_ling_embed, sentence2_ling_embed, bs
|
468 |
+
)
|
469 |
+
|
470 |
+
if (self.args.combine_method == 'decoder_add_first' or
|
471 |
+
(self.args.combine_method == 'layer_injection' and
|
472 |
+
self.args.injection_type == 'first')) and not generate:
|
473 |
+
sentence2_ling = torch.cat([sentence2_ling,
|
474 |
+
torch.repeat_interleave(torch.zeros_like(sentence2_ling),
|
475 |
+
sentence2_input_ids.shape[1] - 1, dim=1)], dim = 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
else:
|
477 |
+
sentence1_ling, sentence2_ling = None, None
|
478 |
+
|
|
|
479 |
if generate:
|
480 |
if self.args.combine_method == 'logits_add':
|
481 |
+
logits_processor = LogitsProcessorList([LogitsAdd(sentence2_ling.view(bs, -1))])
|
482 |
else:
|
483 |
logits_processor = LogitsProcessorList()
|
484 |
|
485 |
+
generate_fn = self.generate_with_grad if generate_with_grad else self.generate_original
|
486 |
+
dec_output = generate_fn(
|
487 |
+
attention_mask=encoder_attention_mask,
|
488 |
+
encoder_outputs=encoder_outputs,
|
489 |
+
sentence1_ling=sentence1_ling,
|
490 |
+
sentence2_ling=sentence2_ling,
|
|
|
491 |
logits_processor = logits_processor,
|
492 |
# renormalize_logits=True,
|
493 |
# do_sample=True,
|
|
|
496 |
# min_new_tokens=3,
|
497 |
# repetition_penalty=1.2,
|
498 |
max_length=self.args.max_length,
|
499 |
+
use_cache=True,
|
500 |
+
**kwargs
|
501 |
)
|
502 |
+
return dec_output, cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
|
504 |
+
if sentence2_input_ids is not None:
|
505 |
+
labels = sentence2_input_ids.clone()
|
506 |
+
labels[labels == self.pad_token_id] = -100
|
507 |
+
else:
|
508 |
+
labels = None
|
509 |
+
|
510 |
+
if decoder_inputs_embeds is None:
|
511 |
+
decoder_input_ids = self._shift_right(sentence2_input_ids)
|
512 |
+
decoder_inputs_embeds = self.shared(decoder_input_ids)
|
513 |
+
|
514 |
+
if self.args.combine_method == 'decoder_concat':
|
515 |
+
if self.args.ling2_only:
|
516 |
+
decoder_inputs_embeds = torch.cat([sentence2_ling, decoder_inputs_embeds], dim=1)
|
517 |
+
decoder_attention_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_attention_mask], dim=1)
|
518 |
+
labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
|
519 |
+
labels], dim=1)
|
520 |
+
else:
|
521 |
+
decoder_inputs_embeds = torch.cat([sentence1_ling, sentence2_ling, decoder_inputs_embeds], dim=1)
|
522 |
+
decoder_attention_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_attention_mask], dim=1)
|
523 |
+
labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
|
524 |
+
labels], dim=1)
|
525 |
+
elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' :
|
526 |
+
if self.args.ling2_only:
|
527 |
+
decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sentence2_ling
|
528 |
+
else:
|
529 |
+
decoder_inputs_embeds = decoder_inputs_embeds + sentence1_ling + sentence2_ling
|
530 |
+
|
531 |
+
if ling_embed is None:
|
532 |
+
ling_embed = sentence2_ling
|
533 |
+
|
534 |
+
dec_output = super().forward(
|
535 |
decoder_inputs_embeds=decoder_inputs_embeds,
|
536 |
+
decoder_attention_mask=decoder_attention_mask,
|
537 |
+
encoder_outputs=encoder_outputs,
|
538 |
+
attention_mask=encoder_attention_mask,
|
539 |
labels=labels,
|
540 |
+
ling_embed=ling_embed,
|
541 |
+
**kwargs
|
542 |
)
|
543 |
if self.args.combine_method == 'logits_add':
|
544 |
+
dec_output.logits = dec_output.logits + self.args.combine_weight * sentence2_ling
|
545 |
vocab_size = dec_output.logits.size(-1)
|
546 |
dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1))
|
547 |
return dec_output, cache
|
548 |
|
549 |
+
def generate(self, *args, **kwargs):
|
550 |
+
return self.forward(*args, **kwargs, generate=True)
|
551 |
+
|
552 |
+
|
553 |
+
def forward(self,
|
554 |
+
input_ids=None,
|
555 |
+
attention_mask=None,
|
556 |
+
labels=None,
|
557 |
+
decoder_attention_mask=None,
|
558 |
+
decoder_inputs_embeds=None,
|
559 |
+
sentence1_ling=None,
|
560 |
+
sentence2_ling=None,
|
561 |
+
sentence1_ling_embed=None,
|
562 |
+
sentence2_ling_embed=None,
|
563 |
+
inputs_embeds=None,
|
564 |
+
generate=False,
|
565 |
+
encoder_outputs=None,
|
566 |
+
encoder_attention_mask=None,
|
567 |
+
ling_embed=None,
|
568 |
+
generate_with_grad=False,
|
569 |
+
**kwargs):
|
570 |
|
571 |
+
cache = {}
|
572 |
+
if encoder_outputs is None:
|
573 |
+
encoder_outputs, encoder_attention_mask, cache = self.encode(
|
574 |
+
input_ids=input_ids,
|
575 |
+
attention_mask=attention_mask,
|
576 |
+
sentence1_ling=sentence1_ling,
|
577 |
+
sentence2_ling=sentence2_ling,
|
578 |
+
sentence1_ling_embed=sentence1_ling_embed,
|
579 |
+
sentence2_ling_embed=sentence2_ling_embed,
|
580 |
+
inputs_embeds=inputs_embeds
|
581 |
+
)
|
582 |
+
|
583 |
+
dec_output, cache2 = self.decode(
|
584 |
+
sentence2_input_ids=labels,
|
585 |
+
sentence1_ling=sentence1_ling,
|
586 |
+
sentence2_ling=sentence2_ling,
|
587 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
588 |
+
decoder_attention_mask=decoder_attention_mask,
|
589 |
+
encoder_outputs=encoder_outputs,
|
590 |
+
encoder_attention_mask=encoder_attention_mask,
|
591 |
+
generate=generate,
|
592 |
+
sentence1_ling_embed=sentence1_ling_embed,
|
593 |
+
sentence2_ling_embed=sentence2_ling_embed,
|
594 |
+
ling_embed=ling_embed,
|
595 |
+
generate_with_grad=generate_with_grad,
|
596 |
+
**kwargs
|
597 |
+
)
|
598 |
+
|
599 |
cache.update(cache2)
|
600 |
+
if generate:
|
601 |
+
return dec_output
|
602 |
+
else:
|
603 |
+
return MySeq2SeqLMOutput(
|
604 |
+
loss=dec_output.loss,
|
605 |
+
logits=dec_output.logits,
|
606 |
+
past_key_values=dec_output.past_key_values,
|
607 |
+
decoder_hidden_states=dec_output.decoder_hidden_states,
|
608 |
+
decoder_attentions=dec_output.decoder_attentions,
|
609 |
+
cross_attentions=dec_output.cross_attentions,
|
610 |
+
encoder_last_hidden_state=encoder_outputs[0],
|
611 |
+
encoder_hidden_states=getattr(encoder_outputs, 'hidden_states', None),
|
612 |
+
encoder_attentions=getattr(encoder_outputs, 'attentions', None),
|
613 |
+
cache=cache
|
614 |
+
)
|
615 |
|
616 |
def infer_with_cache(self, batch):
|
617 |
+
dec_output, _, cache = self(batch, generate = True)
|
618 |
return dec_output, cache
|
619 |
|
620 |
def infer(self, batch):
|
621 |
dec_output, _ = self.infer_with_cache(batch)
|
622 |
return dec_output
|
623 |
|
624 |
+
def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer, progress=None):
|
625 |
from torch.autograd import grad
|
626 |
interpolations = []
|
627 |
def line_search():
|
|
|
|
|
628 |
eta = 1e3
|
629 |
sem_prob = 1
|
630 |
patience = 4
|
|
|
634 |
new_loss, pred = get_loss(param_)
|
635 |
max_len = pred.shape[1]
|
636 |
lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1
|
637 |
+
sem_batch = {**batch,
|
638 |
+
'sentence2_input_ids': pred,
|
639 |
+
'sentence2_attention_mask': sequence_mask(lens, max_len = max_len)
|
640 |
+
}
|
641 |
+
sem_prob = torch.sigmoid(sem_emb.compare_sem(**sem_batch)).item()
|
|
|
|
|
642 |
if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1:
|
643 |
return param_
|
644 |
eta *= 2.25
|
|
|
647 |
|
648 |
def get_loss(param):
|
649 |
if self.args.feedback_param == 'l':
|
650 |
+
batch.update({'sentence2_ling_embed': param})
|
651 |
elif self.args.feedback_param == 's':
|
652 |
batch.update({'inputs_embeds': param})
|
653 |
|
|
|
655 |
logits = param
|
656 |
pred = param.argmax(-1)
|
657 |
else:
|
658 |
+
outputs = self.generate(**batch, output_scores=True, return_dict_in_generate=True, generate_with_grad=True)
|
659 |
+
pred = outputs.sequences
|
660 |
+
logits = torch.stack(outputs.scores, dim=1)
|
661 |
out = ling_disc(logits = logits)
|
662 |
probs = F.softmax(out, 1)
|
663 |
if ling_disc.quant:
|
|
|
670 |
ling2_embed = self.ling_embed(batch['sentence2_ling'])
|
671 |
param = torch.nn.Parameter(ling2_embed, requires_grad = True)
|
672 |
elif self.args.feedback_param == 's':
|
673 |
+
inputs_embeds = self.shared(batch['input_ids'])
|
674 |
param = torch.nn.Parameter(inputs_embeds, requires_grad = True)
|
675 |
elif self.args.feedback_param == 'logits':
|
676 |
logits = self.infer_with_cache(batch)[1]['scores']
|
677 |
param = torch.nn.Parameter(logits, requires_grad = True)
|
678 |
+
num_iter = 0
|
679 |
+
while num_iter < 3:
|
680 |
loss, pred = get_loss(param)
|
681 |
pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
|
682 |
skip_special_tokens=True)[0]
|
|
|
688 |
param = line_search()
|
689 |
if param is False:
|
690 |
break
|
691 |
+
num_iter += 1
|
692 |
+
if progress is not None:
|
693 |
+
progress((num_iter, None), unit='intermediate paraphrase generated.')
|
694 |
return pred, [pred_text, interpolations]
|
695 |
|
696 |
def set_grad(module, state):
|
|
|
729 |
def __init__(self,
|
730 |
model_name="google/flan-t5-base",
|
731 |
disc_type='deberta',
|
732 |
+
disc_ckpt='mohdelgaar/lingconv-discriminator',
|
733 |
# disc_type='t5',
|
734 |
# disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt',
|
735 |
):
|
|
|
749 |
ling_disc = LingDisc(args.model_name, args.disc_type, args.disc_model_path).to(device)
|
750 |
else:
|
751 |
ling_disc = None
|
|
|
|
|
752 |
|
753 |
+
if args.model_path:
|
754 |
model = EncoderDecoderVAE.from_pretrained(args.model_path, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device)
|
755 |
else:
|
756 |
+
model = EncoderDecoderVAE.from_pretrained(args.model_name, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device)
|
757 |
|
758 |
+
if args.sem_loss or args.model_path:
|
759 |
if args.sem_loss_type == 'shared':
|
760 |
sem_emb = model.encoder
|
761 |
elif args.sem_loss_type == 'dedicated':
|
|
|
767 |
|
768 |
return model, ling_disc, sem_emb
|
769 |
|
770 |
+
@dataclass
|
771 |
+
class MySeq2SeqLMOutput(Seq2SeqLMOutput):
|
772 |
+
"""
|
773 |
+
Extends Seq2SeqLMOutput to include a cache dictionary for additional model outputs.
|
774 |
+
|
775 |
+
Args:
|
776 |
+
cache (`Dict[str, Any]`):
|
777 |
+
Dictionary containing additional model outputs like linguistic features,
|
778 |
+
VAE parameters, scores, etc.
|
779 |
+
"""
|
780 |
+
cache: Optional[Dict[str, Any]] = None
|
options.py
CHANGED
@@ -1,16 +1,28 @@
|
|
1 |
-
import os, json
|
2 |
import argparse
|
3 |
-
import numpy as np
|
4 |
from datetime import datetime
|
5 |
from const import lftkplus_names
|
|
|
6 |
from copy import deepcopy
|
|
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def parse_args(ckpt=None):
|
10 |
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
11 |
parser.add_argument('--data_dir', default='/data/mohamed/data')
|
12 |
parser.add_argument('--data', default='ling_conversion')
|
13 |
-
parser.add_argument('--data_sources')
|
14 |
parser.add_argument('--data_type', default='text')
|
15 |
parser.add_argument('--aim_repo', default='/data/mohamed/')
|
16 |
parser.add_argument('--ckpt_dir', default='/data/mohamed/checkpoints')
|
@@ -25,7 +37,7 @@ def parse_args(ckpt=None):
|
|
25 |
parser.add_argument('--sem_loss_tao', default=0.5, type=float)
|
26 |
parser.add_argument('--sem_loss_eps', default=1, type=float)
|
27 |
parser.add_argument('--ckpt')
|
28 |
-
parser.add_argument('--disc_ckpt')
|
29 |
parser.add_argument('--sem_ckpt')
|
30 |
parser.add_argument('--lng_ids')
|
31 |
parser.add_argument('--lng_ids_idx', type=int)
|
@@ -36,30 +48,34 @@ def parse_args(ckpt=None):
|
|
36 |
parser.add_argument('--sem_path', default="mohdelgaar/lingconv-semantic-classifier")
|
37 |
parser.add_argument('--sem_model_path', default="mohdelgaar/lingconv-semantic-classifier")
|
38 |
parser.add_argument('--disc_model_path', default="mohdelgaar/lingconv-discriminator")
|
39 |
-
parser.add_argument('--disc_type', default="
|
40 |
-
parser.add_argument('--aim_exp', default='
|
41 |
parser.add_argument('--sem_loss_type', default='dedicated')
|
42 |
-
parser.add_argument('--combine_method', default='
|
|
|
|
|
43 |
parser.add_argument('--train_log', type=int, default=200)
|
44 |
-
parser.add_argument('--val_log', type=int, default=
|
45 |
-
parser.add_argument('--
|
46 |
-
parser.add_argument('--
|
47 |
-
parser.add_argument('--
|
48 |
-
parser.add_argument('--
|
|
|
49 |
parser.add_argument('--hidden_dim', type=int, default=500)
|
50 |
parser.add_argument('--latent_dim', type=int, default=150)
|
51 |
parser.add_argument('--lng_dim', type=int, default=40)
|
52 |
-
parser.add_argument('--disc_lng_dim', type=int)
|
53 |
parser.add_argument('--use_lora', action='store_true')
|
54 |
parser.add_argument('--lora_r', type=int, default=64)
|
55 |
parser.add_argument('--gpu', type=str, default='0')
|
56 |
-
parser.add_argument('--epochs', type=int, default=
|
57 |
parser.add_argument('--grad_accumulation', type=int, default=1)
|
58 |
parser.add_argument('--n_ica', type=int, default=10)
|
59 |
parser.add_argument('--max_length', type=int, default=200)
|
60 |
parser.add_argument('--total_steps', type=int)
|
61 |
parser.add_argument('--kld_const', type=float, default=1)
|
62 |
-
parser.add_argument('--lr', type=float, default=1e-
|
|
|
63 |
parser.add_argument('--kl_weight', type=float, default=1e-1)
|
64 |
parser.add_argument('--weight_decay', type=float, default=1e-2)
|
65 |
parser.add_argument('--ling_dropout', type=float, default=0.1)
|
@@ -71,12 +87,12 @@ def parse_args(ckpt=None):
|
|
71 |
parser.add_argument('--pretrain_disc', action='store_true')
|
72 |
parser.add_argument('--linggen_type', default='none')
|
73 |
parser.add_argument('--linggen_input', default='s+l')
|
74 |
-
parser.add_argument(
|
75 |
parser.add_argument('--ling_vae', action='store_true')
|
76 |
parser.add_argument('--process_lingpred', action='store_true')
|
77 |
parser.add_argument('--fudge_lambda', type=float, default=1.0)
|
78 |
parser.add_argument('--use_lingpred', action='store_true')
|
79 |
-
parser.add_argument('--ling2_only', action='store_true')
|
80 |
parser.add_argument('--cycle_loss', action='store_true')
|
81 |
parser.add_argument('--disc_loss', action='store_true')
|
82 |
parser.add_argument('--sem_loss', action='store_true')
|
@@ -96,19 +112,36 @@ def parse_args(ckpt=None):
|
|
96 |
parser.add_argument('--quant_nbins', type=int, default=20)
|
97 |
parser.add_argument('--src_lng', default = 'ling')
|
98 |
parser.add_argument('--to_restore', nargs='+', default=[])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
# args = parser.parse_args()
|
100 |
args, unknown = parser.parse_known_args()
|
101 |
args.name = f'{datetime.now().strftime("%m%d_%H-%M-%S")}-{args.data}-{args.combine_method}'
|
102 |
|
103 |
major_arg = args.major_arg
|
104 |
to_restore = [
|
|
|
|
|
|
|
105 |
] + args.to_restore
|
106 |
to_restore = {k: args.__dict__[k] for k in to_restore}
|
107 |
|
108 |
if not args.disc_loss or args.disc_ckpt:
|
109 |
args.disc_steps = 0
|
110 |
|
111 |
-
if args.data_sources
|
|
|
|
|
112 |
args.data_sources = args.data_sources.split(',')
|
113 |
|
114 |
if ckpt is not None:
|
@@ -120,13 +153,17 @@ def parse_args(ckpt=None):
|
|
120 |
ckpts = args.ckpt.split(',')
|
121 |
args_list = [deepcopy(args) for _ in range(len(ckpts))]
|
122 |
for i in range(len(ckpts)):
|
123 |
-
args_path = ckpts[i].replace('_best', '').replace('.pt', '.json'
|
124 |
with open(args_path) as f:
|
125 |
args_list[i].__dict__.update(json.load(f))
|
126 |
args_list[i].__dict__.update(to_restore)
|
127 |
args_list[i].ckpt = ckpts[i]
|
128 |
else:
|
129 |
-
|
|
|
|
|
|
|
|
|
130 |
ckpt = args.ckpt
|
131 |
with open(args_path) as f:
|
132 |
args.__dict__.update(json.load(f))
|
|
|
|
|
1 |
import argparse
|
|
|
2 |
from datetime import datetime
|
3 |
from const import lftkplus_names
|
4 |
+
import os, json
|
5 |
from copy import deepcopy
|
6 |
+
import numpy as np
|
7 |
|
8 |
+
def str2bool(v):
|
9 |
+
if isinstance(v, bool):
|
10 |
+
return v
|
11 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
12 |
+
return True
|
13 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
14 |
+
return False
|
15 |
+
else:
|
16 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
17 |
|
18 |
def parse_args(ckpt=None):
|
19 |
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('--do_train', action='store_true')
|
21 |
+
parser.add_argument('--do_eval', action='store_true')
|
22 |
+
parser.add_argument('--do_predict', action='store_true')
|
23 |
parser.add_argument('--data_dir', default='/data/mohamed/data')
|
24 |
parser.add_argument('--data', default='ling_conversion')
|
25 |
+
parser.add_argument('--data_sources', default='qqp,mrpc,stsb')
|
26 |
parser.add_argument('--data_type', default='text')
|
27 |
parser.add_argument('--aim_repo', default='/data/mohamed/')
|
28 |
parser.add_argument('--ckpt_dir', default='/data/mohamed/checkpoints')
|
|
|
37 |
parser.add_argument('--sem_loss_tao', default=0.5, type=float)
|
38 |
parser.add_argument('--sem_loss_eps', default=1, type=float)
|
39 |
parser.add_argument('--ckpt')
|
40 |
+
parser.add_argument('--disc_ckpt', default='mohdelgaar/lingconv-discriminator')
|
41 |
parser.add_argument('--sem_ckpt')
|
42 |
parser.add_argument('--lng_ids')
|
43 |
parser.add_argument('--lng_ids_idx', type=int)
|
|
|
48 |
parser.add_argument('--sem_path', default="mohdelgaar/lingconv-semantic-classifier")
|
49 |
parser.add_argument('--sem_model_path', default="mohdelgaar/lingconv-semantic-classifier")
|
50 |
parser.add_argument('--disc_model_path', default="mohdelgaar/lingconv-discriminator")
|
51 |
+
parser.add_argument('--disc_type', default="deberta")
|
52 |
+
parser.add_argument('--aim_exp', default='lingconv-1201')
|
53 |
parser.add_argument('--sem_loss_type', default='dedicated')
|
54 |
+
parser.add_argument('--combine_method', default='decoder_add_first')
|
55 |
+
parser.add_argument('--injection_type', default='first')
|
56 |
+
parser.add_argument('--injection_layer', type=int, default=1)
|
57 |
parser.add_argument('--train_log', type=int, default=200)
|
58 |
+
parser.add_argument('--val_log', type=int, default=1000)
|
59 |
+
parser.add_argument('--warmup_steps', type=int, default=1000)
|
60 |
+
parser.add_argument('--batch_size', type=int, default=16)
|
61 |
+
parser.add_argument('--eval_batch_size', type=int, default=256)
|
62 |
+
parser.add_argument('--max_eval_samples', type=int, default=3000)
|
63 |
+
parser.add_argument('--test_batch_size', type=int, default=256)
|
64 |
parser.add_argument('--hidden_dim', type=int, default=500)
|
65 |
parser.add_argument('--latent_dim', type=int, default=150)
|
66 |
parser.add_argument('--lng_dim', type=int, default=40)
|
67 |
+
parser.add_argument('--disc_lng_dim', type=int, default=40)
|
68 |
parser.add_argument('--use_lora', action='store_true')
|
69 |
parser.add_argument('--lora_r', type=int, default=64)
|
70 |
parser.add_argument('--gpu', type=str, default='0')
|
71 |
+
parser.add_argument('--epochs', type=int, default=2)
|
72 |
parser.add_argument('--grad_accumulation', type=int, default=1)
|
73 |
parser.add_argument('--n_ica', type=int, default=10)
|
74 |
parser.add_argument('--max_length', type=int, default=200)
|
75 |
parser.add_argument('--total_steps', type=int)
|
76 |
parser.add_argument('--kld_const', type=float, default=1)
|
77 |
+
parser.add_argument('--lr', type=float, default=1e-3)
|
78 |
+
parser.add_argument('--initializer_range', type=float, default=0.02)
|
79 |
parser.add_argument('--kl_weight', type=float, default=1e-1)
|
80 |
parser.add_argument('--weight_decay', type=float, default=1e-2)
|
81 |
parser.add_argument('--ling_dropout', type=float, default=0.1)
|
|
|
87 |
parser.add_argument('--pretrain_disc', action='store_true')
|
88 |
parser.add_argument('--linggen_type', default='none')
|
89 |
parser.add_argument('--linggen_input', default='s+l')
|
90 |
+
parser.add_argument("--aug_same", type=str2bool, nargs='?', const=True, default=False)
|
91 |
parser.add_argument('--ling_vae', action='store_true')
|
92 |
parser.add_argument('--process_lingpred', action='store_true')
|
93 |
parser.add_argument('--fudge_lambda', type=float, default=1.0)
|
94 |
parser.add_argument('--use_lingpred', action='store_true')
|
95 |
+
parser.add_argument('--ling2_only', action='store_true', default=True)
|
96 |
parser.add_argument('--cycle_loss', action='store_true')
|
97 |
parser.add_argument('--disc_loss', action='store_true')
|
98 |
parser.add_argument('--sem_loss', action='store_true')
|
|
|
112 |
parser.add_argument('--quant_nbins', type=int, default=20)
|
113 |
parser.add_argument('--src_lng', default = 'ling')
|
114 |
parser.add_argument('--to_restore', nargs='+', default=[])
|
115 |
+
parser.add_argument('--freeze_lm', action='store_true',
|
116 |
+
help='Freeze the language model and only train the linguistic embedding')
|
117 |
+
parser.add_argument('--prepend_prompt', action='store_true',
|
118 |
+
help='Prepend "generate a paraphrase: " to input text')
|
119 |
+
parser.add_argument('--prompt_text', type=str, default="generate a paraphrase: ",
|
120 |
+
help='Text to prepend to input if prepend_prompt is True')
|
121 |
+
parser.add_argument('--do_imputation', action='store_true',
|
122 |
+
help='Whether to perform imputation on linguistic features')
|
123 |
+
parser.add_argument('--imputation_percentage', type=int, default=20,
|
124 |
+
help='Percentage of features to impute (20, 40, 60, 80)')
|
125 |
+
parser.add_argument('--imputation_seed', type=int, default=0,
|
126 |
+
help='Seed for imputation set selection (0, 1, 2)')
|
127 |
# args = parser.parse_args()
|
128 |
args, unknown = parser.parse_known_args()
|
129 |
args.name = f'{datetime.now().strftime("%m%d_%H-%M-%S")}-{args.data}-{args.combine_method}'
|
130 |
|
131 |
major_arg = args.major_arg
|
132 |
to_restore = [
|
133 |
+
'total_steps','major_arg','gpu','demo', 'eval_only', 'save_predict', 'predict_fn', 'fudge', 'predict_with_feedback',
|
134 |
+
'feedback_param', 'fb_log', 'data_dir', 'data', 'disc_ckpt', 'disc_type', 'sem_ckpt', 'fudge_lambda', 'eval_batch_size', 'test_batch_size', 'max_eval_samples',
|
135 |
+
'do_train', 'do_eval', 'do_predict',
|
136 |
] + args.to_restore
|
137 |
to_restore = {k: args.__dict__[k] for k in to_restore}
|
138 |
|
139 |
if not args.disc_loss or args.disc_ckpt:
|
140 |
args.disc_steps = 0
|
141 |
|
142 |
+
if args.data_sources == 'all':
|
143 |
+
args.data_sources = None
|
144 |
+
elif args.data_sources is not None:
|
145 |
args.data_sources = args.data_sources.split(',')
|
146 |
|
147 |
if ckpt is not None:
|
|
|
153 |
ckpts = args.ckpt.split(',')
|
154 |
args_list = [deepcopy(args) for _ in range(len(ckpts))]
|
155 |
for i in range(len(ckpts)):
|
156 |
+
args_path = ckpts[i].replace('_best', '').replace('.pt', '') + '.json'
|
157 |
with open(args_path) as f:
|
158 |
args_list[i].__dict__.update(json.load(f))
|
159 |
args_list[i].__dict__.update(to_restore)
|
160 |
args_list[i].ckpt = ckpts[i]
|
161 |
else:
|
162 |
+
args.ckpt = args.ckpt.rstrip('/')
|
163 |
+
if 'checkpoint-' in args.ckpt:
|
164 |
+
args_path = os.path.dirname(args.ckpt) + '.json'
|
165 |
+
else:
|
166 |
+
args_path = args.ckpt.replace('.pt', '') + '.json'
|
167 |
ckpt = args.ckpt
|
168 |
with open(args_path) as f:
|
169 |
args.__dict__.update(json.load(f))
|