Zaixi commited on
Commit
b09b599
·
1 Parent(s): e8bea69
Files changed (1) hide show
  1. app.py +459 -459
app.py CHANGED
@@ -1,506 +1,506 @@
1
- # import spaces
2
- # import logging
3
- # import gradio as gr
4
- # import os
5
- # import uuid
6
- # from datetime import datetime
7
- # import numpy as np
8
- # from configs.configs_base import configs as configs_base
9
- # from configs.configs_data import data_configs
10
- # from configs.configs_inference import inference_configs
11
- # from runner.inference import download_infercence_cache, update_inference_configs, infer_predict, infer_detect, InferenceRunner
12
- # from protenix.config import parse_configs, parse_sys_args
13
- # from runner.msa_search import update_infer_json
14
- # from protenix.web_service.prediction_visualization import plot_best_confidence_measure, PredictionLoader
15
- # from process_data import process_data
16
- # import json
17
- # from typing import Dict, List
18
- # from Bio.PDB import MMCIFParser, PDBIO
19
- # import tempfile
20
- # import shutil
21
- # from Bio import PDB
22
- # from gradio_molecule3d import Molecule3D
23
-
24
- # EXAMPLE_PATH = './examples/example.json'
25
- # example_json=[{'sequences': [{'proteinChain': {'sequence': 'MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH', 'count': 2}}, {'dnaSequence': {'sequence': 'CTAGGTAACATTACTCGCG', 'count': 2}}, {'dnaSequence': {'sequence': 'GCGAGTAATGTTAC', 'count': 2}}, {'ligand': {'ligand': 'CCD_PCG', 'count': 2}}], 'name': '7pzb_need_search_msa'}]
26
-
27
- # # Custom CSS for styling
28
- # custom_css = """
29
- # #logo {
30
- # width: 50%;
31
- # }
32
- # .title {
33
- # font-size: 32px;
34
- # font-weight: bold;
35
- # color: #4CAF50;
36
- # display: flex;
37
- # align-items: center; /* Vertically center the logo and text */
38
- # }
39
- # """
40
-
41
-
42
- # os.environ["LAYERNORM_TYPE"] = "fast_layernorm"
43
- # os.environ["USE_DEEPSPEED_EVO_ATTTENTION"] = "False"
44
- # # Set environment variable in the script
45
- # #os.environ['CUTLASS_PATH'] = './cutlass'
46
-
47
- # # reps = [
48
- # # {
49
- # # "model": 0,
50
- # # "chain": "",
51
- # # "resname": "",
52
- # # "style": "cartoon", # Use cartoon style
53
- # # "color": "whiteCarbon",
54
- # # "residue_range": "",
55
- # # "around": 0,
56
- # # "byres": False,
57
- # # "visible": True # Ensure this representation is visible
58
- # # }
59
- # # ]
60
 
61
  # reps = [
62
  # {
63
  # "model": 0,
64
  # "chain": "",
65
  # "resname": "",
66
- # "style": "cartoon",
67
  # "color": "whiteCarbon",
68
  # "residue_range": "",
69
  # "around": 0,
70
  # "byres": False,
71
- # "opacity": 0.2,
72
- # },
73
- # {
74
- # "model": 1,
75
- # "chain": "",
76
- # "resname": "",
77
- # "style": "cartoon",
78
- # "color": "cyanCarbon",
79
- # "residue_range": "",
80
- # "around": 0,
81
- # "byres": False,
82
- # "opacity": 0.8,
83
  # }
84
  # ]
85
- # ##
86
 
