Zaixi commited on
Commit
e8bea69
·
1 Parent(s): 3012314
Files changed (1) hide show
  1. app.py +459 -459
app.py CHANGED
@@ -1,506 +1,506 @@
1
- import spaces
2
- import logging
3
- import gradio as gr
4
- import os
5
- import uuid
6
- from datetime import datetime
7
- import numpy as np
8
- from configs.configs_base import configs as configs_base
9
- from configs.configs_data import data_configs
10
- from configs.configs_inference import inference_configs
11
- from runner.inference import download_infercence_cache, update_inference_configs, infer_predict, infer_detect, InferenceRunner
12
- from protenix.config import parse_configs, parse_sys_args
13
- from runner.msa_search import update_infer_json
14
- from protenix.web_service.prediction_visualization import plot_best_confidence_measure, PredictionLoader
15
- from process_data import process_data
16
- import json
17
- from typing import Dict, List
18
- from Bio.PDB import MMCIFParser, PDBIO
19
- import tempfile
20
- import shutil
21
- from Bio import PDB
22
- from gradio_molecule3d import Molecule3D
23
-
24
- EXAMPLE_PATH = './examples/example.json'
25
- example_json=[{'sequences': [{'proteinChain': {'sequence': 'MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH', 'count': 2}}, {'dnaSequence': {'sequence': 'CTAGGTAACATTACTCGCG', 'count': 2}}, {'dnaSequence': {'sequence': 'GCGAGTAATGTTAC', 'count': 2}}, {'ligand': {'ligand': 'CCD_PCG', 'count': 2}}], 'name': '7pzb_need_search_msa'}]
26
-
27
- # Custom CSS for styling
28
- custom_css = """
29
- #logo {
30
- width: 50%;
31
- }
32
- .title {
33
- font-size: 32px;
34
- font-weight: bold;
35
- color: #4CAF50;
36
- display: flex;
37
- align-items: center; /* Vertically center the logo and text */
38
- }
39
- """
40
-
41
-
42
- os.environ["LAYERNORM_TYPE"] = "fast_layernorm"
43
- os.environ["USE_DEEPSPEED_EVO_ATTTENTION"] = "False"
44
- # Set environment variable in the script
45
- #os.environ['CUTLASS_PATH'] = './cutlass'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # reps = [
48
  # {
49
  # "model": 0,
50
  # "chain": "",
51
  # "resname": "",
52
- # "style": "cartoon", # Use cartoon style
53
  # "color": "whiteCarbon",
54
  # "residue_range": "",
55
  # "around": 0,
56
  # "byres": False,
57
- # "visible": True # Ensure this representation is visible
 
 
 
 
 
 
 
 
 
 
 
58
  # }
59
  # ]
 
60
 
61
- reps = [
62
- {
63
- "model": 0,
64
- "chain": "",
65
- "resname": "",
66
- "style": "cartoon",
67
- "color": "whiteCarbon",
68
- "residue_range": "",
69
- "around": 0,
70
- "byres": False,
71
- "opacity": 0.2,
72
- },
73
- {
74
- "model": 1,
75
- "chain": "",
76
- "resname": "",
77
- "style": "cartoon",
78
- "color": "cyanCarbon",
79
- "residue_range": "",
80
- "around": 0,
81
- "byres": False,
82
- "opacity": 0.8,
83
- }
84
- ]
85
- ##
86
-
87
-
88
- def align_pdb_files(pdb_file_1, pdb_file_2):
89
- # Load the structures
90
- parser = PDB.PPBuilder()
91
- io = PDB.PDBIO()
92
- structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1)
93
- structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2)
94
-
95
- # Superimpose the second structure onto the first
96
- super_imposer = PDB.Superimposer()
97
- model_1 = structure_1[0]
98
- model_2 = structure_2[0]
99
-
100
- # Extract the coordinates from the two structures
101
- atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] # Use CA atoms
102
- atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"]
103
-
104
- # Align the structures based on the CA atoms
105
- coord_1 = [atom.get_coord() for atom in atoms_1]
106
- coord_2 = [atom.get_coord() for atom in atoms_2]
107
 
