import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch.nn.functional as F
import logging
import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
from PIL import Image

logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()

def process_sequence(sequence, domain_bounds, n):
    start_index = int(domain_bounds['start'][0]) - 1  
    end_index = int(domain_bounds['end'][0])      

    top_n_mutations = {}
    all_logits = []

    for i in range(len(sequence)):
          if start_index <= i <= (end_index - 1):
              masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
              inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
              inputs = {k: v.to(device) for k, v in inputs.items()}
              with torch.no_grad():
                  logits = model(**inputs).logits
              mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
              mask_token_logits = logits[0, mask_token_index, :]

              # Define amino acid tokens
              AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
              all_tokens_logits = mask_token_logits.squeeze(0) 
              top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True)
              top_tokens_logits = all_tokens_logits[top_tokens_indices]
              mutation = []
              # make sure we don't include non-AA tokens
              for token_index in top_tokens_indices:
                  decoded_token = tokenizer.decode([token_index.item()])
                  if decoded_token in AAs_tokens:
                      mutation.append(decoded_token)
                      if len(mutation) == n:
                          break
              top_n_mutations[(sequence[i], i)] = mutation

              # collecting logits for the heatmap
              logits_array = mask_token_logits.cpu().numpy()
              # filter out non-amino acid tokens
              filtered_indices = list(range(4, 23 + 1))
              filtered_logits = logits_array[:, filtered_indices]
              all_logits.append(filtered_logits)

    
    token_indices = torch.arange(logits.size(-1))
    tokens = [tokenizer.decode([idx]) for idx in token_indices]
    filtered_tokens = [tokens[i] for i in filtered_indices]
    
    all_logits_array = np.vstack(all_logits)
    normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
    transposed_logits_array = normalized_logits_array.T

   # Plotting the heatmap
    x_tick_positions = np.arange(start_index, end_index, 10)
    x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
    
    plt.figure(figsize=(15, 8))
    plt.rcParams.update({'font.size': 18})

    sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
    plt.title('Token Probability Heatmap')
    plt.ylabel('Token')
    plt.xlabel('Residue Index')
    plt.yticks(rotation=0)
    plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0) 
    
    # Save the figure to a BytesIO object
    buf = BytesIO()
    plt.savefig(buf, format='png', dpi = 300)
    buf.seek(0)
    plt.close()
    
    # Convert BytesIO object to an image
    img = Image.open(buf)

    original_residues = []
    mutations = []
    positions = []

    for key, value in top_n_mutations.items():
        original_residue, position = key
        original_residues.append(original_residue)
        mutations.append(value)
        positions.append(position + 1)

    df = pd.DataFrame({
        'Original Residue': original_residues,
        'Predicted Residues': mutations,
        'Position': positions
    })
            
    return df, img

demo = gr.Interface(
    fn=process_sequence,
    inputs=[
        gr.Textbox(label="Sequence", placeholder="Enter the protein sequence here"),
        gr.Dataframe(
            headers=["start", "end"],
            datatype=["number", "number"],
            row_count=(1, "fixed"),
            col_count=(2, "fixed"),
            label="Domain Bounds"
        ),
        gr.Dropdown([i for i in range(1, 21)], label="Top N Tokens"),  
    ],
     outputs=[
        gr.Dataframe(label="Predicted Tokens (in order of decreasing likelihood)"),
        gr.Image(type="pil", label="Heatmap"),
    ],
    description="Choose a number from the dropdown to predict N tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).",
)

if __name__ == "__main__":
    demo.launch()