import types
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoModel, LogitsProcessor, LogitsProcessorList
from functools import partial
from compute_lng import compute_lng
from undecorate import unwrap
from types import MethodType
from utils import *
from ling_disc import DebertaReplacedTokenizer
from const import *



def vae_sample(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return eps * std + mu

class VAE(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.encoder = nn.Sequential(
                nn.Linear(args.input_dim, args.hidden_dim),
                nn.ReLU(),
                nn.Linear(args.hidden_dim, args.hidden_dim),
                nn.ReLU(),
                )
        self.decoder = nn.Sequential(
                nn.Linear(args.latent_dim, args.hidden_dim),
                nn.ReLU(),
                nn.Linear(args.hidden_dim, args.hidden_dim),
                nn.ReLU(),
                nn.Linear(args.hidden_dim, args.input_dim),
                )
        self.fc_mu = nn.Linear(args.hidden_dim, args.latent_dim)
        self.fc_var = nn.Linear(args.hidden_dim, args.latent_dim)

    def forward(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_var(h)
        x = vae_sample(mu, logvar)
        o = self.decoder(x)
        return o, (mu, logvar)

class LingGenerator(nn.Module):
    def __init__(self, args, hidden_dim=1000):
        super().__init__()

        self.gen = T5EncoderModel.from_pretrained('google/flan-t5-small')
        self.hidden_size = self.gen.config.d_model
        self.ling_embed = nn.Linear(args.lng_dim, self.hidden_size)
        # self.gen = nn.Sequential(
        #         nn.Linear(args.lng_dim, 2*hidden_dim),
        #         nn.ReLU(),
        #         nn.BatchNorm1d(2*hidden_dim),
        #         nn.Linear(2*hidden_dim, 2*hidden_dim),
        #         nn.ReLU(),
        #         nn.BatchNorm1d(2*hidden_dim),
        #         nn.Linear(2*hidden_dim, hidden_dim),
        #         nn.ReLU(),
        #         )

        self.gen_type = args.linggen_type
        self.gen_input = args.linggen_input
        if self.gen_type == 'vae':
            self.gen_mu = nn.Linear(hidden_dim, args.lng_dim)
            self.gen_logvar = nn.Linear(hidden_dim, args.lng_dim)
        elif self.gen_type == 'det':
            self.projection = nn.Linear(self.hidden_size, args.lng_dim)

    def forward(self, batch):
        inputs_embeds = self.gen.shared(batch['sentence1_input_ids'])
        inputs_att_mask = batch['sentence1_attention_mask']
        bs = inputs_embeds.shape[0]

        if self.gen_input == 's+l':
            sent1_ling = self.ling_embed(batch['sentence1_ling'])
            sent1_ling = sent1_ling.view(bs, 1, -1)
            inputs_embeds = inputs_embeds + sent1_ling

        gen = self.gen(inputs_embeds=inputs_embeds,
                attention_mask=inputs_att_mask).last_hidden_state.mean(1)
        # gen = self.gen(batch['sentence1_ling'])

        cache = {}
        if self.gen_type == 'vae':
            mu = self.gen_mu(gen)
            logvar = self.gen_logvar(gen)
            output = vae_sample(mu, logvar)
            cache['linggen_mu'] = mu
            cache['linggen_logvar'] = logvar
        elif self.gen_type == 'det':
            output = self.projection(gen)

        return output, cache


class LingDisc(nn.Module):
    def __init__(self,
                 model_name,
                 disc_type,
                 disc_ckpt,
                 lng_dim=40,
                 quant_nbins=1,
                 disc_lng_dim=None,
                 lng_ids=None,
                 **kwargs):
        super().__init__()
        if disc_type == 't5':
            self.encoder = T5EncoderModel.from_pretrained(model_name)
            hidden_dim = self.encoder.config.d_model
            self.dropout = nn.Dropout(0.2)
            self.lng_dim = disc_lng_dim if disc_lng_dim else lng_dim
            self.quant = quant_nbins > 1
            self.quant = False
            if self.quant:
                self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim * quant_nbins)
            else:
                self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim)
            lng_ids = torch.tensor(lng_ids) if lng_ids is not None else None
            # from const import used_indices
            # lng_ids = torch.tensor(used_indices)
            self.register_buffer('lng_ids', lng_ids)
        elif disc_type == 'deberta':
            self.encoder= DebertaReplacedTokenizer.from_pretrained(
                    pretrained_model_name_or_path=disc_ckpt,
                    tok_model_name = model_name,
                    problem_type='regression', num_labels=40)
            self.quant = False

        self.disc_type = disc_type

    def forward(self, **batch):
        if not 'attention_mask' in batch:
            if 'input_ids' in batch:
                att_mask = torch.ones_like(batch['input_ids'])
            else:
                att_mask = torch.ones_like(batch['logits'])[:,:,0]
        else:
            att_mask = batch['attention_mask']
        if 'input_ids' in batch:
            enc_output = self.encoder(input_ids=batch['input_ids'],
                    attention_mask=att_mask)
        elif 'logits' in batch:
            logits = batch['logits']
            scores = F.softmax(logits, dim = -1)
            onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device)
            onehot_ = scores - scores.detach() + onehot

            embed_layer = self.encoder.get_input_embeddings()
            if isinstance(embed_layer, nn.Sequential):
                for i, module in enumerate(embed_layer):
                    if i == 0:
                        embeds = torch.matmul(onehot_, module.weight)
                    else:
                        embeds = module(embeds)
            else:
                embeds =  onehot_ @ embed_layer.weight
                embeds = torch.matmul(onehot_, embed_layer.weight)

            enc_output = self.encoder(inputs_embeds=embeds,
                    attention_mask=att_mask)
        if self.disc_type == 't5':
            sent_emb = self.dropout(enc_output.last_hidden_state.mean(1))
            bs = sent_emb.shape[0]
            output = self.ling_classifier(sent_emb)
            if self.quant:
                output = output.reshape(bs, -1, self.lng_dim)
            if self.lng_ids is not None:
                output = torch.index_select(output, 1, self.lng_ids)
        elif self.disc_type == 'deberta':
            output = enc_output.logits
        return output