108
- super_imposer.set_atoms(atoms_1, atoms_2)
109
- super_imposer.apply(model_2) # Apply the transformation to model_2
110
-
111
- # Save the aligned structure back to the original file
112
- io.set_structure(structure_2) # Save the aligned structure to the second file (original file)
113
- io.save(pdb_file_2)
114
-
115
- # Function to convert .cif to .pdb and save as a temporary file
116
- def convert_cif_to_pdb(cif_path):
117
- """
118
- Convert a CIF file to a PDB file and save it as a temporary file.
119
-
120
- Args:
121
- cif_path (str): Path to the input CIF file.
122
-
123
- Returns:
124
- str: Path to the temporary PDB file.
125
- """
126
- # Initialize the MMCIF parser
127
- parser = MMCIFParser()
128
- structure = parser.get_structure("protein", cif_path)
129
-
130
- # Create a temporary file for the PDB output
131
- with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file:
132
- temp_pdb_path = temp_file.name
133
-
134
- # Save the structure as a PDB file
135
- io = PDBIO()
136
- io.set_structure(structure)
137
- io.save(temp_pdb_path)
138
-
139
- return temp_pdb_path
140
-
141
- def plot_3d(pred_loader):
142
- # Get the CIF file path for the given prediction ID
143
- cif_path = sorted(pred_loader.cif_paths)[0]
144
-
145
- # Convert the CIF file to a temporary PDB file
146
- temp_pdb_path = convert_cif_to_pdb(cif_path)
147
-
148
- return temp_pdb_path, cif_path
149
-
150
- def parse_json_input(json_data: List[Dict]) -> Dict:
151
- """Convert Protenix JSON format to UI-friendly structure"""
152
- components = {
153
- "protein_chains": [],
154
- "dna_sequences": [],
155
- "ligands": [],
156
- "complex_name": ""
157
- }
158
 
159
- for entry in json_data:
160
- components["complex_name"] = entry.get("name", "")
161
- for seq in entry["sequences"]:
162
- if "proteinChain" in seq:
163
- components["protein_chains"].append({
164
- "sequence": seq["proteinChain"]["sequence"],
165
- "count": seq["proteinChain"]["count"]
166
- })
167
- elif "dnaSequence" in seq:
168
- components["dna_sequences"].append({
169
- "sequence": seq["dnaSequence"]["sequence"],
170
- "count": seq["dnaSequence"]["count"]
171
- })
172
- elif "ligand" in seq:
173
- components["ligands"].append({
174
- "type": seq["ligand"]["ligand"],
175
- "count": seq["ligand"]["count"]
176
- })
177
- return components
178
-
179
- def create_protenix_json(input_data: Dict) -> List[Dict]:
180
- """Convert UI inputs to Protenix JSON format"""
181
- sequences = []
182
 
183
- for pc in input_data["protein_chains"]:
184
- sequences.append({
185
- "proteinChain": {
186
- "sequence": pc["sequence"],
187
- "count": pc["count"]
188
- }
189
- })
190
 
191
- for dna in input_data["dna_sequences"]:
192
- sequences.append({
193
- "dnaSequence": {
194
- "sequence": dna["sequence"],
195
- "count": dna["count"]
196
- }
197
- })
198
 
199
- for lig in input_data["ligands"]:
200
- sequences.append({
201
- "ligand": {
202
- "ligand": lig["type"],
203
- "count": lig["count"]
204
- }
205
- })
206
 
207
- return [{
208
- "sequences": sequences,
209
- "name": input_data["complex_name"]
210
- }]
211
-
212
-
213
- #@torch.inference_mode()
214
- @spaces.GPU(duration=120) # Specify a duration to avoid timeout
215
- def predict_structure(input_collector: dict):
216
- #first initialize runner
217
- runner = InferenceRunner(configs)
218
- """Handle both input types"""
219
- os.makedirs("./output", exist_ok=True)
220
 