87
-
88
- # def align_pdb_files(pdb_file_1, pdb_file_2):
89
- # # Load the structures
90
- # parser = PDB.PPBuilder()
91
- # io = PDB.PDBIO()
92
- # structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1)
93
- # structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2)
94
-
95
- # # Superimpose the second structure onto the first
96
- # super_imposer = PDB.Superimposer()
97
- # model_1 = structure_1[0]
98
- # model_2 = structure_2[0]
99
-
100
- # # Extract the coordinates from the two structures
101
- # atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] # Use CA atoms
102
- # atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"]
103
-
104
- # # Align the structures based on the CA atoms
105
- # coord_1 = [atom.get_coord() for atom in atoms_1]
106
- # coord_2 = [atom.get_coord() for atom in atoms_2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- # super_imposer.set_atoms(atoms_1, atoms_2)
109
- # super_imposer.apply(model_2) # Apply the transformation to model_2
110
-
111
- # # Save the aligned structure back to the original file
112
- # io.set_structure(structure_2) # Save the aligned structure to the second file (original file)
113
- # io.save(pdb_file_2)
114
-
115
- # # Function to convert .cif to .pdb and save as a temporary file
116
- # def convert_cif_to_pdb(cif_path):
117
- # """
118
- # Convert a CIF file to a PDB file and save it as a temporary file.
119
-
120
- # Args:
121
- # cif_path (str): Path to the input CIF file.
122
-
123
- # Returns:
124
- # str: Path to the temporary PDB file.
125
- # """
126
- # # Initialize the MMCIF parser
127
- # parser = MMCIFParser()
128
- # structure = parser.get_structure("protein", cif_path)
129
-
130
- # # Create a temporary file for the PDB output
131
- # with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file:
132
- # temp_pdb_path = temp_file.name
133
-
134
- # # Save the structure as a PDB file
135
- # io = PDBIO()
136
- # io.set_structure(structure)
137
- # io.save(temp_pdb_path)
138
-
139
- # return temp_pdb_path
140
-
141
- # def plot_3d(pred_loader):
142
- # # Get the CIF file path for the given prediction ID
143
- # cif_path = sorted(pred_loader.cif_paths)[0]
144
-
145
- # # Convert the CIF file to a temporary PDB file
146
- # temp_pdb_path = convert_cif_to_pdb(cif_path)
147
-
148
- # return temp_pdb_path, cif_path
149
-
150
- # def parse_json_input(json_data: List[Dict]) -> Dict:
151
- # """Convert Protenix JSON format to UI-friendly structure"""
152
- # components = {
153
- # "protein_chains": [],
154
- # "dna_sequences": [],
155
- # "ligands": [],
156
- # "complex_name": ""
157
- # }
158
 
159
- # for entry in json_data:
160
- # components["complex_name"] = entry.get("name", "")
161
- # for seq in entry["sequences"]:
162
- # if "proteinChain" in seq:
163
- # components["protein_chains"].append({
164
- # "sequence": seq["proteinChain"]["sequence"],
165
- # "count": seq["proteinChain"]["count"]
166
- # })
167
- # elif "dnaSequence" in seq:
168
- # components["dna_sequences"].append({
169
- # "sequence": seq["dnaSequence"]["sequence"],
170
- # "count": seq["dnaSequence"]["count"]
171
- # })
172
- # elif "ligand" in seq:
173
- # components["ligands"].append({
174
- # "type": seq["ligand"]["ligand"],
175
- # "count": seq["ligand"]["count"]
176
- # })
177
- # return components
178
-
179
- # def create_protenix_json(input_data: Dict) -> List[Dict]:
180
- # """Convert UI inputs to Protenix JSON format"""
181
- # sequences = []
182
 
183
- # for pc in input_data["protein_chains"]:
184
- # sequences.append({
185
- # "proteinChain": {
186
- # "sequence": pc["sequence"],
187
- # "count": pc["count"]
188
- # }
189
- # })
190
 
191
- # for dna in input_data["dna_sequences"]:
192
- # sequences.append({
193
- # "dnaSequence": {
194
- # "sequence": dna["sequence"],
195
- # "count": dna["count"]
196
- # }
197
- # })
198
 
199
- # for lig in input_data["ligands"]:
200
- # sequences.append({
201
- # "ligand": {
202
- # "ligand": lig["type"],
203
- # "count": lig["count"]
204
- # }
205
- # })
206
 
207
- # return [{
208
- # "sequences": sequences,
209
- # "name": input_data["complex_name"]
210
- # }]
211
-
212
-
213
- # #@torch.inference_mode()
214
- # @spaces.GPU(duration=120) # Specify a duration to avoid timeout
215
- # def predict_structure(input_collector: dict):
216
- # #first initialize runner
217
- # runner = InferenceRunner(configs)
218
- # """Handle both input types"""
219
- # os.makedirs("./output", exist_ok=True)
220
 
