Update app.py
Browse files
app.py
CHANGED
@@ -28,7 +28,7 @@ def process_sequence(sequence, domain_bounds, n):
|
|
28 |
all_logits = []
|
29 |
|
30 |
for i in range(len(sequence)):
|
31 |
-
if start_index <= i <= end_index:
|
32 |
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
|
33 |
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
|
34 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
@@ -56,8 +56,7 @@ def process_sequence(sequence, domain_bounds, n):
|
|
56 |
transposed_logits_array = normalized_logits_array.T
|
57 |
|
58 |
# Plotting the heatmap
|
59 |
-
|
60 |
-
y_tick_positions = np.arange(0, len(sequence), step)
|
61 |
y_tick_labels = [str(pos) for pos in y_tick_positions]
|
62 |
|
63 |
plt.figure(figsize=(15, 8))
|
|
|
28 |
all_logits = []
|
29 |
|
30 |
for i in range(len(sequence)):
|
31 |
+
if start_index <= i <= (end_index - 1):
|
32 |
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
|
33 |
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
|
34 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
56 |
transposed_logits_array = normalized_logits_array.T
|
57 |
|
58 |
# Plotting the heatmap
|
59 |
+
y_tick_positions = np.arange((start_index+1), end_index + 1, 10)
|
|
|
60 |
y_tick_labels = [str(pos) for pos in y_tick_positions]
|
61 |
|
62 |
plt.figure(figsize=(15, 8))
|