221
- # Generate random filename with timestamp
222
- random_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
223
- save_path = os.path.join("./output", f"{random_name}.json")
224
-
225
- print(input_collector)
226
-
227
- # Handle JSON input
228
- if input_collector["json"]:
229
- # Handle different input types
230
- if isinstance(input_collector["json"], str): # Example JSON case (file path)
231
- input_data = json.load(open(input_collector["json"]))
232
- elif hasattr(input_collector["json"], "name"): # File upload case
233
- input_data = json.load(open(input_collector["json"].name))
234
- else: # Direct JSON data case
235
- input_data = input_collector["json"]
236
- else: # Manual input case
237
- input_data = create_protenix_json(input_collector["data"])
238
-
239
- with open(save_path, "w") as f:
240
- json.dump(input_data, f, indent=2)
241
-
242
- if input_data==example_json and input_collector['watermark']==True:
243
- configs.saved_path = './output/example_output/'
244
- else:
245
- # run msa
246
- json_file = update_infer_json(save_path, './output', True)
247
-
248
- # Run prediction
249
- configs.input_json_path = json_file
250
- configs.watermark = input_collector['watermark']
251
- configs.saved_path = os.path.join("./output/", random_name)
252
- infer_predict(runner, configs)
253
- #saved_path = os.path.join('./output', f"{sample_name}", f"seed_{seed}", 'predictions')
254
-
255
- # Generate visualizations
256
- pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions'))
257
- view3d, cif_path = plot_3d(pred_loader=pred_loader)
258
- if configs.watermark:
259
- pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions_orig'))
260
- view3d_orig, _ = plot_3d(pred_loader=pred_loader)
261
- align_pdb_files(view3d, view3d_orig)
262
- view3d = [view3d, view3d_orig]
263
- plot_best_confidence_measure(os.path.join(configs.saved_path, 'predictions'))
264
- confidence_img_path = os.path.join(os.path.join(configs.saved_path, 'predictions'), "best_sample_confidence.png")
265
-
266
- return view3d, confidence_img_path, cif_path
267
-
268
-
269
- logger = logging.getLogger(__name__)
270
- LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
271
- logging.basicConfig(
272
- format=LOG_FORMAT,
273
- level=logging.INFO,
274
- datefmt="%Y-%m-%d %H:%M:%S",
275
- filemode="w",
276
- )
277
- configs_base["use_deepspeed_evo_attention"] = (
278
- os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "False"
279
- )
280
- arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 "
281
- configs = {**configs_base, **{"data": data_configs}, **inference_configs}
282
- configs = parse_configs(
283
- configs=configs,
284
- arg_str=arg_str,
285
- fill_required_with_null=True,
286
- )
287
- configs.load_checkpoint_path='./checkpoint.pt'
288
- download_infercence_cache()
289
- configs.use_deepspeed_evo_attention=False
290
-
291
- add_watermark = gr.Checkbox(label="Add Watermark", value=True)
292
- add_watermark1 = gr.Checkbox(label="Add Watermark", value=True)
293
-
294
-
295
- with gr.Blocks(title="FoldMark", css=custom_css) as demo:
296
- with gr.Row():
297
- # Use a Column to align the logo and title horizontally
298
- gr.Image(value="./assets/foldmark_head.png", elem_id="logo", label="Logo", height=150, show_label=False)
299
-
300
- with gr.Tab("Structure Predictor (JSON Upload)"):
301
- # First create the upload component
302
- json_upload = gr.File(label="Upload JSON", file_types=[".json"])
303
 
304
- # Then create the example component that references it
305
- gr.Examples(
306
- examples=[[EXAMPLE_PATH]],
307
- inputs=[json_upload],
308
- label="Click to use example JSON:",
309
- examples_per_page=1
310
- )
311
 
312
- # Rest of the components
313
- upload_name = gr.Textbox(label="Complex Name (optional)")
314
- upload_output = gr.JSON(label="Parsed Components")
315
 
316
- json_upload.upload(
317
- fn=lambda f: parse_json_input(json.load(open(f.name))),
318
- inputs=json_upload,
319
- outputs=upload_output
320
- )
321
-
322
- # Shared prediction components
323
- with gr.Row():
324
- add_watermark.render()
325
- submit_btn = gr.Button("Predict Structure", variant="primary")
326
- #structure_view = gr.HTML(label="3D Visualization")
327
-
328
- with gr.Row():
329
- view3d = Molecule3D(label="3D Visualization", reps=reps)
330
- legend = gr.Markdown("""
331
- **Color Legend:**
332
-
333
- - <span style="color:grey">Unwatermarked Structure</span>
334
- - <span style="color:cyan">Watermarked Structure</span>
335
- """)
336
- with gr.Row():
337
- cif_file = gr.File(label="Download CIF File")
338
- with gr.Row():
339
- confidence_plot_image = gr.Image(label="Confidence Measures")
340
 
341
- input_collector = gr.JSON(visible=False)
342
-
343
- # Map inputs to a dictionary
344
- submit_btn.click(
345
- fn=lambda j, w: {"json": j, "watermark": w},
346
- inputs=[json_upload, add_watermark],
347
- outputs=input_collector
348
- ).then(
349
- fn=predict_structure,
350
- inputs=input_collector,
351
- outputs=[view3d, confidence_plot_image, cif_file]
352
- )
353
-
354
- gr.Markdown("""
355
- The example of the uploaded json file for structure prediction.
356
- <pre>
357
- [{
358
- "sequences": [
359
- {
360
- "proteinChain": {
361
- "sequence": "MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH",
362
- "count": 2
363
- }
364
- },
365
- {
366
- "dnaSequence": {
367
- "sequence": "CTAGGTAACATTACTCGCG",
368
- "count": 2
369
- }
370
- },
371
- {
372
- "dnaSequence": {
373
- "sequence": "GCGAGTAATGTTAC",
374
- "count": 2
375
- }
376
- },
377
- {
378
- "ligand": {
379
- "ligand": "CCD_PCG",
380
- "count": 2
381
- }
382
- }
383
- ],
384
- "name": "7pzb"
385
- }]
386
- </pre>
387
- """)
388
 
389
- with gr.Tab("Structure Predictor (Manual Input)"):
390
- with gr.Row():
391
- complex_name = gr.Textbox(label="Complex Name")
392
 
393
- # Replace gr.Group with gr.Accordion
394
- with gr.Accordion(label="Protein Chains", open=True):
395
- protein_chains = gr.Dataframe(
396
- headers=["Sequence", "Count"],
397
- datatype=["str", "number"],
398
- row_count=1,
399
- col_count=(2, "fixed")
400
- )
401
 