class SemEmb(nn.Module):
    def __init__(self, backbone, sep_token_id):
        super().__init__()
        self.backbone = backbone
        self.sep_token_id = sep_token_id
        hidden_dim = self.backbone.config.d_model
        self.projection = nn.Sequential(nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden_dim, 1))

    def forward(self, **batch):
        bs = batch['sentence1_attention_mask'].shape[0]
        ones = torch.ones((bs, 1), device=batch['sentence1_attention_mask'].device)
        sep = torch.ones((bs, 1), dtype=torch.long,
                device=batch['sentence1_attention_mask'].device) * self.sep_token_id
        att_mask = torch.cat([batch['sentence1_attention_mask'], ones, batch['sentence2_attention_mask']], dim=1)
        if 'logits' in batch:
            input_ids = torch.cat([batch['sentence1_input_ids'], sep], dim=1)
            embeds1 = self.backbone.shared(input_ids)

            logits = batch['logits']
            scores = F.softmax(logits, dim = -1)
            onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device)
            onehot_ = scores - scores.detach() + onehot

            embeds2 =  onehot_ @ self.backbone.shared.weight
            embeds1_2 = torch.cat([embeds1, embeds2], dim=1)
            hidden_units = self.backbone(inputs_embeds=embeds1_2,
                    attention_mask=att_mask).last_hidden_state.mean(1)
        elif 'sentence2_input_ids' in batch:
            input_ids = torch.cat([batch['sentence1_input_ids'], sep, batch['sentence2_input_ids']], dim=1)
            hidden_units = self.backbone(input_ids=input_ids,
                    attention_mask=att_mask).last_hidden_state.mean(1)
        probs = self.projection(hidden_units)
        return probs

def prepare_inputs_for_generation(
        combine_method,
        ling2_only,
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        sent1_ling=None,
        sent2_ling=None,
        **kwargs
    ):

        # cut decoder_input_ids if past is used
        if past_key_values is not None:
            input_ids = input_ids[:, -1:]

        input_ids = input_ids.clone()
        decoder_inputs_embeds = self.shared(input_ids)

        if combine_method == 'decoder_add_first':
            sent2_ling = torch.cat([sent2_ling,
                torch.repeat_interleave(torch.zeros_like(sent2_ling), input_ids.shape[1] - 1, dim=1)], dim = 1)
        if combine_method == 'decoder_concat':
            if ling2_only:
                decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1)
            else:
                decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1)
        elif combine_method == 'decoder_add'or (past_key_values is None and combine_method == 'decoder_add_first'):
            if ling2_only:
                decoder_inputs_embeds = decoder_inputs_embeds + sent2_ling
            else:
                decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling

        return {
            "decoder_inputs_embeds": decoder_inputs_embeds,
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
        }

