gelnesr commited on
Commit
d8524ba
·
verified ·
1 Parent(s): ff70084

move model run

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -139,7 +139,14 @@ def handle_name(name=None, pdb_input=None, model_version="ESM3"):
139
  pdb_name = str(random.randint(0, 100000))
140
  return f'{pdb_name}-Dyna1{"" if model_version == "ESM3" else "-ESM2"}'
141
 
142
- @spaces.GPU(duration=1200)
 
 
 
 
 
 
 
143
  def predict_dynamics(sequence=None, pdb_input=None, chain_id='A', use_pdb_seq=False, model_version="ESM3", name=None, oauth_token: Optional[str] = None):
144
  try:
145
  # Validate ESM2 requires sequence
@@ -151,7 +158,6 @@ def predict_dynamics(sequence=None, pdb_input=None, chain_id='A', use_pdb_seq=Fa
151
 
152
  base_name = handle_name(name, pdb_input, model_version)
153
 
154
-
155
  if model_version == "ESM3":
156
  model = ESM_model(method='esm3').to(DEVICE)
157
  model.load_state_dict(torch.load('Dyna-1/model/weights/dyna1.pt', map_location=DEVICE), strict=False)
@@ -187,11 +193,9 @@ def predict_dynamics(sequence=None, pdb_input=None, chain_id='A', use_pdb_seq=Fa
187
 
188
  if not (sequence or (pdb_input and model_version == "ESM3")):
189
  raise ValueError('Please provide a sequence' + (' or structure input' if model_version == "ESM3" else ''))
190
-
191
- if model_version == "ESM3":
192
- logits = model((seq_input, struct_input), sequence_id)
193
- else:
194
- logits = model(seq_input, sequence_id)
195
  probabilities = utils.prob_adjusted(logits).cpu().detach().numpy()
196
 
197
  seq_to_use = sequence if sequence else pdb_seq if pdb_input else sequence
 
139
  pdb_name = str(random.randint(0, 100000))
140
  return f'{pdb_name}-Dyna1{"" if model_version == "ESM3" else "-ESM2"}'
141
 
142
+ @spaces.GPU(duration=300)
143
+ def run_model(model_version='ESM2', seq_input=None, struct_input=None, sequence_id=None):
144
+ if model_version == "ESM3":
145
+ logits = model((seq_input, struct_input), sequence_id)
146
+ else:
147
+ logits = model(seq_input, sequence_id)
148
+ return logits
149
+
150
  def predict_dynamics(sequence=None, pdb_input=None, chain_id='A', use_pdb_seq=False, model_version="ESM3", name=None, oauth_token: Optional[str] = None):
151
  try:
152
  # Validate ESM2 requires sequence
 
158
 
159
  base_name = handle_name(name, pdb_input, model_version)
160
 
 
161
  if model_version == "ESM3":
162
  model = ESM_model(method='esm3').to(DEVICE)
163
  model.load_state_dict(torch.load('Dyna-1/model/weights/dyna1.pt', map_location=DEVICE), strict=False)
 
193
 
194
  if not (sequence or (pdb_input and model_version == "ESM3")):
195
  raise ValueError('Please provide a sequence' + (' or structure input' if model_version == "ESM3" else ''))
196
+
197
+ logits = run_model(model_version, seq_input, struct_input, sequence_id)
198
+
 
 
199
  probabilities = utils.prob_adjusted(logits).cpu().detach().numpy()
200
 
201
  seq_to_use = sequence if sequence else pdb_seq if pdb_input else sequence