402
- # Repeat for other groups
403
- with gr.Accordion(label="DNA Sequences", open=True):
404
- dna_sequences = gr.Dataframe(
405
- headers=["Sequence", "Count"],
406
- datatype=["str", "number"],
407
- row_count=1
408
- )
409
 
410
- with gr.Accordion(label="Ligands", open=True):
411
- ligands = gr.Dataframe(
412
- headers=["Ligand Type", "Count"],
413
- datatype=["str", "number"],
414
- row_count=1
415
- )
416
 
417
- manual_output = gr.JSON(label="Generated JSON")
418
 
419
- complex_name.change(
420
- fn=lambda x: {"complex_name": x},
421
- inputs=complex_name,
422
- outputs=manual_output
423
- )
424
-
425
- # Shared prediction components
426
- with gr.Row():
427
- add_watermark1.render()
428
- submit_btn = gr.Button("Predict Structure", variant="primary")
429
- #structure_view = gr.HTML(label="3D Visualization")
430
-
431
- with gr.Row():
432
- view3d = Molecule3D(label="3D Visualization (Gray: Unwatermarked; Cyan: Watermarked)", reps=reps)
433
-
434
- with gr.Row():
435
- cif_file = gr.File(label="Download CIF File")
436
- with gr.Row():
437
- confidence_plot_image = gr.Image(label="Confidence Measures")
438
 
439
- input_collector = gr.JSON(visible=False)
440
-
441
- # Map inputs to a dictionary
442
- submit_btn.click(
443
- fn=lambda c, p, d, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "ligands": l}, "watermark": w},
444
- inputs=[complex_name, protein_chains, dna_sequences, ligands, add_watermark1],
445
- outputs=input_collector
446
- ).then(
447
- fn=predict_structure,
448
- inputs=input_collector,
449
- outputs=[view3d, confidence_plot_image, cif_file]
450
- )
451
-
452
- @spaces.GPU(duration=120)
453
- def is_watermarked(file):
454
- #first initialize runner
455
- runner = InferenceRunner(configs)
456
- # Generate a unique subdirectory and filename
457
- unique_id = str(uuid.uuid4().hex[:8])
458
- subdir = os.path.join('./output', unique_id)
459
- os.makedirs(subdir, exist_ok=True)
460
- filename = f"{unique_id}.cif"
461
- file_path = os.path.join(subdir, filename)
462
 
463
- # Save the uploaded file to the new location
464
- shutil.copy(file.name, file_path)
465
 
466
- # Call your processing functions
467
- configs.process_success = process_data(subdir)
468
- configs.subdir = subdir
469
- result = infer_detect(runner, configs)
470
- # This function should return 'Watermarked' or 'Not Watermarked'
471
- temp_pdb_path = convert_cif_to_pdb(file_path)
472
- if result==False:
473
- return "Not Watermarked", temp_pdb_path
474
- else:
475
- return "Watermarked", temp_pdb_path
476
 
477
 
478
 
479
- with gr.Tab("Watermark Detector"):
480
- # First create the upload component
481
- cif_upload = gr.File(label="Upload .cif", file_types=["..cif"])
482
 
483
- with gr.Row():
484
- cif_3d_view = Molecule3D(label="3D Visualization of Input", reps=reps)
485
 
486
- # Prediction output
487
- prediction_output = gr.Textbox(label="Prediction")
488
 
489
- # Define the interaction
490
- cif_upload.change(is_watermarked, inputs=cif_upload, outputs=[prediction_output, cif_3d_view])
491
 
492
- # Example files
493
- example_files = [
494
- "./examples/7r6r_watermarked.cif",
495
- "./examples/7pzb_unwatermarked.cif"
496
- ]
497
 
498
- gr.Examples(examples=example_files, inputs=cif_upload)
499
 
500
 
501
 
502
 
503
 
504
 
505
- if __name__ == "__main__":
506
- demo.launch(share=True)
 