class LogitsAdd(LogitsProcessor):
    def __init__(self, sent2_ling):
        super().__init__()
        self.sent2_ling = sent2_ling

    def __call__(self, input_ids, scores):
        return scores + self.sent2_ling

class EncoderDecoderVAE(nn.Module):
    def __init__(self, args, pad_token_id, sepeos_token_id, vocab_size = 32128):
        super().__init__()
        self.backbone = T5ForConditionalGeneration.from_pretrained(args.model_name)
        self.backbone.prepare_inputs_for_generation = types.MethodType(
                partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only),
                self.backbone)
        self.args = args
        self.pad_token_id = pad_token_id
        self.eos_token_id = sepeos_token_id
        hidden_dim = self.backbone.config.d_model if not 'logits' in args.combine_method else vocab_size
        if args.combine_method == 'fusion1':
            self.fusion = nn.Sequential(
                    nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim),
                    )
        elif args.combine_method == 'fusion2':
            self.fusion = nn.Sequential(
                    nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Linear(hidden_dim, hidden_dim),
                    )
        elif 'concat' in args.combine_method or 'add' in args.combine_method:
            if args.ling_embed_type == 'two-layer':
                self.ling_embed = nn.Sequential(
                        nn.Linear(args.lng_dim, args.lng_dim),
                        nn.ReLU(),
                        nn.Linear(args.lng_dim, hidden_dim),
                        )
            else:
                self.ling_embed = nn.Linear(args.lng_dim, hidden_dim)
            self.ling_dropout = nn.Dropout(args.ling_dropout)

        if args.ling_vae:
            self.ling_mu = nn.Linear(hidden_dim, hidden_dim)
            self.ling_logvar = nn.Linear(hidden_dim, hidden_dim)
            nn.init.xavier_uniform_(self.ling_embed.weight)
            nn.init.xavier_uniform_(self.ling_mu.weight)
            nn.init.xavier_uniform_(self.ling_logvar.weight)


        generate_with_grad = unwrap(self.backbone.generate)
        self.backbone.generate_with_grad = MethodType(generate_with_grad, self.backbone)

    def get_fusion_layer(self):
        if 'fusion' in self.args.combine_method:
            return self.fusion
        elif 'concat' in self.args.combine_method or 'add' in self.args.combine_method:
            return self.ling_embed
        else:
            return None

    def sample(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + std * torch.randn_like(std)

    def encode(self, batch):
        if 'inputs_embeds' in batch:
            inputs_embeds = batch['inputs_embeds']
        else:
            inputs_embeds = self.backbone.shared(batch['sentence1_input_ids'])
        inputs_att_mask = batch['sentence1_attention_mask']
        bs = inputs_embeds.shape[0]
        cache = {}
        if self.args.combine_method in ('input_concat', 'input_add'):
            if 'sent1_ling_embed' in batch:
                sent1_ling = batch['sent1_ling_embed']
            else:
                sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling']))
            if 'sent2_ling_embed' in batch:
                sent2_ling = batch['sent2_ling_embed']
            else:
                sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling']))
            if self.args.ling_vae:
                sent1_ling = F.leaky_relu(sent1_ling)
                sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling)
                sent1_ling = self.sample(sent1_mu, sent1_logvar)

                sent2_ling = F.leaky_relu(sent2_ling)
                sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
                sent2_ling = self.sample(sent2_mu, sent2_logvar)
                cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
                    'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
                    'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
            else:
                cache.update({'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
            sent1_ling = sent1_ling.view(bs, 1, -1)
            sent2_ling = sent2_ling.view(bs, 1, -1)
            if self.args.combine_method == 'input_concat':
                if self.args.ling2_only:
                    inputs_embeds = torch.cat([inputs_embeds, sent2_ling], dim=1)
                    inputs_att_mask = torch.cat([inputs_att_mask,
                        torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1)
                else:
                    inputs_embeds = torch.cat([inputs_embeds, sent1_ling, sent2_ling], dim=1)
                    inputs_att_mask = torch.cat([inputs_att_mask,
                        torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1)
            elif self.args.combine_method == 'input_add':
                if self.args.ling2_only:
                    inputs_embeds = inputs_embeds + sent2_ling
                else:
                    inputs_embeds = inputs_embeds + sent1_ling + sent2_ling
        return self.backbone.encoder(inputs_embeds=inputs_embeds,
                attention_mask=inputs_att_mask), inputs_att_mask, cache

    def decode(self, batch, enc_output, inputs_att_mask, generate):
        bs = inputs_att_mask.shape[0]
        cache = {}
        if self.args.combine_method in ('embed_concat', 'decoder_concat', 'decoder_add', 'logits_add', 'decoder_add_first'):
            if 'sent1_ling_embed' in batch:
                sent1_ling = batch['sent1_ling_embed']
            elif 'sentence1_ling' in batch:
                sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling']))
            else:
                sent1_ling = None
            if 'sent2_ling_embed' in batch:
                sent2_ling = batch['sent2_ling_embed']
            else:
                sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling']))
            if self.args.ling_vae:
                sent1_ling = F.leaky_relu(sent1_ling)
                sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling)
                sent1_ling = self.sample(sent1_mu, sent1_logvar)

                sent2_ling = F.leaky_relu(sent2_ling)
                sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
                sent2_ling = self.sample(sent2_mu, sent2_logvar)
                cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
                    'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
                    'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
            else:
                cache.update({'sent2_ling': sent2_ling})
                if sent1_ling is not None:
                    cache.update({'sent1_ling': sent1_ling})
            if sent1_ling is not None:
                sent1_ling = sent1_ling.view(bs, 1, -1)
            sent2_ling = sent2_ling.view(bs, 1, -1)
            if self.args.combine_method == 'decoder_add_first' and not generate:
                sent2_ling = torch.cat([sent2_ling,
                    torch.repeat_interleave(torch.zeros_like(sent2_ling), batch['sentence2_input_ids'].shape[1] - 1, dim=1)], dim = 1)
        else:
            sent1_ling, sent2_ling = None, None

        if self.args.combine_method == 'embed_concat':
            enc_output.last_hidden_state = torch.cat([enc_output.last_hidden_state,
                sent1_ling, sent2_ling], dim=1)
            inputs_att_mask = torch.cat([inputs_att_mask,
                torch.ones((bs, 2)).to(inputs_att_mask.device)], dim=1)
        elif 'fusion' in self.args.combine_method:
            sent1_ling = batch['sentence1_ling'].unsqueeze(1)\
                    .expand(-1, enc_output.last_hidden_state.shape[1], -1)
            sent2_ling = batch['sentence2_ling'].unsqueeze(1)\
                    .expand(-1, enc_output.last_hidden_state.shape[1], -1)
            if self.args.ling2_only:
                combined_embedding = torch.cat([enc_output.last_hidden_state, sent2_ling], dim=2)
            else:
                combined_embedding = torch.cat([enc_output.last_hidden_state, sent1_ling, sent2_ling], dim=2)
            enc_output.last_hidden_state = self.fusion(combined_embedding)

        if generate:
            if self.args.combine_method == 'logits_add':
                logits_processor = LogitsProcessorList([LogitsAdd(sent2_ling.view(bs, -1))])
            else:
                logits_processor = LogitsProcessorList()

            dec_output = self.backbone.generate_with_grad(
                    attention_mask=inputs_att_mask,
                    encoder_outputs=enc_output,
                    sent1_ling=sent1_ling,
                    sent2_ling=sent2_ling,
                    return_dict_in_generate=True,
                    output_scores=True,
                    logits_processor = logits_processor,
                    # renormalize_logits=True,
                    # do_sample=True,
                    # top_p=0.8,
                    eos_token_id=self.eos_token_id,
                    # min_new_tokens=3,
                    # repetition_penalty=1.2,
                    max_length=self.args.max_length,
                    )
            scores = torch.stack(dec_output.scores, 1)
            cache.update({'scores': scores})
            return dec_output.sequences, cache

        decoder_input_ids = self.backbone._shift_right(batch['sentence2_input_ids'])
        decoder_inputs_embeds = self.backbone.shared(decoder_input_ids)
        decoder_att_mask = batch['sentence2_attention_mask']
        labels = batch['sentence2_input_ids'].clone()
        labels[labels == self.pad_token_id] = -100

        if self.args.combine_method == 'decoder_concat':
            if self.args.ling2_only:
                decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1)
                decoder_att_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
                labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
                    labels], dim=1)
            else:
                decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1)
                decoder_att_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
                labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
                    labels], dim=1)
        elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' :
            if self.args.ling2_only:
                decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sent2_ling
            else:
                decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling

        dec_output = self.backbone(
                decoder_inputs_embeds=decoder_inputs_embeds,
                decoder_attention_mask=decoder_att_mask,
                encoder_outputs=enc_output,
                attention_mask=inputs_att_mask,
                labels=labels,
                )
        if self.args.combine_method == 'logits_add':
            dec_output.logits = dec_output.logits + self.args.combine_weight * sent2_ling
            vocab_size = dec_output.logits.size(-1)
            dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1))
        return dec_output, cache


    def forward(self, batch, generate=False):
        enc_output, enc_att_mask, cache = self.encode(batch)
        dec_output, cache2 = self.decode(batch, enc_output, enc_att_mask, generate)
        cache.update(cache2)
        return dec_output, enc_output, cache

    def infer_with_cache(self, batch):
        dec_output, _, cache = self(batch, generate = True)
        return dec_output, cache

    def infer(self, batch):
        dec_output, _ = self.infer_with_cache(batch)
        return dec_output

    def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer, scaler):
        from torch.autograd import grad
        interpolations = []
        def line_search():
            best_val = None
            best_loss = None
            eta = 1e3
            sem_prob = 1
            patience = 4
            while patience > 0:
                param_ = param - eta * grads
                with torch.no_grad():
                    new_loss, pred = get_loss(param_)
                max_len = pred.shape[1]
                lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1
                # if lens.item() == 1:
                #     patience -= 1
                batch.update({
                    'sentence2_input_ids': pred,
                    'sentence2_attention_mask': sequence_mask(lens, max_len = max_len)
                    })
                sem_prob = torch.sigmoid(sem_emb(**batch)).item()
                # if sem_prob <= 0.1:
                #     patience -= 1
                # f.write(f'[{eta}], [{new_loss.item():.2f}], [{sem_prob:.2f}], {tokenizer.decode(pred[0])}\n')
                # print(f'[{eta}], [{new_loss.item():.2f}], [{sem_prob:.2f}], {tokenizer.decode(pred[0])}\n')
                if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1:
                    return param_
                eta *= 2.25
                patience -= 1
            return False

        def get_loss(param):
            if self.args.feedback_param == 'l':
                batch.update({'sent2_ling_embed': param})
            elif self.args.feedback_param == 's':
                batch.update({'inputs_embeds': param})

            if self.args.feedback_param == 'logits':
                logits = param
                pred = param.argmax(-1)
            else:
                pred, cache = self.infer_with_cache(batch)
                logits = cache['scores']
            out = ling_disc(logits = logits)
            probs = F.softmax(out, 1)
            if ling_disc.quant:
                loss = F.cross_entropy(out, batch['sentence2_discr'])
            else:
                loss = F.mse_loss(out, batch['sentence2_ling'])
            return loss, pred

        if self.args.feedback_param == 'l':
            ling2_embed = self.ling_embed(batch['sentence2_ling'])
            param = torch.nn.Parameter(ling2_embed, requires_grad = True)
        elif self.args.feedback_param == 's':
            inputs_embeds = self.backbone.shared(batch['sentence1_input_ids'])
            param = torch.nn.Parameter(inputs_embeds, requires_grad = True)
        elif self.args.feedback_param == 'logits':
            logits = self.infer_with_cache(batch)[1]['scores']
            param = torch.nn.Parameter(logits, requires_grad = True)
        f = open(self.args.fb_log, 'a') if self.args.fb_log else None
        target_np = batch['sentence2_ling'][0].cpu().numpy()
        while True:
            loss, pred = get_loss(param)
            pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
                    skip_special_tokens=True)[0]
            if f:
                # from compute_lng import compute_lng
                # lng_pred = scaler.transform(np.array([compute_lng(pred_text)])[:,used_indices])[0]
                # real_loss = np.mean((lng_pred - target_np)**2)
                # f.write(f'Loss: {loss.item():.2f}\tReal loss:{real_loss:.2f}\t{pred_text}\n')
                f.write(f'*** [{loss.item():.2f}], {pred_text}\n')
            interpolations.append(pred_text)
            if loss < 1:
                break
            self.zero_grad()
            grads = grad(loss, param)[0]
            param = line_search()
            if param is False:
                break
        if f:
            f.write(f'[return] {pred_text}\n\n')
            f.close()
        return pred, [pred_text, interpolations]

    def infer_with_feedback(self, ling_disc, batch, tokenizer, scaler, approx=False):
        interpolations = []
        converged = False
        c = 0
        eta = 0.3
        use_embed = True
        if use_embed:
            ling1_embed = self.ling_embed(batch['sentence1_ling'])
            ling2_embed = self.ling_embed(batch['sentence2_ling'])
            batch.update({
                    'sent1_ling_embed': ling1_embed,
                    'sent2_ling_embed': ling2_embed,
                    })
        else:
            ling2 = batch['sentence2_ling']
        ling2_orig = batch['sentence2_ling'].clone()
        while not converged:
            with torch.no_grad():
                pred = self.infer(batch)
                inputs_pred = batch.copy()
                inputs_pred.update({'input_ids': pred,
                    'attention_mask': torch.ones_like(pred)})
                pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
                        skip_special_tokens=True)[0]
                if approx:
                    ling_pred = ling_disc(**inputs_pred)
                else:
                    ling_pred = compute_lng(pred_text)
                    ling_pred = scaler.transform([ling_pred])[0]
                    ling_pred = torch.tensor(ling_pred).to(pred.device).float()
                if use_embed:
                    ling_pred_embed = self.ling_embed(ling_pred)
                    # diff = torch.mean((ling2_embed - ling_pred_embed)**2)
                # else:
                diff = torch.mean((ling2_orig - ling_pred)**2)


            # print(f'Diff {diff.item():.3f}>> {tokenizer.batch_decode(pred.cpu().numpy(), skip_special_tokens=True)[0]}')
            if diff < 1e-1 or c == 6:
                converged = True
            elif use_embed:
                ling2_embed = ling2_embed + eta * (ling_pred_embed - ling2_embed)
                batch.update({'sent2_ling_embed': ling2_embed})
            else:
                ling2 = ling2 + eta * (ling_pred - ling2)
                batch.update({'sentence2_ling': ling2})
                
            c += 1

            if len(interpolations) == 0 or pred_text != interpolations[-1]:
                interpolations.append(pred_text)

        return [pred_text, interpolations]

