Update app.py
Browse files
app.py
CHANGED
@@ -176,40 +176,112 @@ def parse_json_input(json_data: List[Dict]) -> Dict:
|
|
176 |
})
|
177 |
return components
|
178 |
|
|
|
179 |
def create_protenix_json(input_data: Dict) -> List[Dict]:
|
180 |
-
"""Convert UI inputs to Protenix JSON format"""
|
181 |
sequences = []
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
"
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
"
|
204 |
-
|
205 |
-
|
|
|
|
|
206 |
|
207 |
return [{
|
208 |
"sequences": sequences,
|
209 |
-
"name": input_data
|
210 |
}]
|
211 |
|
212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
#@torch.inference_mode()
|
214 |
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
215 |
def predict_structure(input_collector: dict):
|
@@ -225,7 +297,7 @@ def predict_structure(input_collector: dict):
|
|
225 |
print(input_collector)
|
226 |
|
227 |
# Handle JSON input
|
228 |
-
if
|
229 |
# Handle different input types
|
230 |
if isinstance(input_collector["json"], str): # Example JSON case (file path)
|
231 |
input_data = json.load(open(input_collector["json"]))
|
@@ -406,31 +478,44 @@ with gr.Blocks(title="FoldMark", css=custom_css) as demo:
|
|
406 |
headers=["Sequence", "Count"],
|
407 |
datatype=["str", "number"],
|
408 |
row_count=1,
|
409 |
-
col_count=(2, "fixed")
|
|
|
410 |
)
|
411 |
|
412 |
# Repeat for other groups
|
413 |
-
with gr.Accordion(label="DNA Sequences", open=True):
|
414 |
dna_sequences = gr.Dataframe(
|
415 |
headers=["Sequence", "Count"],
|
416 |
datatype=["str", "number"],
|
417 |
-
row_count=1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
)
|
419 |
|
420 |
with gr.Accordion(label="Ligands", open=True):
|
421 |
ligands = gr.Dataframe(
|
422 |
headers=["Ligand Type", "Count"],
|
423 |
datatype=["str", "number"],
|
424 |
-
row_count=1
|
|
|
425 |
)
|
426 |
|
427 |
manual_output = gr.JSON(label="Generated JSON")
|
428 |
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
|
|
|
|
434 |
|
435 |
# Shared prediction components
|
436 |
with gr.Row():
|
@@ -450,8 +535,8 @@ with gr.Blocks(title="FoldMark", css=custom_css) as demo:
|
|
450 |
|
451 |
# Map inputs to a dictionary
|
452 |
submit_btn.click(
|
453 |
-
fn=lambda c, p, d, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "ligands": l}, "watermark": w},
|
454 |
-
inputs=[complex_name, protein_chains, dna_sequences, ligands, add_watermark1],
|
455 |
outputs=input_collector
|
456 |
).then(
|
457 |
fn=predict_structure,
|
|
|
176 |
})
|
177 |
return components
|
178 |
|
179 |
+
|
180 |
def create_protenix_json(input_data: Dict) -> List[Dict]:
|
|
|
181 |
sequences = []
|
182 |
|
183 |
+
# Process protein chains
|
184 |
+
for pc in input_data.get("protein_chains", []):
|
185 |
+
# Check that the row has both columns and the sequence is nonempty.
|
186 |
+
if len(pc) >= 2 and pc[0].strip():
|
187 |
+
sequences.append({
|
188 |
+
"proteinChain": {
|
189 |
+
"sequence": pc[0].strip(),
|
190 |
+
"count": int(pc[1]) if pc[1] else 1
|
191 |
+
}
|
192 |
+
})
|
193 |
|
194 |
+
# Process DNA sequences
|
195 |
+
for dna in input_data.get("dna_sequences", []):
|
196 |
+
if len(dna) >= 2 and dna[0].strip():
|
197 |
+
sequences.append({
|
198 |
+
"dnaSequence": {
|
199 |
+
"sequence": dna[0].strip(),
|
200 |
+
"count": int(dna[1]) if dna[1] else 1
|
201 |
+
}
|
202 |
+
})
|
203 |
+
|
204 |
+
# Process RNA sequences
|
205 |
+
for rna in input_data.get("rna_sequences", []):
|
206 |
+
if len(rna) >= 2 and rna[0].strip():
|
207 |
+
sequences.append({
|
208 |
+
"rnaSequence": {
|
209 |
+
"sequence": rna[0].strip(),
|
210 |
+
"count": int(rna[1]) if rna[1] else 1
|
211 |
+
}
|
212 |
+
})
|
213 |
|
214 |
+
# Process ligands
|
215 |
+
for lig in input_data.get("ligands", []):
|
216 |
+
if len(lig) >= 2 and lig[0].strip():
|
217 |
+
sequences.append({
|
218 |
+
"ligand": {
|
219 |
+
"ligand": lig[0].strip(),
|
220 |
+
"count": int(lig[1]) if lig[1] else 1
|
221 |
+
}
|
222 |
+
})
|
223 |
|
224 |
return [{
|
225 |
"sequences": sequences,
|
226 |
+
"name": input_data.get("complex_name")+f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:3]}"
|
227 |
}]
|
228 |
|
229 |
|
230 |
+
def update_json(complex_name, protein_chains, dna_sequences, rna_sequences, ligands):
|
231 |
+
sequences_list = []
|
232 |
+
|
233 |
+
# Process protein chains (DataFrame with headers: ["Sequence", "Count"])
|
234 |
+
if protein_chains:
|
235 |
+
for row in protein_chains:
|
236 |
+
# Check if the row is valid and non-empty
|
237 |
+
if row and len(row) >= 2 and row[0]:
|
238 |
+
sequences_list.append({
|
239 |
+
"proteinChain": {
|
240 |
+
"sequence": row[0],
|
241 |
+
"count": row[1]
|
242 |
+
}
|
243 |
+
})
|
244 |
+
|
245 |
+
# Process DNA sequences
|
246 |
+
if dna_sequences:
|
247 |
+
for row in dna_sequences:
|
248 |
+
if row and len(row) >= 2 and row[0]:
|
249 |
+
sequences_list.append({
|
250 |
+
"dnaSequence": {
|
251 |
+
"sequence": row[0],
|
252 |
+
"count": row[1]
|
253 |
+
}
|
254 |
+
})
|
255 |
+
|
256 |
+
# Process RNA sequences
|
257 |
+
if rna_sequences:
|
258 |
+
for row in rna_sequences:
|
259 |
+
if row and len(row) >= 2 and row[0]:
|
260 |
+
sequences_list.append({
|
261 |
+
"rnaSequence": {
|
262 |
+
"sequence": row[0],
|
263 |
+
"count": row[1]
|
264 |
+
}
|
265 |
+
})
|
266 |
+
|
267 |
+
# Process ligands (DataFrame with headers: ["Ligand Type", "Count"])
|
268 |
+
if ligands:
|
269 |
+
for row in ligands:
|
270 |
+
if row and len(row) >= 2 and row[0]:
|
271 |
+
sequences_list.append({
|
272 |
+
"ligand": {
|
273 |
+
"ligand": row[0],
|
274 |
+
"count": row[1]
|
275 |
+
}
|
276 |
+
})
|
277 |
+
|
278 |
+
return {
|
279 |
+
"sequences": sequences_list,
|
280 |
+
"name": complex_name
|
281 |
+
}
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
#@torch.inference_mode()
|
286 |
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
287 |
def predict_structure(input_collector: dict):
|
|
|
297 |
print(input_collector)
|
298 |
|
299 |
# Handle JSON input
|
300 |
+
if "json" in input_collector:
|
301 |
# Handle different input types
|
302 |
if isinstance(input_collector["json"], str): # Example JSON case (file path)
|
303 |
input_data = json.load(open(input_collector["json"]))
|
|
|
478 |
headers=["Sequence", "Count"],
|
479 |
datatype=["str", "number"],
|
480 |
row_count=1,
|
481 |
+
col_count=(2, "fixed"),
|
482 |
+
type="array"
|
483 |
)
|
484 |
|
485 |
# Repeat for other groups
|
486 |
+
with gr.Accordion(label="DNA Sequences (A T G C)", open=True):
|
487 |
dna_sequences = gr.Dataframe(
|
488 |
headers=["Sequence", "Count"],
|
489 |
datatype=["str", "number"],
|
490 |
+
row_count=1,
|
491 |
+
type="array"
|
492 |
+
)
|
493 |
+
|
494 |
+
with gr.Accordion(label="RNA Sequences (A U G C)", open=True):
|
495 |
+
rna_sequences = gr.Dataframe(
|
496 |
+
headers=["Sequence", "Count"],
|
497 |
+
datatype=["str", "number"],
|
498 |
+
row_count=1,
|
499 |
+
type="array"
|
500 |
)
|
501 |
|
502 |
with gr.Accordion(label="Ligands", open=True):
|
503 |
ligands = gr.Dataframe(
|
504 |
headers=["Ligand Type", "Count"],
|
505 |
datatype=["str", "number"],
|
506 |
+
row_count=1,
|
507 |
+
type="array"
|
508 |
)
|
509 |
|
510 |
manual_output = gr.JSON(label="Generated JSON")
|
511 |
|
512 |
+
# Attach a change event to all widgets so that any change updates the JSON output.
|
513 |
+
for widget in [complex_name, protein_chains, dna_sequences, rna_sequences, ligands]:
|
514 |
+
widget.change(
|
515 |
+
fn=update_json,
|
516 |
+
inputs=[complex_name, protein_chains, dna_sequences, rna_sequences, ligands],
|
517 |
+
outputs=manual_output
|
518 |
+
)
|
519 |
|
520 |
# Shared prediction components
|
521 |
with gr.Row():
|
|
|
535 |
|
536 |
# Map inputs to a dictionary
|
537 |
submit_btn.click(
|
538 |
+
fn=lambda c, p, d, r, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "rna_sequences": r, "ligands": l}, "watermark": w},
|
539 |
+
inputs=[complex_name, protein_chains, dna_sequences, rna_sequences, ligands, add_watermark1],
|
540 |
outputs=input_collector
|
541 |
).then(
|
542 |
fn=predict_structure,
|