Update app.py
Browse files
app.py
CHANGED
@@ -29,26 +29,37 @@ def process_sequence(sequence, domain_bounds, n):
|
|
29 |
all_logits = []
|
30 |
|
31 |
for i in range(len(sequence)):
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
# filter out non-amino acid tokens
|
42 |
-
filtered_indices = list(range(4, 23 + 1))
|
43 |
-
filtered_logits = mask_token_logits[:, filtered_indices]
|
44 |
-
# Decode top n tokens
|
45 |
-
top_n_tokens = torch.topk(filtered_logits, n, dim=1).indices[0].tolist()
|
46 |
-
mutation = [tokenizer.decode([token]) for token in top_n_tokens]
|
47 |
-
top_n_mutations[(sequence[i], i)] = mutation
|
48 |
-
|
49 |
-
logits_array = mask_token_logits.cpu().numpy()
|
50 |
-
filtered_logits = logits_array[:, filtered_indices]
|
51 |
-
all_logits.append(filtered_logits)
|
52 |
|
53 |
token_indices = torch.arange(logits.size(-1))
|
54 |
tokens = [tokenizer.decode([idx]) for idx in token_indices]
|
@@ -63,7 +74,7 @@ def process_sequence(sequence, domain_bounds, n):
|
|
63 |
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
|
64 |
|
65 |
plt.figure(figsize=(15, 8))
|
66 |
-
plt.rcParams.update({'font.size':
|
67 |
|
68 |
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
|
69 |
plt.title('Token Probability Heatmap')
|
|
|
29 |
all_logits = []
|
30 |
|
31 |
for i in range(len(sequence)):
|
32 |
+
if start_index <= i <= (end_index - 1):
|
33 |
+
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
|
34 |
+
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
|
35 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
36 |
+
with torch.no_grad():
|
37 |
+
logits = model(**inputs).logits
|
38 |
+
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
|
39 |
+
mask_token_logits = logits[0, mask_token_index, :]
|
40 |
+
|
41 |
+
# Define amino acid tokens
|
42 |
+
AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
|
43 |
+
all_tokens_logits = mask_token_logits.squeeze(0)
|
44 |
+
top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True)
|
45 |
+
top_tokens_logits = all_tokens_logits[top_tokens_indices]
|
46 |
+
mutation = []
|
47 |
+
# make sure we don't include non-AA tokens
|
48 |
+
for token_index in top_tokens_indices:
|
49 |
+
decoded_token = tokenizer.decode([token_index.item()])
|
50 |
+
if decoded_token in AAs_tokens:
|
51 |
+
mutation.append(decoded_token)
|
52 |
+
if len(mutation) == n:
|
53 |
+
break
|
54 |
+
top_n_mutations[(sequence[i], i)] = mutation
|
55 |
+
|
56 |
+
# collecting logits for the heatmap
|
57 |
+
logits_array = mask_token_logits.cpu().numpy()
|
58 |
+
# filter out non-amino acid tokens
|
59 |
+
filtered_indices = list(range(4, 23 + 1))
|
60 |
+
filtered_logits = logits_array[:, filtered_indices]
|
61 |
+
all_logits.append(filtered_logits)
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
token_indices = torch.arange(logits.size(-1))
|
65 |
tokens = [tokenizer.decode([idx]) for idx in token_indices]
|
|
|
74 |
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
|
75 |
|
76 |
plt.figure(figsize=(15, 8))
|
77 |
+
plt.rcParams.update({'font.size': 18})
|
78 |
|
79 |
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
|
80 |
plt.title('Token Probability Heatmap')
|