Update app.py
Browse files
app.py
CHANGED
@@ -48,24 +48,24 @@ def process_sequence(sequence, domain_bounds, n):
|
|
48 |
|
49 |
# checking domain bounds inputs
|
50 |
try:
|
51 |
-
start = int(domain_bounds['start'][0])
|
52 |
end = int(domain_bounds['end'][0])
|
53 |
except ValueError:
|
54 |
raise gr.Error("Error: Start and end indices must be integers.")
|
55 |
return None, None, None
|
56 |
if start >= end:
|
57 |
-
raise gr.Error("Start index must be smaller than end index.")
|
58 |
return None, None, None
|
59 |
if start == 0 and end != 0:
|
60 |
-
raise gr.Error("Indexing starts at 1. Please enter valid domain bounds.")
|
61 |
return None, None, None
|
62 |
-
if start
|
63 |
-
raise gr.Error("Domain bounds
|
64 |
return None, None, None
|
65 |
if start > len(sequence) or end > len(sequence):
|
66 |
raise gr.Error("Domain bounds exceed sequence length.")
|
67 |
return None, None, None
|
68 |
-
|
69 |
# checking n inputs
|
70 |
if n == None:
|
71 |
raise gr.Error("Choose Top N Tokens from the dropdown menu.")
|
@@ -117,10 +117,12 @@ def process_sequence(sequence, domain_bounds, n):
|
|
117 |
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
|
118 |
transposed_logits_array = normalized_logits_array.T
|
119 |
|
120 |
-
|
121 |
domain_len = end - start
|
122 |
-
if domain_len > 100:
|
123 |
step_size = 50
|
|
|
|
|
124 |
elif domain_len < 10:
|
125 |
step_size = 1
|
126 |
else:
|
|
|
48 |
|
49 |
# checking domain bounds inputs
|
50 |
try:
|
51 |
+
start = int(domain_bounds['start'][0])
|
52 |
end = int(domain_bounds['end'][0])
|
53 |
except ValueError:
|
54 |
raise gr.Error("Error: Start and end indices must be integers.")
|
55 |
return None, None, None
|
56 |
if start >= end:
|
57 |
+
raise gr.Error("Start index must be smaller than end index.")
|
58 |
return None, None, None
|
59 |
if start == 0 and end != 0:
|
60 |
+
raise gr.Error("Indexing starts at 1. Please enter valid domain bounds.")
|
61 |
return None, None, None
|
62 |
+
if start <= 0 or end <= 0:
|
63 |
+
raise gr.Error("Domain bounds must be positive integers. Please enter valid domain bounds.")
|
64 |
return None, None, None
|
65 |
if start > len(sequence) or end > len(sequence):
|
66 |
raise gr.Error("Domain bounds exceed sequence length.")
|
67 |
return None, None, None
|
68 |
+
|
69 |
# checking n inputs
|
70 |
if n == None:
|
71 |
raise gr.Error("Choose Top N Tokens from the dropdown menu.")
|
|
|
117 |
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
|
118 |
transposed_logits_array = normalized_logits_array.T
|
119 |
|
120 |
+
# Plotting the heatmap
|
121 |
domain_len = end - start
|
122 |
+
if 500 > domain_len > 100:
|
123 |
step_size = 50
|
124 |
+
elif 500 <= domain_len:
|
125 |
+
step_size = 100
|
126 |
elif domain_len < 10:
|
127 |
step_size = 1
|
128 |
else:
|