221
- # # Generate random filename with timestamp
222
- # random_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
223
- # save_path = os.path.join("./output", f"{random_name}.json")
224
-
225
- # print(input_collector)
226
-
227
- # # Handle JSON input
228
- # if input_collector["json"]:
229
- # # Handle different input types
230
- # if isinstance(input_collector["json"], str): # Example JSON case (file path)
231
- # input_data = json.load(open(input_collector["json"]))
232
- # elif hasattr(input_collector["json"], "name"): # File upload case
233
- # input_data = json.load(open(input_collector["json"].name))
234
- # else: # Direct JSON data case
235
- # input_data = input_collector["json"]
236
- # else: # Manual input case
237
- # input_data = create_protenix_json(input_collector["data"])
238
-
239
- # with open(save_path, "w") as f:
240
- # json.dump(input_data, f, indent=2)
241
-
242
- # if input_data==example_json and input_collector['watermark']==True:
243
- # configs.saved_path = './output/example_output/'
244
- # else:
245
- # # run msa
246
- # json_file = update_infer_json(save_path, './output', True)
247
-
248
- # # Run prediction
249
- # configs.input_json_path = json_file
250
- # configs.watermark = input_collector['watermark']
251
- # configs.saved_path = os.path.join("./output/", random_name)
252
- # infer_predict(runner, configs)
253
- # #saved_path = os.path.join('./output', f"{sample_name}", f"seed_{seed}", 'predictions')
254
-
255
- # # Generate visualizations
256
- # pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions'))
257
- # view3d, cif_path = plot_3d(pred_loader=pred_loader)
258
- # if configs.watermark:
259
- # pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions_orig'))
260
- # view3d_orig, _ = plot_3d(pred_loader=pred_loader)
261
- # align_pdb_files(view3d, view3d_orig)
262
- # view3d = [view3d, view3d_orig]
263
- # plot_best_confidence_measure(os.path.join(configs.saved_path, 'predictions'))
264
- # confidence_img_path = os.path.join(os.path.join(configs.saved_path, 'predictions'), "best_sample_confidence.png")
265
-
266
- # return view3d, confidence_img_path, cif_path
267
-
268
-
269
- # logger = logging.getLogger(__name__)
270
- # LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
271
- # logging.basicConfig(
272
- # format=LOG_FORMAT,
273
- # level=logging.INFO,
274
- # datefmt="%Y-%m-%d %H:%M:%S",
275
- # filemode="w",
276
- # )
277
- # configs_base["use_deepspeed_evo_attention"] = (
278
- # os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "False"
279
- # )
280
- # arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 "
281
- # configs = {**configs_base, **{"data": data_configs}, **inference_configs}
282
- # configs = parse_configs(
283
- # configs=configs,
284
- # arg_str=arg_str,
285
- # fill_required_with_null=True,
286
- # )
287
- # configs.load_checkpoint_path='./checkpoint.pt'
288
- # download_infercence_cache()
289
- # configs.use_deepspeed_evo_attention=False
290
-
291
- # add_watermark = gr.Checkbox(label="Add Watermark", value=True)
292
- # add_watermark1 = gr.Checkbox(label="Add Watermark", value=True)
293
-
294
-
295
- # with gr.Blocks(title="FoldMark", css=custom_css) as demo:
296
- # with gr.Row():
297
- # # Use a Column to align the logo and title horizontally
298
- # gr.Image(value="./assets/foldmark_head.png", elem_id="logo", label="Logo", height=150, show_label=False)
299
-
300
- # with gr.Tab("Structure Predictor (JSON Upload)"):
301
- # # First create the upload component
302
- # json_upload = gr.File(label="Upload JSON", file_types=[".json"])
303
 
304
- # # Then create the example component that references it
305
- # gr.Examples(
306
- # examples=[[EXAMPLE_PATH]],
307
- # inputs=[json_upload],
308
- # label="Click to use example JSON:",
309
- # examples_per_page=1
310
- # )
311
 
312
- # # Rest of the components
313
- # upload_name = gr.Textbox(label="Complex Name (optional)")
314
- # upload_output = gr.JSON(label="Parsed Components")
315
 
316
- # json_upload.upload(
317
- # fn=lambda f: parse_json_input(json.load(open(f.name))),
318
- # inputs=json_upload,
319
- # outputs=upload_output
320
- # )
321
-
322
- # # Shared prediction components
323
- # with gr.Row():
324
- # add_watermark.render()
325
- # submit_btn = gr.Button("Predict Structure", variant="primary")
326
- # #structure_view = gr.HTML(label="3D Visualization")
327
-
328
- # with gr.Row():
329
- # view3d = Molecule3D(label="3D Visualization", reps=reps)
330
- # legend = gr.Markdown("""
331
- # **Color Legend:**
332
-
333
- # - <span style="color:grey">Unwatermarked Structure</span>
334
- # - <span style="color:cyan">Watermarked Structure</span>
335
- # """)
336
- # with gr.Row():
337
- # cif_file = gr.File(label="Download CIF File")
338
- # with gr.Row():
339
- # confidence_plot_image = gr.Image(label="Confidence Measures")
340
 
341
- # input_collector = gr.JSON(visible=False)
342
-
343
- # # Map inputs to a dictionary
344
- # submit_btn.click(
345
- # fn=lambda j, w: {"json": j, "watermark": w},
346
- # inputs=[json_upload, add_watermark],
347
- # outputs=input_collector
348
- # ).then(
349
- # fn=predict_structure,
350
- # inputs=input_collector,
351
- # outputs=[view3d, confidence_plot_image, cif_file]
352
- # )
353
-
354
- # gr.Markdown("""
355
- # The example of the uploaded json file for structure prediction.
356
- # <pre>
357
- # [{
358
- # "sequences": [
359
- # {
360
- # "proteinChain": {
361
- # "sequence": "MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH",
362
- # "count": 2
363
- # }
364
- # },
365
- # {
366
- # "dnaSequence": {
367
- # "sequence": "CTAGGTAACATTACTCGCG",
368
- # "count": 2
369
- # }
370
- # },
371
- # {
372
- # "dnaSequence": {
373
- # "sequence": "GCGAGTAATGTTAC",
374
- # "count": 2
375
- # }
376
- # },
377
- # {
378
- # "ligand": {
379
- # "ligand": "CCD_PCG",
380
- # "count": 2
381
- # }
382
- # }
383
- # ],
384
- # "name": "7pzb"
385
- # }]
386
- # </pre>
387
- # """)
388
 