1
+ # import spaces
2
+ # import logging
3
+ # import gradio as gr
4
+ # import os
5
+ # import uuid
6
+ # from datetime import datetime
7
+ # import numpy as np
8
+ # from configs.configs_base import configs as configs_base
9
+ # from configs.configs_data import data_configs
10
+ # from configs.configs_inference import inference_configs
11
+ # from runner.inference import download_infercence_cache, update_inference_configs, infer_predict, infer_detect, InferenceRunner
12
+ # from protenix.config import parse_configs, parse_sys_args
13
+ # from runner.msa_search import update_infer_json
14
+ # from protenix.web_service.prediction_visualization import plot_best_confidence_measure, PredictionLoader
15
+ # from process_data import process_data
16
+ # import json
17
+ # from typing import Dict, List
18
+ # from Bio.PDB import MMCIFParser, PDBIO
19
+ # import tempfile
20
+ # import shutil
21
+ # from Bio import PDB
22
+ # from gradio_molecule3d import Molecule3D
23
+
24
+ # EXAMPLE_PATH = './examples/example.json'
25
+ # example_json=[{'sequences': [{'proteinChain': {'sequence': 'MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH', 'count': 2}}, {'dnaSequence': {'sequence': 'CTAGGTAACATTACTCGCG', 'count': 2}}, {'dnaSequence': {'sequence': 'GCGAGTAATGTTAC', 'count': 2}}, {'ligand': {'ligand': 'CCD_PCG', 'count': 2}}], 'name': '7pzb_need_search_msa'}]
26
+
27
+ # # Custom CSS for styling
28
+ # custom_css = """
29
+ # #logo {
30
+ # width: 50%;
31
+ # }
32
+ # .title {
33
+ # font-size: 32px;
34
+ # font-weight: bold;
35
+ # color: #4CAF50;
36
+ # display: flex;
37
+ # align-items: center; /* Vertically center the logo and text */
38
+ # }
39
+ # """
40
+
41
+
42
+ # os.environ["LAYERNORM_TYPE"] = "fast_layernorm"
43
+ # os.environ["USE_DEEPSPEED_EVO_ATTTENTION"] = "False"
44
+ # # Set environment variable in the script
45
+ # #os.environ['CUTLASS_PATH'] = './cutlass'
46
+
47
+ # # reps = [
48
+ # # {
49
+ # # "model": 0,
50
+ # # "chain": "",
51
+ # # "resname": "",
52
+ # # "style": "cartoon", # Use cartoon style
53
+ # # "color": "whiteCarbon",
54
+ # # "residue_range": "",
55
+ # # "around": 0,
56
+ # # "byres": False,
57
+ # # "visible": True # Ensure this representation is visible
58
+ # # }
59
+ # # ]
60
 
61
  # reps = [
62
  # {
63
  # "model": 0,
64
  # "chain": "",
65
  # "resname": "",
66
+ # "style": "cartoon",
67
  # "color": "whiteCarbon",
68
  # "residue_range": "",
69
  # "around": 0,
70
  # "byres": False,
71
+ # "opacity": 0.2,
72
+ # },
73
+ # {
74
+ # "model": 1,
75
+ # "chain": "",
76
+ # "resname": "",
77
+ # "style": "cartoon",
78
+ # "color": "cyanCarbon",
79
+ # "residue_range": "",
80
+ # "around": 0,
81
+ # "byres": False,
82
+ # "opacity": 0.8,
83
  # }
84
  # ]
85
+ # ##
86
 
87
+
88
+ # def align_pdb_files(pdb_file_1, pdb_file_2):
89
+ # # Load the structures
90
+ # parser = PDB.PPBuilder()
91
+ # io = PDB.PDBIO()
92
+ # structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1)
93
+ # structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2)
94
+
95
+ # # Superimpose the second structure onto the first
96
+ # super_imposer = PDB.Superimposer()
97
+ # model_1 = structure_1[0]
98
+ # model_2 = structure_2[0]
99
+
100
+ # # Extract the coordinates from the two structures
101
+ # atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] # Use CA atoms
102
+ # atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"]
103
+
104
+ # # Align the structures based on the CA atoms
105
+ # coord_1 = [atom.get_coord() for atom in atoms_1]
106
+ # coord_2 = [atom.get_coord() for atom in atoms_2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ # super_imposer.set_atoms(atoms_1, atoms_2)
109
+ # super_imposer.apply(model_2) # Apply the transformation to model_2
110
+
111
+ # # Save the aligned structure back to the original file
112
+ # io.set_structure(structure_2) # Save the aligned structure to the second file (original file)
113
+ # io.save(pdb_file_2)
114
+
115
+ # # Function to convert .cif to .pdb and save as a temporary file
116
+ # def convert_cif_to_pdb(cif_path):
117
+ # """
118
+ # Convert a CIF file to a PDB file and save it as a temporary file.
119
+
120
+ # Args:
121
+ # cif_path (str): Path to the input CIF file.
122
+
123
+ # Returns:
124
+ # str: Path to the temporary PDB file.
125
+ # """
126
+ # # Initialize the MMCIF parser
127
+ # parser = MMCIFParser()
128
+ # structure = parser.get_structure("protein", cif_path)
129
+
130
+ # # Create a temporary file for the PDB output
131
+ # with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file:
132
+ # temp_pdb_path = temp_file.name
133
+
134
+ # # Save the structure as a PDB file
135
+ # io = PDBIO()
136
+ # io.set_structure(structure)
137
+ # io.save(temp_pdb_path)
138
+
139
+ # return temp_pdb_path
140
+
141
+ # def plot_3d(pred_loader):
142
+ # # Get the CIF file path for the given prediction ID
143
+ # cif_path = sorted(pred_loader.cif_paths)[0]
144
+
145
+ # # Convert the CIF file to a temporary PDB file
146
+ # temp_pdb_path = convert_cif_to_pdb(cif_path)
147
+
148
+ # return temp_pdb_path, cif_path
149
+
150
+ # def parse_json_input(json_data: List[Dict]) -> Dict:
151
+ # """Convert Protenix JSON format to UI-friendly structure"""
152
+ # components = {
153
+ # "protein_chains": [],
154
+ # "dna_sequences": [],
155
+ # "ligands": [],
156
+ # "complex_name": ""
157
+ # }
158
 
