Zaixi commited on
Commit
604af1c
·
verified ·
1 Parent(s): 0b938ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -35
app.py CHANGED
@@ -176,40 +176,112 @@ def parse_json_input(json_data: List[Dict]) -> Dict:
176
  })
177
  return components
178
 
 
179
  def create_protenix_json(input_data: Dict) -> List[Dict]:
180
- """Convert UI inputs to Protenix JSON format"""
181
  sequences = []
182
 
183
- for pc in input_data["protein_chains"]:
184
- sequences.append({
185
- "proteinChain": {
186
- "sequence": pc["sequence"],
187
- "count": pc["count"]
188
- }
189
- })
 
 
 
190
 
191
- for dna in input_data["dna_sequences"]:
192
- sequences.append({
193
- "dnaSequence": {
194
- "sequence": dna["sequence"],
195
- "count": dna["count"]
196
- }
197
- })
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- for lig in input_data["ligands"]:
200
- sequences.append({
201
- "ligand": {
202
- "ligand": lig["type"],
203
- "count": lig["count"]
204
- }
205
- })
 
 
206
 
207
  return [{
208
  "sequences": sequences,
209
- "name": input_data["complex_name"]
210
  }]
211
 
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  #@torch.inference_mode()
214
  @spaces.GPU(duration=120) # Specify a duration to avoid timeout
215
  def predict_structure(input_collector: dict):
@@ -225,7 +297,7 @@ def predict_structure(input_collector: dict):
225
  print(input_collector)
226
 
227
  # Handle JSON input
228
- if input_collector["json"]:
229
  # Handle different input types
230
  if isinstance(input_collector["json"], str): # Example JSON case (file path)
231
  input_data = json.load(open(input_collector["json"]))
@@ -406,31 +478,44 @@ with gr.Blocks(title="FoldMark", css=custom_css) as demo:
406
  headers=["Sequence", "Count"],
407
  datatype=["str", "number"],
408
  row_count=1,
409
- col_count=(2, "fixed")
 
410
  )
411
 
412
  # Repeat for other groups
413
- with gr.Accordion(label="DNA Sequences", open=True):
414
  dna_sequences = gr.Dataframe(
415
  headers=["Sequence", "Count"],
416
  datatype=["str", "number"],
417
- row_count=1
 
 
 
 
 
 
 
 
 
418
  )
419
 
420
  with gr.Accordion(label="Ligands", open=True):
421
  ligands = gr.Dataframe(
422
  headers=["Ligand Type", "Count"],
423
  datatype=["str", "number"],
424
- row_count=1
 
425
  )
426
 
427
  manual_output = gr.JSON(label="Generated JSON")
428
 
429
- complex_name.change(
430
- fn=lambda x: {"complex_name": x},
431
- inputs=complex_name,
432
- outputs=manual_output
433
- )
 
 
434
 
435
  # Shared prediction components
436
  with gr.Row():
@@ -450,8 +535,8 @@ with gr.Blocks(title="FoldMark", css=custom_css) as demo:
450
 
451
  # Map inputs to a dictionary
452
  submit_btn.click(
453
- fn=lambda c, p, d, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "ligands": l}, "watermark": w},
454
- inputs=[complex_name, protein_chains, dna_sequences, ligands, add_watermark1],
455
  outputs=input_collector
456
  ).then(
457
  fn=predict_structure,
 
176
  })
177
  return components
178
 
179
+
180
  def create_protenix_json(input_data: Dict) -> List[Dict]:
 
181
  sequences = []
182
 
183
+ # Process protein chains
184
+ for pc in input_data.get("protein_chains", []):
185
+ # Check that the row has both columns and the sequence is nonempty.
186
+ if len(pc) >= 2 and pc[0].strip():
187
+ sequences.append({
188
+ "proteinChain": {
189
+ "sequence": pc[0].strip(),
190
+ "count": int(pc[1]) if pc[1] else 1
191
+ }
192
+ })
193
 
194
+ # Process DNA sequences
195
+ for dna in input_data.get("dna_sequences", []):
196
+ if len(dna) >= 2 and dna[0].strip():
197
+ sequences.append({
198
+ "dnaSequence": {
199
+ "sequence": dna[0].strip(),
200
+ "count": int(dna[1]) if dna[1] else 1
201
+ }
202
+ })
203
+
204
+ # Process RNA sequences
205
+ for rna in input_data.get("rna_sequences", []):
206
+ if len(rna) >= 2 and rna[0].strip():
207
+ sequences.append({
208
+ "rnaSequence": {
209
+ "sequence": rna[0].strip(),
210
+ "count": int(rna[1]) if rna[1] else 1
211
+ }
212
+ })
213
 
