Kseniia-Kholina commited on
Commit
71532a4
·
verified ·
1 Parent(s): 431091a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -3
app.py CHANGED
@@ -1,10 +1,73 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
 
 
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name
 
 
 
 
6
 
 
 
 
 
7
 
8
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  demo.launch()
 
1
+ import transformers
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
3
+ import logging
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ import numpy as np
8
  import gradio as gr
9
 
10
+ def get_heatmap(sequence):
11
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"Using device: {device}")
14
 
15
+ # Load the tokenizer and model
16
+ model_name = "ChatterjeeLab/FusOn-pLM"
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
18
+ model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
19
+ model.to(device)
20
+ model.eval()
21
 
22
+ all_logits = []
23
+ for i in range(len(sequence)):
24
+ # add a masked token
25
+ masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
26
 
27
+ # tokenize masked sequence
28
+ inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True,max_length=2000)
29
+ inputs = {k: v.to(device) for k, v in inputs.items()}
30
+
31
+ # predict logits for the masked token
32
+ with torch.no_grad():
33
+ logits = model(**inputs).logits
34
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
35
+ mask_token_logits = logits[0, mask_token_index, :]
36
+ top_1_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].item()
37
+ logits_array = mask_token_logits.cpu().numpy()
38
+
39
+ # filter out non-amino acid tokens
40
+ filtered_indices = list(range(4, 23 + 1))
41
+ filtered_logits = logits_array[:, filtered_indices]
42
+ all_logits.append(filtered_logits)
43
+
44
+ token_indices = torch.arange(logits.size(-1))
45
+ tokens = [tokenizer.decode([idx]) for idx in token_indices]
46
+ filtered_tokens = [tokens[i] for i in filtered_indices]
47
+
48
+ all_logits_array = np.vstack(all_logits)
49
+ normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min())
50
+ transposed_logits_array = normalized_logits_array.T
51
+
52
+
53
+
54
+ # Plotting the heatmap
55
+ step = 50
56
+ y_tick_positions = np.arange(0, len(sequence), step)
57
+ y_tick_labels = [str(pos) for pos in y_tick_positions]
58
+
59
+ plt.figure(figsize=(15, 8))
60
+ sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
61
+ plt.title('Logits for masked per residue tokens')
62
+ plt.ylabel('Token')
63
+ plt.xlabel('Residue Index')
64
+ plt.yticks(rotation=0)
65
+ plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
66
+
67
+ return plt
68
+
69
+
70
+
71
+ demo = gr.Interface(fn=get_heatmap, inputs="text", outputs="image")
72
 
73
  demo.launch()