159
+ # for entry in json_data:
160
+ # components["complex_name"] = entry.get("name", "")
161
+ # for seq in entry["sequences"]:
162
+ # if "proteinChain" in seq:
163
+ # components["protein_chains"].append({
164
+ # "sequence": seq["proteinChain"]["sequence"],
165
+ # "count": seq["proteinChain"]["count"]
166
+ # })
167
+ # elif "dnaSequence" in seq:
168
+ # components["dna_sequences"].append({
169
+ # "sequence": seq["dnaSequence"]["sequence"],
170
+ # "count": seq["dnaSequence"]["count"]
171
+ # })
172
+ # elif "ligand" in seq:
173
+ # components["ligands"].append({
174
+ # "type": seq["ligand"]["ligand"],
175
+ # "count": seq["ligand"]["count"]
176
+ # })
177
+ # return components
178
+
179
+ # def create_protenix_json(input_data: Dict) -> List[Dict]:
180
+ # """Convert UI inputs to Protenix JSON format"""
181
+ # sequences = []
182
 
183
+ # for pc in input_data["protein_chains"]:
184
+ # sequences.append({
185
+ # "proteinChain": {
186
+ # "sequence": pc["sequence"],
187
+ # "count": pc["count"]
188
+ # }
189
+ # })
190
 
191
+ # for dna in input_data["dna_sequences"]:
192
+ # sequences.append({
193
+ # "dnaSequence": {
194
+ # "sequence": dna["sequence"],
195
+ # "count": dna["count"]
196
+ # }
197
+ # })
198
 
199
+ # for lig in input_data["ligands"]:
200
+ # sequences.append({
201
+ # "ligand": {
202
+ # "ligand": lig["type"],
203
+ # "count": lig["count"]
204
+ # }
205
+ # })
206
 
207
+ # return [{
208
+ # "sequences": sequences,
209
+ # "name": input_data["complex_name"]
210
+ # }]
211
+
212
+
213
+ # #@torch.inference_mode()
214
+ # @spaces.GPU(duration=120) # Specify a duration to avoid timeout
215
+ # def predict_structure(input_collector: dict):
216
+ # #first initialize runner
217
+ # runner = InferenceRunner(configs)
218
+ # """Handle both input types"""
219
+ # os.makedirs("./output", exist_ok=True)
220
 
221
+ # # Generate random filename with timestamp
222
+ # random_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
223
+ # save_path = os.path.join("./output", f"{random_name}.json")
224
+
225
+ # print(input_collector)
226
+
227
+ # # Handle JSON input
228
+ # if input_collector["json"]:
229
+ # # Handle different input types
230
+ # if isinstance(input_collector["json"], str): # Example JSON case (file path)
231
+ # input_data = json.load(open(input_collector["json"]))
232
+ # elif hasattr(input_collector["json"], "name"): # File upload case
233
+ # input_data = json.load(open(input_collector["json"].name))
234
+ # else: # Direct JSON data case
235
+ # input_data = input_collector["json"]
236
+ # else: # Manual input case
237
+ # input_data = create_protenix_json(input_collector["data"])
238
+
239
+ # with open(save_path, "w") as f:
240
+ # json.dump(input_data, f, indent=2)
241
+
242
+ # if input_data==example_json and input_collector['watermark']==True:
243
+ # configs.saved_path = './output/example_output/'
244
+ # else:
245
+ # # run msa
246
+ # json_file = update_infer_json(save_path, './output', True)
247
+
248
+ # # Run prediction
249
+ # configs.input_json_path = json_file
250
+ # configs.watermark = input_collector['watermark']
251
+ # configs.saved_path = os.path.join("./output/", random_name)
252
+ # infer_predict(runner, configs)
253
+ # #saved_path = os.path.join('./output', f"{sample_name}", f"seed_{seed}", 'predictions')
254
+
255
+ # # Generate visualizations
256
+ # pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions'))
257
+ # view3d, cif_path = plot_3d(pred_loader=pred_loader)
258
+ # if configs.watermark:
259
+ # pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions_orig'))
260
+ # view3d_orig, _ = plot_3d(pred_loader=pred_loader)
261
+ # align_pdb_files(view3d, view3d_orig)
262
+ # view3d = [view3d, view3d_orig]
263
+ # plot_best_confidence_measure(os.path.join(configs.saved_path, 'predictions'))
264
+ # confidence_img_path = os.path.join(os.path.join(configs.saved_path, 'predictions'), "best_sample_confidence.png")
265
+
266
+ # return view3d, confidence_img_path, cif_path
267
+
268
+
269
+ # logger = logging.getLogger(__name__)
270
+ # LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
271
+ # logging.basicConfig(
272
+ # format=LOG_FORMAT,
273
+ # level=logging.INFO,
274
+ # datefmt="%Y-%m-%d %H:%M:%S",
275
+ # filemode="w",
276
+ # )
277
+ # configs_base["use_deepspeed_evo_attention"] = (
278
+ # os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "False"
279
+ # )
280
+ # arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 "
281
+ # configs = {**configs_base, **{"data": data_configs}, **inference_configs}
282
+ # configs = parse_configs(
283
+ # configs=configs,
284
+ # arg_str=arg_str,
285
+ # fill_required_with_null=True,
286
+ # )
287
+ # configs.load_checkpoint_path='./checkpoint.pt'
288
+ # download_infercence_cache()
289
+ # configs.use_deepspeed_evo_attention=False
290
+
291
+ # add_watermark = gr.Checkbox(label="Add Watermark", value=True)
292
+ # add_watermark1 = gr.Checkbox(label="Add Watermark", value=True)
293
+
294
+
295
+ # with gr.Blocks(title="FoldMark", css=custom_css) as demo:
296
+ # with gr.Row():
297
+ # # Use a Column to align the logo and title horizontally
298
+ # gr.Image(value="./assets/foldmark_head.png", elem_id="logo", label="Logo", height=150, show_label=False)
299
+
300
+ # with gr.Tab("Structure Predictor (JSON Upload)"):
301
+ # # First create the upload component
302
+ # json_upload = gr.File(label="Upload JSON", file_types=[".json"])
303
 