389
- # with gr.Tab("Structure Predictor (Manual Input)"):
390
- # with gr.Row():
391
- # complex_name = gr.Textbox(label="Complex Name")
392
 
393
- # # Replace gr.Group with gr.Accordion
394
- # with gr.Accordion(label="Protein Chains", open=True):
395
- # protein_chains = gr.Dataframe(
396
- # headers=["Sequence", "Count"],
397
- # datatype=["str", "number"],
398
- # row_count=1,
399
- # col_count=(2, "fixed")
400
- # )
401
 
402
- # # Repeat for other groups
403
- # with gr.Accordion(label="DNA Sequences", open=True):
404
- # dna_sequences = gr.Dataframe(
405
- # headers=["Sequence", "Count"],
406
- # datatype=["str", "number"],
407
- # row_count=1
408
- # )
409
 
410
- # with gr.Accordion(label="Ligands", open=True):
411
- # ligands = gr.Dataframe(
412
- # headers=["Ligand Type", "Count"],
413
- # datatype=["str", "number"],
414
- # row_count=1
415
- # )
416
 
417
- # manual_output = gr.JSON(label="Generated JSON")
418
 
419
- # complex_name.change(
420
- # fn=lambda x: {"complex_name": x},
421
- # inputs=complex_name,
422
- # outputs=manual_output
423
- # )
424
-
425
- # # Shared prediction components
426
- # with gr.Row():
427
- # add_watermark1.render()
428
- # submit_btn = gr.Button("Predict Structure", variant="primary")
429
- # #structure_view = gr.HTML(label="3D Visualization")
430
-
431
- # with gr.Row():
432
- # view3d = Molecule3D(label="3D Visualization (Gray: Unwatermarked; Cyan: Watermarked)", reps=reps)
433
-
434
- # with gr.Row():
435
- # cif_file = gr.File(label="Download CIF File")
436
- # with gr.Row():
437
- # confidence_plot_image = gr.Image(label="Confidence Measures")
438
 
439
- # input_collector = gr.JSON(visible=False)
440
-
441
- # # Map inputs to a dictionary
442
- # submit_btn.click(
443
- # fn=lambda c, p, d, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "ligands": l}, "watermark": w},
444
- # inputs=[complex_name, protein_chains, dna_sequences, ligands, add_watermark1],
445
- # outputs=input_collector
446
- # ).then(
447
- # fn=predict_structure,
448
- # inputs=input_collector,
449
- # outputs=[view3d, confidence_plot_image, cif_file]
450
- # )
451
-
452
- # @spaces.GPU(duration=120)
453
- # def is_watermarked(file):
454
- # #first initialize runner
455
- # runner = InferenceRunner(configs)
456
- # # Generate a unique subdirectory and filename
457
- # unique_id = str(uuid.uuid4().hex[:8])
458
- # subdir = os.path.join('./output', unique_id)
459
- # os.makedirs(subdir, exist_ok=True)
460
- # filename = f"{unique_id}.cif"
461
- # file_path = os.path.join(subdir, filename)
462
 
463
- # # Save the uploaded file to the new location
464
- # shutil.copy(file.name, file_path)
465
 
466
- # # Call your processing functions
467
- # configs.process_success = process_data(subdir)
468
- # configs.subdir = subdir
469
- # result = infer_detect(runner, configs)
470
- # # This function should return 'Watermarked' or 'Not Watermarked'
471
- # temp_pdb_path = convert_cif_to_pdb(file_path)
472
- # if result==False:
473
- # return "Not Watermarked", temp_pdb_path
474
- # else:
475
- # return "Watermarked", temp_pdb_path
476
 
477
 
478
 
479
- # with gr.Tab("Watermark Detector"):
480
- # # First create the upload component
481
- # cif_upload = gr.File(label="Upload .cif", file_types=["..cif"])
482
 
483
- # with gr.Row():
484
- # cif_3d_view = Molecule3D(label="3D Visualization of Input", reps=reps)
485
 
486
- # # Prediction output
487
- # prediction_output = gr.Textbox(label="Prediction")
488
 
489
- # # Define the interaction
490
- # cif_upload.change(is_watermarked, inputs=cif_upload, outputs=[prediction_output, cif_3d_view])
491
 
492
- # # Example files
493
- # example_files = [
494
- # "./examples/7r6r_watermarked.cif",
495
- # "./examples/7pzb_unwatermarked.cif"
496
- # ]
497
 
498
- # gr.Examples(examples=example_files, inputs=cif_upload)
499
 
500
 
501
 
502
 
503
 
