Kseniia-Kholina commited on
Commit
e7999a2
·
verified ·
1 Parent(s): 0224cdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -4
app.py CHANGED
@@ -26,8 +26,51 @@ model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
26
  model.to(device)
27
  model.eval()
28
 
 
 
 
 
 
 
 
 
 
29
 
30
  def process_sequence(sequence, domain_bounds, n):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  start_index = int(domain_bounds['start'][0]) - 1
32
  end_index = int(domain_bounds['end'][0])
33
 
@@ -45,7 +88,6 @@ def process_sequence(sequence, domain_bounds, n):
45
  mask_token_logits = logits[0, mask_token_index, :]
46
 
47
  # Define amino acid tokens
48
- AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
49
  all_tokens_logits = mask_token_logits.squeeze(0)
50
  top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True)
51
  top_tokens_logits = all_tokens_logits[top_tokens_indices]
@@ -91,7 +133,7 @@ def process_sequence(sequence, domain_bounds, n):
91
 
92
  # Save the figure to a BytesIO object
93
  buf = BytesIO()
94
- plt.savefig(buf, format='png', dpi=(300, 300))
95
  buf.seek(0)
96
  plt.close()
97
 
@@ -114,7 +156,7 @@ def process_sequence(sequence, domain_bounds, n):
114
  'Position': positions
115
  })
116
  df.to_csv("predicted_tokens.csv", index=False)
117
- img.save("heatmap.png", dpi = 300)
118
  zip_path = "outputs.zip"
119
  with zipfile.ZipFile(zip_path, 'w') as zipf:
120
  zipf.write("predicted_tokens.csv")
@@ -143,4 +185,5 @@ demo = gr.Interface(
143
  ],
144
  )
145
  if __name__ == "__main__":
146
- demo.launch()
 
 
26
  model.to(device)
27
  model.eval()
28
 
29
+ @contextmanager
30
+ def suppress_output():
31
+ with open(os.devnull, 'w') as devnull:
32
+ old_stdout = sys.stdout
33
+ sys.stdout = devnull
34
+ try:
35
+ yield
36
+ finally:
37
+ sys.stdout = old_stdout
38
 
39
  def process_sequence(sequence, domain_bounds, n):
40
+ AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
41
+ # checking sequence inputs
42
+ if not sequence.strip():
43
+ raise gr.Error("Error: The sequence input is empty. Please enter a valid protein sequence.")
44
+ return None, None, None
45
+ if any(char not in AAs_tokens for char in sequence):
46
+ raise gr.Error("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.")
47
+ return None, None, None
48
+
49
+ # checking domain bounds inputs
50
+ try:
51
+ start = int(domain_bounds['start'][0])
52
+ end = int(domain_bounds['end'][0])
53
+ except ValueError:
54
+ raise gr.Error("Error: Start and end indices must be integers.")
55
+ return None, None, None
56
+ if start >= end:
57
+ raise gr.Error("Start index must be smaller than end index.")
58
+ return None, None, None
59
+ if start == 0 and end != 0:
60
+ raise gr.Error("Indexing starts at 1. Please enter valid domain bounds.")
61
+ return None, None, None
62
+ if start == 0 or end == 0:
63
+ raise gr.Error("Domain bounds cannot be zero. Please enter valid domain bounds.")
64
+ return None, None, None
65
+ if start > len(sequence) or end > len(sequence):
66
+ raise gr.Error("Domain bounds exceed sequence length.")
67
+ return None, None, None
68
+
69
+ # checking n inputs
70
+ if n == None:
71
+ raise gr.Error("Choose Top N Tokens from the dropdown menu.")
72
+ return None, None, None
73
+
74
  start_index = int(domain_bounds['start'][0]) - 1
75
  end_index = int(domain_bounds['end'][0])
76
 
 
88
  mask_token_logits = logits[0, mask_token_index, :]
89
 
90
  # Define amino acid tokens
 
91
  all_tokens_logits = mask_token_logits.squeeze(0)
92
  top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True)
93
  top_tokens_logits = all_tokens_logits[top_tokens_indices]
 
133
 
134
  # Save the figure to a BytesIO object
135
  buf = BytesIO()
136
+ plt.savefig(buf, format='png', dpi = 300)
137
  buf.seek(0)
138
  plt.close()
139
 
 
156
  'Position': positions
157
  })
158
  df.to_csv("predicted_tokens.csv", index=False)
159
+ img.save("heatmap.png", dpi=(300, 300))
160
  zip_path = "outputs.zip"
161
  with zipfile.ZipFile(zip_path, 'w') as zipf:
162
  zipf.write("predicted_tokens.csv")
 
185
  ],
186
  )
187
  if __name__ == "__main__":
188
+ with suppress_output():
189
+ demo.launch()