304
+ # # Then create the example component that references it
305
+ # gr.Examples(
306
+ # examples=[[EXAMPLE_PATH]],
307
+ # inputs=[json_upload],
308
+ # label="Click to use example JSON:",
309
+ # examples_per_page=1
310
+ # )
311
 
312
+ # # Rest of the components
313
+ # upload_name = gr.Textbox(label="Complex Name (optional)")
314
+ # upload_output = gr.JSON(label="Parsed Components")
315
 
316
+ # json_upload.upload(
317
+ # fn=lambda f: parse_json_input(json.load(open(f.name))),
318
+ # inputs=json_upload,
319
+ # outputs=upload_output
320
+ # )
321
+
322
+ # # Shared prediction components
323
+ # with gr.Row():
324
+ # add_watermark.render()
325
+ # submit_btn = gr.Button("Predict Structure", variant="primary")
326
+ # #structure_view = gr.HTML(label="3D Visualization")
327
+
328
+ # with gr.Row():
329
+ # view3d = Molecule3D(label="3D Visualization", reps=reps)
330
+ # legend = gr.Markdown("""
331
+ # **Color Legend:**
332
+
333
+ # - <span style="color:grey">Unwatermarked Structure</span>
334
+ # - <span style="color:cyan">Watermarked Structure</span>
335
+ # """)
336
+ # with gr.Row():
337
+ # cif_file = gr.File(label="Download CIF File")
338
+ # with gr.Row():
339
+ # confidence_plot_image = gr.Image(label="Confidence Measures")
340
 
341
+ # input_collector = gr.JSON(visible=False)
342
+
343
+ # # Map inputs to a dictionary
344
+ # submit_btn.click(
345
+ # fn=lambda j, w: {"json": j, "watermark": w},
346
+ # inputs=[json_upload, add_watermark],
347
+ # outputs=input_collector
348
+ # ).then(
349
+ # fn=predict_structure,
350
+ # inputs=input_collector,
351
+ # outputs=[view3d, confidence_plot_image, cif_file]
352
+ # )
353
+
354
+ # gr.Markdown("""
355
+ # The example of the uploaded json file for structure prediction.
356
+ # <pre>
357
+ # [{
358
+ # "sequences": [
359
+ # {
360
+ # "proteinChain": {
361
+ # "sequence": "MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH",
362
+ # "count": 2
363
+ # }
364
+ # },
365
+ # {
366
+ # "dnaSequence": {
367
+ # "sequence": "CTAGGTAACATTACTCGCG",
368
+ # "count": 2
369
+ # }
370
+ # },
371
+ # {
372
+ # "dnaSequence": {
373
+ # "sequence": "GCGAGTAATGTTAC",
374
+ # "count": 2
375
+ # }
376
+ # },
377
+ # {
378
+ # "ligand": {
379
+ # "ligand": "CCD_PCG",
380
+ # "count": 2
381
+ # }
382
+ # }
383
+ # ],
384
+ # "name": "7pzb"
385
+ # }]
386
+ # </pre>
387
+ # """)
388
 