214
+ # Process ligands
215
+ for lig in input_data.get("ligands", []):
216
+ if len(lig) >= 2 and lig[0].strip():
217
+ sequences.append({
218
+ "ligand": {
219
+ "ligand": lig[0].strip(),
220
+ "count": int(lig[1]) if lig[1] else 1
221
+ }
222
+ })
223
 
224
  return [{
225
  "sequences": sequences,
226
+ "name": input_data.get("complex_name")+f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:3]}"
227
  }]
228
 
229
 
230
+ def update_json(complex_name, protein_chains, dna_sequences, rna_sequences, ligands):
231
+ sequences_list = []
232
+
233
+ # Process protein chains (DataFrame with headers: ["Sequence", "Count"])
234
+ if protein_chains:
235
+ for row in protein_chains:
236
+ # Check if the row is valid and non-empty
237
+ if row and len(row) >= 2 and row[0]:
238
+ sequences_list.append({
239
+ "proteinChain": {
240
+ "sequence": row[0],
241
+ "count": row[1]
242
+ }
243
+ })
244
+
245
+ # Process DNA sequences
246
+ if dna_sequences:
247
+ for row in dna_sequences:
248
+ if row and len(row) >= 2 and row[0]:
249
+ sequences_list.append({
250
+ "dnaSequence": {
251
+ "sequence": row[0],
252
+ "count": row[1]
253
+ }
254
+ })
255
+
256
+ # Process RNA sequences
257
+ if rna_sequences:
258
+ for row in rna_sequences:
259
+ if row and len(row) >= 2 and row[0]:
260
+ sequences_list.append({
261
+ "rnaSequence": {
262
+ "sequence": row[0],
263
+ "count": row[1]
264
+ }
265
+ })
266
+
267
+ # Process ligands (DataFrame with headers: ["Ligand Type", "Count"])
268
+ if ligands:
269
+ for row in ligands:
270
+ if row and len(row) >= 2 and row[0]:
271
+ sequences_list.append({
272
+ "ligand": {
273
+ "ligand": row[0],
274
+ "count": row[1]
275
+ }
276
+ })
277
+
278
+ return {
279
+ "sequences": sequences_list,
280
+ "name": complex_name
281
+ }
282
+
283
+
284
+
285
  #@torch.inference_mode()
286
  @spaces.GPU(duration=120) # Specify a duration to avoid timeout
287
  def predict_structure(input_collector: dict):
 
297
  print(input_collector)
298
 
299
  # Handle JSON input
300
+ if "json" in input_collector:
301
  # Handle different input types
302
  if isinstance(input_collector["json"], str): # Example JSON case (file path)
303
  input_data = json.load(open(input_collector["json"]))
 
478
  headers=["Sequence", "Count"],
479
  datatype=["str", "number"],
480
  row_count=1,
481
+ col_count=(2, "fixed"),
482
+ type="array"
483
  )
484
 
485
  # Repeat for other groups
486
+ with gr.Accordion(label="DNA Sequences (A T G C)", open=True):
487
  dna_sequences = gr.Dataframe(
488
  headers=["Sequence", "Count"],
489
  datatype=["str", "number"],
490
+ row_count=1,
491
+ type="array"
492
+ )
493
+
494
+ with gr.Accordion(label="RNA Sequences (A U G C)", open=True):
495
+ rna_sequences = gr.Dataframe(
496
+ headers=["Sequence", "Count"],
497
+ datatype=["str", "number"],
498
+ row_count=1,
499
+ type="array"
500
  )
501
 
502
  with gr.Accordion(label="Ligands", open=True):
503
  ligands = gr.Dataframe(
504
  headers=["Ligand Type", "Count"],
505
  datatype=["str", "number"],
506
+ row_count=1,
507
+ type="array"
508
  )
509
 
510
  manual_output = gr.JSON(label="Generated JSON")
511
 
512
+ # Attach a change event to all widgets so that any change updates the JSON output.
513
+ for widget in [complex_name, protein_chains, dna_sequences, rna_sequences, ligands]:
514
+ widget.change(
515
+ fn=update_json,
516
+ inputs=[complex_name, protein_chains, dna_sequences, rna_sequences, ligands],
517
+ outputs=manual_output
518
+ )
519
 
520
  # Shared prediction components
521
  with gr.Row():
 
535
 
536
  # Map inputs to a dictionary
537
  submit_btn.click(
538
+ fn=lambda c, p, d, r, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "rna_sequences": r, "ligands": l}, "watermark": w},
539
+ inputs=[complex_name, protein_chains, dna_sequences, rna_sequences, ligands, add_watermark1],
540
  outputs=input_collector
541
  ).then(
542
  fn=predict_structure,