def set_grad(module, state):
    if module is not None:
        for p in module.parameters():
            p.requires_grad = state

def set_grad_except(model, name, state):
    for n, p in model.named_parameters():
        if not name in n:
            p.requires_grad = state

class SemEmbPipeline():
    def __init__(self,
            ckpt = "/data/mohamed/checkpoints/ling_conversion_sem_emb_best.pt"):
        self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
        self.model = SemEmb(T5EncoderModel.from_pretrained('google/flan-t5-base'), self.tokenizer.get_vocab()['</s>'])
        state = torch.load(ckpt)
        self.model.load_state_dict(state['model'], strict=False)
        self.model.eval()
        self.model.cuda()

    def __call__(self, sentence1, sentence2):
        sentence1 = self.tokenizer(sentence1, return_attention_mask = True, return_tensors = 'pt')
        sentence2 = self.tokenizer(sentence2, return_attention_mask = True, return_tensors = 'pt')
        sem_logit = self.model(
                sentence1_input_ids = sentence1.input_ids.cuda(),
                sentence1_attention_mask = sentence1.attention_mask.cuda(),
                sentence2_input_ids = sentence2.input_ids.cuda(),
                sentence2_attention_mask = sentence2.attention_mask.cuda(),
                )
        sem_prob = torch.sigmoid(sem_logit).item()
        return sem_prob

class LingDiscPipeline():
    def __init__(self,
                 model_name="google/flan-t5-base",
                 disc_type='deberta',
                 disc_ckpt='/data/mohamed/checkpoints/ling_disc/deberta-v3-small_flan-t5-base_40',
                 # disc_type='t5',
                 # disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt',
                 ):
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = LingDisc(model_name, disc_type, disc_ckpt)
        self.model.eval()
        self.model.cuda()

    def __call__(self, sentence):
        inputs = self.tokenizer(sentence, return_tensors = 'pt')
        with torch.no_grad():
            ling_pred = self.model(input_ids=inputs.input_ids.cuda())
        return ling_pred