LingConv / model.py
mohdelgaar's picture
move to hub
0bdbfac
raw
history blame
28.6 kB
import types
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoModel, LogitsProcessor, LogitsProcessorList, PreTrainedModel
from functools import partial
from undecorate import unwrap
from types import MethodType
from utils import *
from ling_disc import DebertaReplacedTokenizer
from const import *
def vae_sample(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
class VAE(nn.Module):
def __init__(self, args):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(args.input_dim, args.hidden_dim),
nn.ReLU(),
nn.Linear(args.hidden_dim, args.hidden_dim),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.Linear(args.latent_dim, args.hidden_dim),
nn.ReLU(),
nn.Linear(args.hidden_dim, args.hidden_dim),
nn.ReLU(),
nn.Linear(args.hidden_dim, args.input_dim),
)
self.fc_mu = nn.Linear(args.hidden_dim, args.latent_dim)
self.fc_var = nn.Linear(args.hidden_dim, args.latent_dim)
def forward(self, x):
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_var(h)
x = vae_sample(mu, logvar)
o = self.decoder(x)
return o, (mu, logvar)
class LingGenerator(nn.Module):
def __init__(self, args, hidden_dim=1000):
super().__init__()
self.gen = T5EncoderModel.from_pretrained('google/flan-t5-small')
self.hidden_size = self.gen.config.d_model
self.ling_embed = nn.Linear(args.lng_dim, self.hidden_size)
# self.gen = nn.Sequential(
# nn.Linear(args.lng_dim, 2*hidden_dim),
# nn.ReLU(),
# nn.BatchNorm1d(2*hidden_dim),
# nn.Linear(2*hidden_dim, 2*hidden_dim),
# nn.ReLU(),
# nn.BatchNorm1d(2*hidden_dim),
# nn.Linear(2*hidden_dim, hidden_dim),
# nn.ReLU(),
# )
self.gen_type = args.linggen_type
self.gen_input = args.linggen_input
if self.gen_type == 'vae':
self.gen_mu = nn.Linear(hidden_dim, args.lng_dim)
self.gen_logvar = nn.Linear(hidden_dim, args.lng_dim)
elif self.gen_type == 'det':
self.projection = nn.Linear(self.hidden_size, args.lng_dim)
def forward(self, batch):
inputs_embeds = self.gen.shared(batch['sentence1_input_ids'])
inputs_att_mask = batch['sentence1_attention_mask']
bs = inputs_embeds.shape[0]
if self.gen_input == 's+l':
sent1_ling = self.ling_embed(batch['sentence1_ling'])
sent1_ling = sent1_ling.view(bs, 1, -1)
inputs_embeds = inputs_embeds + sent1_ling
gen = self.gen(inputs_embeds=inputs_embeds,
attention_mask=inputs_att_mask).last_hidden_state.mean(1)
# gen = self.gen(batch['sentence1_ling'])
cache = {}
if self.gen_type == 'vae':
mu = self.gen_mu(gen)
logvar = self.gen_logvar(gen)
output = vae_sample(mu, logvar)
cache['linggen_mu'] = mu
cache['linggen_logvar'] = logvar
elif self.gen_type == 'det':
output = self.projection(gen)
return output, cache
class LingDisc(nn.Module):
def __init__(self,
model_name,
disc_type,
disc_ckpt,
lng_dim=40,
quant_nbins=1,
disc_lng_dim=None,
lng_ids=None,
**kwargs):
super().__init__()
if disc_type == 't5':
self.encoder = T5EncoderModel.from_pretrained(model_name)
hidden_dim = self.encoder.config.d_model
self.dropout = nn.Dropout(0.2)
self.lng_dim = disc_lng_dim if disc_lng_dim else lng_dim
self.quant = quant_nbins > 1
self.quant = False
if self.quant:
self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim * quant_nbins)
else:
self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim)
lng_ids = torch.tensor(lng_ids) if lng_ids is not None else None
# from const import used_indices
# lng_ids = torch.tensor(used_indices)
self.register_buffer('lng_ids', lng_ids)
elif disc_type == 'deberta':
self.encoder= DebertaReplacedTokenizer.from_pretrained(
pretrained_model_name_or_path=disc_ckpt,
tok_model_name = model_name,
problem_type='regression', num_labels=40)
self.quant = False
self.disc_type = disc_type
def forward(self, **batch):
if not 'attention_mask' in batch:
if 'input_ids' in batch:
att_mask = torch.ones_like(batch['input_ids'])
else:
att_mask = torch.ones_like(batch['logits'])[:,:,0]
else:
att_mask = batch['attention_mask']
if 'input_ids' in batch:
enc_output = self.encoder(input_ids=batch['input_ids'],
attention_mask=att_mask)
elif 'logits' in batch:
logits = batch['logits']
scores = F.softmax(logits, dim = -1)
onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device)
onehot_ = scores - scores.detach() + onehot
embed_layer = self.encoder.get_input_embeddings()
if isinstance(embed_layer, nn.Sequential):
for i, module in enumerate(embed_layer):
if i == 0:
embeds = torch.matmul(onehot_, module.weight)
else:
embeds = module(embeds)
else:
embeds = onehot_ @ embed_layer.weight
embeds = torch.matmul(onehot_, embed_layer.weight)
enc_output = self.encoder(inputs_embeds=embeds,
attention_mask=att_mask)
if self.disc_type == 't5':
sent_emb = self.dropout(enc_output.last_hidden_state.mean(1))
bs = sent_emb.shape[0]
output = self.ling_classifier(sent_emb)
if self.quant:
output = output.reshape(bs, -1, self.lng_dim)
if self.lng_ids is not None:
output = torch.index_select(output, 1, self.lng_ids)
elif self.disc_type == 'deberta':
output = enc_output.logits
return output
class SemEmb(T5EncoderModel):
def __init__(self, config, sep_token_id):
super().__init__(config)
self.sep_token_id = sep_token_id
hidden_dim = self.config.d_model
self.projection = nn.Sequential(nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, 1))
def compare_sem(self, **batch):
bs = batch['sentence1_attention_mask'].shape[0]
ones = torch.ones((bs, 1), device=batch['sentence1_attention_mask'].device)
sep = torch.ones((bs, 1), dtype=torch.long,
device=batch['sentence1_attention_mask'].device) * self.sep_token_id
att_mask = torch.cat([batch['sentence1_attention_mask'], ones, batch['sentence2_attention_mask']], dim=1)
if 'logits' in batch:
input_ids = torch.cat([batch['sentence1_input_ids'], sep], dim=1)
embeds1 = self.shared(input_ids)
logits = batch['logits']
scores = F.softmax(logits, dim = -1)
onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device)
onehot_ = scores - scores.detach() + onehot
embeds2 = onehot_ @ self.shared.weight
embeds1_2 = torch.cat([embeds1, embeds2], dim=1)
hidden_units = self(inputs_embeds=embeds1_2,
attention_mask=att_mask).last_hidden_state.mean(1)
elif 'sentence2_input_ids' in batch:
input_ids = torch.cat([batch['sentence1_input_ids'], sep, batch['sentence2_input_ids']], dim=1)
hidden_units = self(input_ids=input_ids,
attention_mask=att_mask).last_hidden_state.mean(1)
probs = self.projection(hidden_units)
return probs
def prepare_inputs_for_generation(
combine_method,
ling2_only,
self,
input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
sent1_ling=None,
sent2_ling=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
input_ids = input_ids.clone()
decoder_inputs_embeds = self.shared(input_ids)
if combine_method == 'decoder_add_first':
sent2_ling = torch.cat([sent2_ling,
torch.repeat_interleave(torch.zeros_like(sent2_ling), input_ids.shape[1] - 1, dim=1)], dim = 1)
if combine_method == 'decoder_concat':
if ling2_only:
decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1)
else:
decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1)
elif combine_method == 'decoder_add'or (past_key_values is None and combine_method == 'decoder_add_first'):
if ling2_only:
decoder_inputs_embeds = decoder_inputs_embeds + sent2_ling
else:
decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling
return {
"decoder_inputs_embeds": decoder_inputs_embeds,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
class LogitsAdd(LogitsProcessor):
def __init__(self, sent2_ling):
super().__init__()
self.sent2_ling = sent2_ling
def __call__(self, input_ids, scores):
return scores + self.sent2_ling
class EncoderDecoderVAE(T5ForConditionalGeneration):
def __init__(self, config, args, pad_token_id, sepeos_token_id, vocab_size = 32128):
super().__init__(config)
self.prepare_inputs_for_generation = types.MethodType(
partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only),
self)
self.args = args
self.pad_token_id = pad_token_id
self.eos_token_id = sepeos_token_id
hidden_dim = self.config.d_model if not 'logits' in args.combine_method else vocab_size
if args.combine_method == 'fusion1':
self.fusion = nn.Sequential(
nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim),
)
elif args.combine_method == 'fusion2':
self.fusion = nn.Sequential(
nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
elif 'concat' in args.combine_method or 'add' in args.combine_method:
if args.ling_embed_type == 'two-layer':
self.ling_embed = nn.Sequential(
nn.Linear(args.lng_dim, args.lng_dim),
nn.ReLU(),
nn.Linear(args.lng_dim, hidden_dim),
)
else:
self.ling_embed = nn.Linear(args.lng_dim, hidden_dim)
self.ling_dropout = nn.Dropout(args.ling_dropout)
if args.ling_vae:
self.ling_mu = nn.Linear(hidden_dim, hidden_dim)
self.ling_logvar = nn.Linear(hidden_dim, hidden_dim)
nn.init.xavier_uniform_(self.ling_embed.weight)
nn.init.xavier_uniform_(self.ling_mu.weight)
nn.init.xavier_uniform_(self.ling_logvar.weight)
generate_with_grad = unwrap(self.generate)
self.generate_with_grad = MethodType(generate_with_grad, self)
def get_fusion_layer(self):
if 'fusion' in self.args.combine_method:
return self.fusion
elif 'concat' in self.args.combine_method or 'add' in self.args.combine_method:
return self.ling_embed
else:
return None
def sample(self, mu, logvar):
std = torch.exp(0.5 * logvar)
return mu + std * torch.randn_like(std)
def encode(self, batch):
if 'inputs_embeds' in batch:
inputs_embeds = batch['inputs_embeds']
else:
inputs_embeds = self.shared(batch['sentence1_input_ids'])
inputs_att_mask = batch['sentence1_attention_mask']
bs = inputs_embeds.shape[0]
cache = {}
if self.args.combine_method in ('input_concat', 'input_add'):
if 'sent1_ling_embed' in batch:
sent1_ling = batch['sent1_ling_embed']
else:
sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling']))
if 'sent2_ling_embed' in batch:
sent2_ling = batch['sent2_ling_embed']
else:
sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling']))
if self.args.ling_vae:
sent1_ling = F.leaky_relu(sent1_ling)
sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling)
sent1_ling = self.sample(sent1_mu, sent1_logvar)
sent2_ling = F.leaky_relu(sent2_ling)
sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
sent2_ling = self.sample(sent2_mu, sent2_logvar)
cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
else:
cache.update({'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
sent1_ling = sent1_ling.view(bs, 1, -1)
sent2_ling = sent2_ling.view(bs, 1, -1)
if self.args.combine_method == 'input_concat':
if self.args.ling2_only:
inputs_embeds = torch.cat([inputs_embeds, sent2_ling], dim=1)
inputs_att_mask = torch.cat([inputs_att_mask,
torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1)
else:
inputs_embeds = torch.cat([inputs_embeds, sent1_ling, sent2_ling], dim=1)
inputs_att_mask = torch.cat([inputs_att_mask,
torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1)
elif self.args.combine_method == 'input_add':
if self.args.ling2_only:
inputs_embeds = inputs_embeds + sent2_ling
else:
inputs_embeds = inputs_embeds + sent1_ling + sent2_ling
return self.encoder(inputs_embeds=inputs_embeds,
attention_mask=inputs_att_mask), inputs_att_mask, cache
def decode(self, batch, enc_output, inputs_att_mask, generate):
bs = inputs_att_mask.shape[0]
cache = {}
if self.args.combine_method in ('embed_concat', 'decoder_concat', 'decoder_add', 'logits_add', 'decoder_add_first'):
if 'sent1_ling_embed' in batch:
sent1_ling = batch['sent1_ling_embed']
elif 'sentence1_ling' in batch:
sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling']))
else:
sent1_ling = None
if 'sent2_ling_embed' in batch:
sent2_ling = batch['sent2_ling_embed']
else:
sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling']))
if self.args.ling_vae:
sent1_ling = F.leaky_relu(sent1_ling)
sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling)
sent1_ling = self.sample(sent1_mu, sent1_logvar)
sent2_ling = F.leaky_relu(sent2_ling)
sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
sent2_ling = self.sample(sent2_mu, sent2_logvar)
cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
else:
cache.update({'sent2_ling': sent2_ling})
if sent1_ling is not None:
cache.update({'sent1_ling': sent1_ling})
if sent1_ling is not None:
sent1_ling = sent1_ling.view(bs, 1, -1)
sent2_ling = sent2_ling.view(bs, 1, -1)
if self.args.combine_method == 'decoder_add_first' and not generate:
sent2_ling = torch.cat([sent2_ling,
torch.repeat_interleave(torch.zeros_like(sent2_ling), batch['sentence2_input_ids'].shape[1] - 1, dim=1)], dim = 1)
else:
sent1_ling, sent2_ling = None, None
if self.args.combine_method == 'embed_concat':
enc_output.last_hidden_state = torch.cat([enc_output.last_hidden_state,
sent1_ling, sent2_ling], dim=1)
inputs_att_mask = torch.cat([inputs_att_mask,
torch.ones((bs, 2)).to(inputs_att_mask.device)], dim=1)
elif 'fusion' in self.args.combine_method:
sent1_ling = batch['sentence1_ling'].unsqueeze(1)\
.expand(-1, enc_output.last_hidden_state.shape[1], -1)
sent2_ling = batch['sentence2_ling'].unsqueeze(1)\
.expand(-1, enc_output.last_hidden_state.shape[1], -1)
if self.args.ling2_only:
combined_embedding = torch.cat([enc_output.last_hidden_state, sent2_ling], dim=2)
else:
combined_embedding = torch.cat([enc_output.last_hidden_state, sent1_ling, sent2_ling], dim=2)
enc_output.last_hidden_state = self.fusion(combined_embedding)
if generate:
if self.args.combine_method == 'logits_add':
logits_processor = LogitsProcessorList([LogitsAdd(sent2_ling.view(bs, -1))])
else:
logits_processor = LogitsProcessorList()
dec_output = self.generate_with_grad(
attention_mask=inputs_att_mask,
encoder_outputs=enc_output,
sent1_ling=sent1_ling,
sent2_ling=sent2_ling,
return_dict_in_generate=True,
output_scores=True,
logits_processor = logits_processor,
# renormalize_logits=True,
# do_sample=True,
# top_p=0.8,
eos_token_id=self.eos_token_id,
# min_new_tokens=3,
# repetition_penalty=1.2,
max_length=self.args.max_length,
)
scores = torch.stack(dec_output.scores, 1)
cache.update({'scores': scores})
return dec_output.sequences, cache
decoder_input_ids = self._shift_right(batch['sentence2_input_ids'])
decoder_inputs_embeds = self.shared(decoder_input_ids)
decoder_att_mask = batch['sentence2_attention_mask']
labels = batch['sentence2_input_ids'].clone()
labels[labels == self.pad_token_id] = -100
if self.args.combine_method == 'decoder_concat':
if self.args.ling2_only:
decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1)
decoder_att_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
labels], dim=1)
else:
decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1)
decoder_att_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
labels], dim=1)
elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' :
if self.args.ling2_only:
decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sent2_ling
else:
decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling
dec_output = self(
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_attention_mask=decoder_att_mask,
encoder_outputs=enc_output,
attention_mask=inputs_att_mask,
labels=labels,
)
if self.args.combine_method == 'logits_add':
dec_output.logits = dec_output.logits + self.args.combine_weight * sent2_ling
vocab_size = dec_output.logits.size(-1)
dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1))
return dec_output, cache
def convert(self, batch, generate=False):
enc_output, enc_att_mask, cache = self.encode(batch)
dec_output, cache2 = self.decode(batch, enc_output, enc_att_mask, generate)
cache.update(cache2)
return dec_output, enc_output, cache
def infer_with_cache(self, batch):
dec_output, _, cache = self.convert(batch, generate = True)
return dec_output, cache
def infer(self, batch):
dec_output, _ = self.infer_with_cache(batch)
return dec_output
def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer):
from torch.autograd import grad
interpolations = []
def line_search():
best_val = None
best_loss = None
eta = 1e3
sem_prob = 1
patience = 4
while patience > 0:
param_ = param - eta * grads
with torch.no_grad():
new_loss, pred = get_loss(param_)
max_len = pred.shape[1]
lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1
batch.update({
'sentence2_input_ids': pred,
'sentence2_attention_mask': sequence_mask(lens, max_len = max_len)
})
sem_prob = torch.sigmoid(sem_emb.compare_sem(**batch)).item()
# if sem_prob <= 0.1:
# patience -= 1
if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1:
return param_
eta *= 2.25
patience -= 1
return False
def get_loss(param):
if self.args.feedback_param == 'l':
batch.update({'sent2_ling_embed': param})
elif self.args.feedback_param == 's':
batch.update({'inputs_embeds': param})
if self.args.feedback_param == 'logits':
logits = param
pred = param.argmax(-1)
else:
pred, cache = self.infer_with_cache(batch)
logits = cache['scores']
out = ling_disc(logits = logits)
probs = F.softmax(out, 1)
if ling_disc.quant:
loss = F.cross_entropy(out, batch['sentence2_discr'])
else:
loss = F.mse_loss(out, batch['sentence2_ling'])
return loss, pred
if self.args.feedback_param == 'l':
ling2_embed = self.ling_embed(batch['sentence2_ling'])
param = torch.nn.Parameter(ling2_embed, requires_grad = True)
elif self.args.feedback_param == 's':
inputs_embeds = self.shared(batch['sentence1_input_ids'])
param = torch.nn.Parameter(inputs_embeds, requires_grad = True)
elif self.args.feedback_param == 'logits':
logits = self.infer_with_cache(batch)[1]['scores']
param = torch.nn.Parameter(logits, requires_grad = True)
target_np = batch['sentence2_ling'][0].cpu().numpy()
while True:
loss, pred = get_loss(param)
pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
skip_special_tokens=True)[0]
interpolations.append(pred_text)
if loss < 1:
break
self.zero_grad()
grads = grad(loss, param)[0]
param = line_search()
if param is False:
break
return pred, [pred_text, interpolations]
def set_grad(module, state):
if module is not None:
for p in module.parameters():
p.requires_grad = state
def set_grad_except(model, name, state):
for n, p in model.named_parameters():
if not name in n:
p.requires_grad = state
class SemEmbPipeline():
def __init__(self,
ckpt = "/data/mohamed/checkpoints/ling_conversion_sem_emb_best.pt"):
self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
self.model = SemEmb(T5EncoderModel.from_pretrained('google/flan-t5-base'), self.tokenizer.get_vocab()['</s>'])
state = torch.load(ckpt)
self.model.load_state_dict(state['model'], strict=False)
self.model.eval()
self.model.cuda()
def __call__(self, sentence1, sentence2):
sentence1 = self.tokenizer(sentence1, return_attention_mask = True, return_tensors = 'pt')
sentence2 = self.tokenizer(sentence2, return_attention_mask = True, return_tensors = 'pt')
sem_logit = self.model(
sentence1_input_ids = sentence1.input_ids.cuda(),
sentence1_attention_mask = sentence1.attention_mask.cuda(),
sentence2_input_ids = sentence2.input_ids.cuda(),
sentence2_attention_mask = sentence2.attention_mask.cuda(),
)
sem_prob = torch.sigmoid(sem_logit).item()
return sem_prob
class LingDiscPipeline():
def __init__(self,
model_name="google/flan-t5-base",
disc_type='deberta',
disc_ckpt='/data/mohamed/checkpoints/ling_disc/deberta-v3-small_flan-t5-base_40',
# disc_type='t5',
# disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt',
):
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.model = LingDisc(model_name, disc_type, disc_ckpt)
self.model.eval()
self.model.cuda()
def __call__(self, sentence):
inputs = self.tokenizer(sentence, return_tensors = 'pt')
with torch.no_grad():
ling_pred = self.model(input_ids=inputs.input_ids.cuda())
return ling_pred
def get_model(args, tokenizer, device):
if args.pretrain_disc or args.disc_loss or args.disc_ckpt:
ling_disc = LingDisc(args.model_name, args.disc_type, args.disc_model_path).to(device)
else:
ling_disc = None
if args.linggen_type != 'none':
ling_gen = LingGenerator(args).to(device)
if not args.pretrain_disc:
model = EncoderDecoderVAE.from_pretrained(args.model_path, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device)
else:
model = ling_disc
if args.sem_loss or args.sem_ckpt:
if args.sem_loss_type == 'shared':
sem_emb = model.encoder
elif args.sem_loss_type == 'dedicated':
sem_emb = SemEmb.from_pretrained(args.sem_model_path, tokenizer.eos_token_id).to(device)
else:
raise NotImplementedError('Semantic loss type')
else:
sem_emb = None
return model, ling_disc, sem_emb