389
+ # with gr.Tab("Structure Predictor (Manual Input)"):
390
+ # with gr.Row():
391
+ # complex_name = gr.Textbox(label="Complex Name")
392
 
393
+ # # Replace gr.Group with gr.Accordion
394
+ # with gr.Accordion(label="Protein Chains", open=True):
395
+ # protein_chains = gr.Dataframe(
396
+ # headers=["Sequence", "Count"],
397
+ # datatype=["str", "number"],
398
+ # row_count=1,
399
+ # col_count=(2, "fixed")
400
+ # )
401
 
402
+ # # Repeat for other groups
403
+ # with gr.Accordion(label="DNA Sequences", open=True):
404
+ # dna_sequences = gr.Dataframe(
405
+ # headers=["Sequence", "Count"],
406
+ # datatype=["str", "number"],
407
+ # row_count=1
408
+ # )
409
 
410
+ # with gr.Accordion(label="Ligands", open=True):
411
+ # ligands = gr.Dataframe(
412
+ # headers=["Ligand Type", "Count"],
413
+ # datatype=["str", "number"],
414
+ # row_count=1
415
+ # )
416
 
417
+ # manual_output = gr.JSON(label="Generated JSON")
418
 
419
+ # complex_name.change(
420
+ # fn=lambda x: {"complex_name": x},
421
+ # inputs=complex_name,
422
+ # outputs=manual_output
423
+ # )
424
+
425
+ # # Shared prediction components
426
+ # with gr.Row():
427
+ # add_watermark1.render()
428
+ # submit_btn = gr.Button("Predict Structure", variant="primary")
429
+ # #structure_view = gr.HTML(label="3D Visualization")
430
+
431
+ # with gr.Row():
432
+ # view3d = Molecule3D(label="3D Visualization (Gray: Unwatermarked; Cyan: Watermarked)", reps=reps)
433
+
434
+ # with gr.Row():
435
+ # cif_file = gr.File(label="Download CIF File")
436
+ # with gr.Row():
437
+ # confidence_plot_image = gr.Image(label="Confidence Measures")
438
 
439
+ # input_collector = gr.JSON(visible=False)
440
+
441
+ # # Map inputs to a dictionary
442
+ # submit_btn.click(
443
+ # fn=lambda c, p, d, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "ligands": l}, "watermark": w},
444
+ # inputs=[complex_name, protein_chains, dna_sequences, ligands, add_watermark1],
445
+ # outputs=input_collector
446
+ # ).then(
447
+ # fn=predict_structure,
448
+ # inputs=input_collector,
449
+ # outputs=[view3d, confidence_plot_image, cif_file]
450
+ # )
451
+
452
+ # @spaces.GPU(duration=120)
453
+ # def is_watermarked(file):
454
+ # #first initialize runner
455
+ # runner = InferenceRunner(configs)
456
+ # # Generate a unique subdirectory and filename
457
+ # unique_id = str(uuid.uuid4().hex[:8])
458
+ # subdir = os.path.join('./output', unique_id)
459
+ # os.makedirs(subdir, exist_ok=True)
460
+ # filename = f"{unique_id}.cif"
461
+ # file_path = os.path.join(subdir, filename)
462
 
463
+ # # Save the uploaded file to the new location
464
+ # shutil.copy(file.name, file_path)
465
 
466
+ # # Call your processing functions
467
+ # configs.process_success = process_data(subdir)
468
+ # configs.subdir = subdir
469
+ # result = infer_detect(runner, configs)
470
+ # # This function should return 'Watermarked' or 'Not Watermarked'
471
+ # temp_pdb_path = convert_cif_to_pdb(file_path)
472
+ # if result==False:
473
+ # return "Not Watermarked", temp_pdb_path
474
+ # else:
475
+ # return "Watermarked", temp_pdb_path
476
 
477
 
478
 
479
+ # with gr.Tab("Watermark Detector"):
480
+ # # First create the upload component
481
+ # cif_upload = gr.File(label="Upload .cif", file_types=["..cif"])
482
 
483
+ # with gr.Row():
484
+ # cif_3d_view = Molecule3D(label="3D Visualization of Input", reps=reps)
485
 
486
+ # # Prediction output
487
+ # prediction_output = gr.Textbox(label="Prediction")
488
 
489
+ # # Define the interaction
490
+ # cif_upload.change(is_watermarked, inputs=cif_upload, outputs=[prediction_output, cif_3d_view])
491
 
492
+ # # Example files
493
+ # example_files = [
494
+ # "./examples/7r6r_watermarked.cif",
495
+ # "./examples/7pzb_unwatermarked.cif"
496
+ # ]
497
 
498
+ # gr.Examples(examples=example_files, inputs=cif_upload)
499
 
500
 
501
 
502
 
503
 
504
 
505
+ # if __name__ == "__main__":
506
+ # demo.launch(share=True)