504
 
505
- # if __name__ == "__main__":
506
- # demo.launch(share=True)
 
1
+ import spaces
2
+ import logging
3
+ import gradio as gr
4
+ import os
5
+ import uuid
6
+ from datetime import datetime
7
+ import numpy as np
8
+ from configs.configs_base import configs as configs_base
9
+ from configs.configs_data import data_configs
10
+ from configs.configs_inference import inference_configs
11
+ from runner.inference import download_infercence_cache, update_inference_configs, infer_predict, infer_detect, InferenceRunner
12
+ from protenix.config import parse_configs, parse_sys_args
13
+ from runner.msa_search import update_infer_json
14
+ from protenix.web_service.prediction_visualization import plot_best_confidence_measure, PredictionLoader
15
+ from process_data import process_data
16
+ import json
17
+ from typing import Dict, List
18
+ from Bio.PDB import MMCIFParser, PDBIO
19
+ import tempfile
20
+ import shutil
21
+ from Bio import PDB
22
+ from gradio_molecule3d import Molecule3D
23
+
24
+ EXAMPLE_PATH = './examples/example.json'
25
+ example_json=[{'sequences': [{'proteinChain': {'sequence': 'MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH', 'count': 2}}, {'dnaSequence': {'sequence': 'CTAGGTAACATTACTCGCG', 'count': 2}}, {'dnaSequence': {'sequence': 'GCGAGTAATGTTAC', 'count': 2}}, {'ligand': {'ligand': 'CCD_PCG', 'count': 2}}], 'name': '7pzb_need_search_msa'}]
26
+
27
+ # Custom CSS for styling
28
+ custom_css = """
29
+ #logo {
30
+ width: 50%;
31
+ }
32
+ .title {
33
+ font-size: 32px;
34
+ font-weight: bold;
35
+ color: #4CAF50;
36
+ display: flex;
37
+ align-items: center; /* Vertically center the logo and text */
38
+ }
39
+ """
40
+
41
+
42
+ os.environ["LAYERNORM_TYPE"] = "fast_layernorm"
43
+ os.environ["USE_DEEPSPEED_EVO_ATTTENTION"] = "False"
44
+ # Set environment variable in the script
45
+ #os.environ['CUTLASS_PATH'] = './cutlass'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # reps = [
48
  # {
49
  # "model": 0,
50
  # "chain": "",
51
  # "resname": "",
52
+ # "style": "cartoon", # Use cartoon style
53
  # "color": "whiteCarbon",
54
  # "residue_range": "",
55
  # "around": 0,
56
  # "byres": False,
57
+ # "visible": True # Ensure this representation is visible
 
 
 
 
 
 
 
 
 
 
 
58
  # }
59
  # ]
 
60
 
61
+ reps = [
62
+ {
63
+ "model": 0,
64
+ "chain": "",
65
+ "resname": "",
66
+ "style": "cartoon",
67
+ "color": "whiteCarbon",
68
+ "residue_range": "",
69
+ "around": 0,
70
+ "byres": False,
71
+ "opacity": 0.2,
72
+ },
73
+ {
74
+ "model": 1,
75
+ "chain": "",
76
+ "resname": "",
77
+ "style": "cartoon",
78
+ "color": "cyanCarbon",
79
+ "residue_range": "",
80
+ "around": 0,
81
+ "byres": False,
82
+ "opacity": 0.8,
83
+ }
84
+ ]
85
+ ##
86
+
87
+
88
+ def align_pdb_files(pdb_file_1, pdb_file_2):
89
+ # Load the structures
90
+ parser = PDB.PPBuilder()
91
+ io = PDB.PDBIO()
92
+ structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1)
93
+ structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2)
94
+
95
+ # Superimpose the second structure onto the first
96
+ super_imposer = PDB.Superimposer()
97
+ model_1 = structure_1[0]
98
+ model_2 = structure_2[0]
99
+
100
+ # Extract the coordinates from the two structures
101
+ atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] # Use CA atoms
102
+ atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"]
103
+
104
+ # Align the structures based on the CA atoms
105
+ coord_1 = [atom.get_coord() for atom in atoms_1]
106
+ coord_2 = [atom.get_coord() for atom in atoms_2]
107
 
