Update app.py
Browse files
app.py
CHANGED
@@ -37,14 +37,16 @@ def process_sequence(sequence, domain_bounds, n):
|
|
37 |
logits = model(**inputs).logits
|
38 |
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
|
39 |
mask_token_logits = logits[0, mask_token_index, :]
|
|
|
|
|
|
|
|
|
40 |
# Decode top n tokens
|
41 |
-
top_n_tokens = torch.topk(
|
42 |
mutation = [tokenizer.decode([token]) for token in top_n_tokens]
|
43 |
top_n_mutations[(sequence[i], i)] = mutation
|
44 |
|
45 |
logits_array = mask_token_logits.cpu().numpy()
|
46 |
-
# filter out non-amino acid tokens
|
47 |
-
filtered_indices = list(range(4, 23 + 1))
|
48 |
filtered_logits = logits_array[:, filtered_indices]
|
49 |
all_logits.append(filtered_logits)
|
50 |
|
@@ -66,7 +68,7 @@ def process_sequence(sequence, domain_bounds, n):
|
|
66 |
plt.ylabel('Token')
|
67 |
plt.xlabel('Residue Index')
|
68 |
plt.yticks(rotation=0)
|
69 |
-
plt.xticks(x_tick_positions - start_index, x_tick_labels, rotation=0)
|
70 |
|
71 |
# Save the figure to a BytesIO object
|
72 |
buf = BytesIO()
|
|
|
37 |
logits = model(**inputs).logits
|
38 |
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
|
39 |
mask_token_logits = logits[0, mask_token_index, :]
|
40 |
+
|
41 |
+
# filter out non-amino acid tokens
|
42 |
+
filtered_indices = list(range(4, 23 + 1))
|
43 |
+
filtered_logits = mask_token_logits[:, filtered_indices]
|
44 |
# Decode top n tokens
|
45 |
+
top_n_tokens = torch.topk(filtered_logits, n, dim=1).indices[0].tolist()
|
46 |
mutation = [tokenizer.decode([token]) for token in top_n_tokens]
|
47 |
top_n_mutations[(sequence[i], i)] = mutation
|
48 |
|
49 |
logits_array = mask_token_logits.cpu().numpy()
|
|
|
|
|
50 |
filtered_logits = logits_array[:, filtered_indices]
|
51 |
all_logits.append(filtered_logits)
|
52 |
|
|
|
68 |
plt.ylabel('Token')
|
69 |
plt.xlabel('Residue Index')
|
70 |
plt.yticks(rotation=0)
|
71 |
+
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
|
72 |
|
73 |
# Save the figure to a BytesIO object
|
74 |
buf = BytesIO()
|