108
+ super_imposer.set_atoms(atoms_1, atoms_2)
109
+ super_imposer.apply(model_2) # Apply the transformation to model_2
110
+
111
+ # Save the aligned structure back to the original file
112
+ io.set_structure(structure_2) # Save the aligned structure to the second file (original file)
113
+ io.save(pdb_file_2)
114
+
115
+ # Function to convert .cif to .pdb and save as a temporary file
116
+ def convert_cif_to_pdb(cif_path):
117
+ """
118
+ Convert a CIF file to a PDB file and save it as a temporary file.
119
+
120
+ Args:
121
+ cif_path (str): Path to the input CIF file.
122
+
123
+ Returns:
124
+ str: Path to the temporary PDB file.
125
+ """
126
+ # Initialize the MMCIF parser
127
+ parser = MMCIFParser()
128
+ structure = parser.get_structure("protein", cif_path)
129
+
130
+ # Create a temporary file for the PDB output
131
+ with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file:
132
+ temp_pdb_path = temp_file.name
133
+
134
+ # Save the structure as a PDB file
135
+ io = PDBIO()
136
+ io.set_structure(structure)
137
+ io.save(temp_pdb_path)
138
+
139
+ return temp_pdb_path
140
+
141
+ def plot_3d(pred_loader):
142
+ # Get the CIF file path for the given prediction ID
143
+ cif_path = sorted(pred_loader.cif_paths)[0]
144
+
145
+ # Convert the CIF file to a temporary PDB file
146
+ temp_pdb_path = convert_cif_to_pdb(cif_path)
147
+
148
+ return temp_pdb_path, cif_path
149
+
150
+ def parse_json_input(json_data: List[Dict]) -> Dict:
151
+ """Convert Protenix JSON format to UI-friendly structure"""
152
+ components = {
153
+ "protein_chains": [],
154
+ "dna_sequences": [],
155
+ "ligands": [],
156
+ "complex_name": ""
157
+ }
158
 
159
+ for entry in json_data:
160
+ components["complex_name"] = entry.get("name", "")
161
+ for seq in entry["sequences"]:
162
+ if "proteinChain" in seq:
163
+ components["protein_chains"].append({
164
+ "sequence": seq["proteinChain"]["sequence"],
165
+ "count": seq["proteinChain"]["count"]
166
+ })
167
+ elif "dnaSequence" in seq:
168
+ components["dna_sequences"].append({
169
+ "sequence": seq["dnaSequence"]["sequence"],
170
+ "count": seq["dnaSequence"]["count"]
171
+ })
172
+ elif "ligand" in seq:
173
+ components["ligands"].append({
174
+ "type": seq["ligand"]["ligand"],
175
+ "count": seq["ligand"]["count"]
176
+ })
177
+ return components
178
+
179
+ def create_protenix_json(input_data: Dict) -> List[Dict]:
180
+ """Convert UI inputs to Protenix JSON format"""
181
+ sequences = []
182
 
183
+ for pc in input_data["protein_chains"]:
184
+ sequences.append({
185
+ "proteinChain": {
186
+ "sequence": pc["sequence"],
187
+ "count": pc["count"]
188
+ }
189
+ })
190
 
191
+ for dna in input_data["dna_sequences"]:
192
+ sequences.append({
193
+ "dnaSequence": {
194
+ "sequence": dna["sequence"],
195
+ "count": dna["count"]
196
+ }
197
+ })
198
 
199
+ for lig in input_data["ligands"]:
200
+ sequences.append({
201
+ "ligand": {
202
+ "ligand": lig["type"],
203
+ "count": lig["count"]
204
+ }
205
+ })
206
 
207
+ return [{
208
+ "sequences": sequences,
209
+ "name": input_data["complex_name"]
210
+ }]
211
+
212
+
213
+ #@torch.inference_mode()
214
+ @spaces.GPU(duration=120) # Specify a duration to avoid timeout
215
+ def predict_structure(input_collector: dict):
216
+ #first initialize runner
217
+ runner = InferenceRunner(configs)
218
+ """Handle both input types"""
219
+ os.makedirs("./output", exist_ok=True)
220
 
221
+ # Generate random filename with timestamp
222
+ random_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
223
+ save_path = os.path.join("./output", f"{random_name}.json")
224
+
225
+ print(input_collector)
226
+
227
+ # Handle JSON input
228
+ if input_collector["json"]:
229
+ # Handle different input types
230
+ if isinstance(input_collector["json"], str): # Example JSON case (file path)
231
+ input_data = json.load(open(input_collector["json"]))
232
+ elif hasattr(input_collector["json"], "name"): # File upload case
233
+ input_data = json.load(open(input_collector["json"].name))
234
+ else: # Direct JSON data case
235
+ input_data = input_collector["json"]
236
+ else: # Manual input case
237
+ input_data = create_protenix_json(input_collector["data"])
238
+
239
+ with open(save_path, "w") as f:
240
+ json.dump(input_data, f, indent=2)
241
+
242
+ if input_data==example_json and input_collector['watermark']==True:
243
+ configs.saved_path = './output/example_output/'
244
+ else:
245
+ # run msa
246
+ json_file = update_infer_json(save_path, './output', True)
247
+
248
+ # Run prediction
249
+ configs.input_json_path = json_file
250
+ configs.watermark = input_collector['watermark']
251
+ configs.saved_path = os.path.join("./output/", random_name)
252
+ infer_predict(runner, configs)
253
+ #saved_path = os.path.join('./output', f"{sample_name}", f"seed_{seed}", 'predictions')
254
+
255
+ # Generate visualizations
256
+ pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions'))
257
+ view3d, cif_path = plot_3d(pred_loader=pred_loader)
258
+ if configs.watermark:
259
+ pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions_orig'))
260
+ view3d_orig, _ = plot_3d(pred_loader=pred_loader)
261
+ align_pdb_files(view3d, view3d_orig)
262
+ view3d = [view3d, view3d_orig]
263
+ plot_best_confidence_measure(os.path.join(configs.saved_path, 'predictions'))
264
+ confidence_img_path = os.path.join(os.path.join(configs.saved_path, 'predictions'), "best_sample_confidence.png")
265
+
266
+ return view3d, confidence_img_path, cif_path
267
+
268
+
269
+ logger = logging.getLogger(__name__)
270
+ LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
271
+ logging.basicConfig(
272
+ format=LOG_FORMAT,
273
+ level=logging.INFO,
274
+ datefmt="%Y-%m-%d %H:%M:%S",
275
+ filemode="w",
276
+ )
277
+ configs_base["use_deepspeed_evo_attention"] = (
278
+ os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "False"
279
+ )
280
+ arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 "
281
+ configs = {**configs_base, **{"data": data_configs}, **inference_configs}
282
+ configs = parse_configs(
283
+ configs=configs,
284
+ arg_str=arg_str,
285
+ fill_required_with_null=True,
286
+ )
287
+ configs.load_checkpoint_path='./checkpoint.pt'
288
+ download_infercence_cache()
289
+ configs.use_deepspeed_evo_attention=False
290
+
291
+ add_watermark = gr.Checkbox(label="Add Watermark", value=True)
292
+ add_watermark1 = gr.Checkbox(label="Add Watermark", value=True)
293
+
294
+
295
+ with gr.Blocks(title="FoldMark", css=custom_css) as demo:
296
+ with gr.Row():
297
+ # Use a Column to align the logo and title horizontally
298
+ gr.Image(value="./assets/foldmark_head.png", elem_id="logo", label="Logo", height=150, show_label=False)
299
+
300
+ with gr.Tab("Structure Predictor (JSON Upload)"):
301
+ # First create the upload component
302
+ json_upload = gr.File(label="Upload JSON", file_types=[".json"])
303
 
304
+ # Then create the example component that references it
305
+ gr.Examples(
306
+ examples=[[EXAMPLE_PATH]],
307
+ inputs=[json_upload],
308
+ label="Click to use example JSON:",
309
+ examples_per_page=1
310
+ )
311
 
312
+ # Rest of the components
313
+ upload_name = gr.Textbox(label="Complex Name (optional)")
314
+ upload_output = gr.JSON(label="Parsed Components")
315
 
316
+ json_upload.upload(
317
+ fn=lambda f: parse_json_input(json.load(open(f.name))),
318
+ inputs=json_upload,
319
+ outputs=upload_output
320
+ )
321
+
322
+ # Shared prediction components
323
+ with gr.Row():
324
+ add_watermark.render()
325
+ submit_btn = gr.Button("Predict Structure", variant="primary")
326
+ #structure_view = gr.HTML(label="3D Visualization")
327
+
328
+ with gr.Row():
329
+ view3d = Molecule3D(label="3D Visualization", reps=reps)
330
+ legend = gr.Markdown("""
331
+ **Color Legend:**
332
+
333
+ - <span style="color:grey">Unwatermarked Structure</span>
334
+ - <span style="color:cyan">Watermarked Structure</span>
335
+ """)
336
+ with gr.Row():
337
+ cif_file = gr.File(label="Download CIF File")
338
+ with gr.Row():
339
+ confidence_plot_image = gr.Image(label="Confidence Measures")
340
 
341
+ input_collector = gr.JSON(visible=False)
342
+
343
+ # Map inputs to a dictionary
344
+ submit_btn.click(
345
+ fn=lambda j, w: {"json": j, "watermark": w},
346
+ inputs=[json_upload, add_watermark],
347
+ outputs=input_collector
348
+ ).then(
349
+ fn=predict_structure,
350
+ inputs=input_collector,
351
+ outputs=[view3d, confidence_plot_image, cif_file]
352
+ )
353
+
354
+ gr.Markdown("""
355
+ The example of the uploaded json file for structure prediction.
356
+ <pre>
357
+ [{
358
+ "sequences": [
359
+ {
360
+ "proteinChain": {
361
+ "sequence": "MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH",
362
+ "count": 2
363
+ }
364
+ },
365
+ {
366
+ "dnaSequence": {
367
+ "sequence": "CTAGGTAACATTACTCGCG",
368
+ "count": 2
369
+ }
370
+ },
371
+ {
372
+ "dnaSequence": {
373
+ "sequence": "GCGAGTAATGTTAC",
374
+ "count": 2
375
+ }
376
+ },
377
+ {
378
+ "ligand": {
379
+ "ligand": "CCD_PCG",
380
+ "count": 2
381
+ }
382
+ }
383
+ ],
384
+ "name": "7pzb"
385
+ }]
386
+ </pre>
387
+ """)
388
 
389
+ with gr.Tab("Structure Predictor (Manual Input)"):
390
+ with gr.Row():
391
+ complex_name = gr.Textbox(label="Complex Name")
392
 
393
+ # Replace gr.Group with gr.Accordion
394
+ with gr.Accordion(label="Protein Chains", open=True):
395
+ protein_chains = gr.Dataframe(
396
+ headers=["Sequence", "Count"],
397
+ datatype=["str", "number"],
398
+ row_count=1,
399
+ col_count=(2, "fixed")
400
+ )
401
 
402
+ # Repeat for other groups
403
+ with gr.Accordion(label="DNA Sequences", open=True):
404
+ dna_sequences = gr.Dataframe(
405
+ headers=["Sequence", "Count"],
406
+ datatype=["str", "number"],
407
+ row_count=1
408
+ )
409
 
410
+ with gr.Accordion(label="Ligands", open=True):
411
+ ligands = gr.Dataframe(
412
+ headers=["Ligand Type", "Count"],
413
+ datatype=["str", "number"],
414
+ row_count=1
415
+ )
416
 
417
+ manual_output = gr.JSON(label="Generated JSON")
418
 
419
+ complex_name.change(
420
+ fn=lambda x: {"complex_name": x},
421
+ inputs=complex_name,
422
+ outputs=manual_output
423
+ )
424
+
425
+ # Shared prediction components
426
+ with gr.Row():
427
+ add_watermark1.render()
428
+ submit_btn = gr.Button("Predict Structure", variant="primary")
429
+ #structure_view = gr.HTML(label="3D Visualization")
430
+
431
+ with gr.Row():
432
+ view3d = Molecule3D(label="3D Visualization (Gray: Unwatermarked; Cyan: Watermarked)", reps=reps)
433
+
434
+ with gr.Row():
435
+ cif_file = gr.File(label="Download CIF File")
436
+ with gr.Row():
437
+ confidence_plot_image = gr.Image(label="Confidence Measures")
438
 
439
+ input_collector = gr.JSON(visible=False)
440
+
441
+ # Map inputs to a dictionary
442
+ submit_btn.click(
443
+ fn=lambda c, p, d, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "ligands": l}, "watermark": w},
444
+ inputs=[complex_name, protein_chains, dna_sequences, ligands, add_watermark1],
445
+ outputs=input_collector
446
+ ).then(
447
+ fn=predict_structure,
448
+ inputs=input_collector,
449
+ outputs=[view3d, confidence_plot_image, cif_file]
450
+ )
451
+
452
+ @spaces.GPU(duration=120)
453
+ def is_watermarked(file):
454
+ #first initialize runner
455
+ runner = InferenceRunner(configs)
456
+ # Generate a unique subdirectory and filename
457
+ unique_id = str(uuid.uuid4().hex[:8])
458
+ subdir = os.path.join('./output', unique_id)
459
+ os.makedirs(subdir, exist_ok=True)
460
+ filename = f"{unique_id}.cif"
461
+ file_path = os.path.join(subdir, filename)
462
 
463
+ # Save the uploaded file to the new location
464
+ shutil.copy(file.name, file_path)
465
 
466
+ # Call your processing functions
467
+ configs.process_success = process_data(subdir)
468
+ configs.subdir = subdir
469
+ result = infer_detect(runner, configs)
470
+ # This function should return 'Watermarked' or 'Not Watermarked'
471
+ temp_pdb_path = convert_cif_to_pdb(file_path)
472
+ if result==False:
473
+ return "Not Watermarked", temp_pdb_path
474
+ else:
475
+ return "Watermarked", temp_pdb_path
476
 
477
 
478
 
479
+ with gr.Tab("Watermark Detector"):
480
+ # First create the upload component
481
+ cif_upload = gr.File(label="Upload .cif", file_types=["..cif"])
482
 
483
+ with gr.Row():
484
+ cif_3d_view = Molecule3D(label="3D Visualization of Input", reps=reps)
485
 
486
+ # Prediction output
487
+ prediction_output = gr.Textbox(label="Prediction")
488
 
489
+ # Define the interaction
490
+ cif_upload.change(is_watermarked, inputs=cif_upload, outputs=[prediction_output, cif_3d_view])
491
 
492
+ # Example files
493
+ example_files = [
494
+ "./examples/7r6r_watermarked.cif",
495
+ "./examples/7pzb_unwatermarked.cif"
496
+ ]
497
 
498
+ gr.Examples(examples=example_files, inputs=cif_upload)
499
 
500
 
501
 
502
 
503
 
504
 
505
+ if __name__ == "__main__":
506
+ demo.launch(share=True)