# import spaces # import logging # import gradio as gr # import os # import uuid # from datetime import datetime # import numpy as np # from configs.configs_base import configs as configs_base # from configs.configs_data import data_configs # from configs.configs_inference import inference_configs # from runner.inference import download_infercence_cache, update_inference_configs, infer_predict, infer_detect, InferenceRunner # from protenix.config import parse_configs, parse_sys_args # from runner.msa_search import update_infer_json # from protenix.web_service.prediction_visualization import plot_best_confidence_measure, PredictionLoader # from process_data import process_data # import json # from typing import Dict, List # from Bio.PDB import MMCIFParser, PDBIO # import tempfile # import shutil # from Bio import PDB # from gradio_molecule3d import Molecule3D # EXAMPLE_PATH = './examples/example.json' # example_json=[{'sequences': [{'proteinChain': {'sequence': 'MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH', 'count': 2}}, {'dnaSequence': {'sequence': 'CTAGGTAACATTACTCGCG', 'count': 2}}, {'dnaSequence': {'sequence': 'GCGAGTAATGTTAC', 'count': 2}}, {'ligand': {'ligand': 'CCD_PCG', 'count': 2}}], 'name': '7pzb_need_search_msa'}] # # Custom CSS for styling # custom_css = """ # #logo { # width: 50%; # } # .title { # font-size: 32px; # font-weight: bold; # color: #4CAF50; # display: flex; # align-items: center; /* Vertically center the logo and text */ # } # """ # os.environ["LAYERNORM_TYPE"] = "fast_layernorm" # os.environ["USE_DEEPSPEED_EVO_ATTTENTION"] = "False" # # Set environment variable in the script # #os.environ['CUTLASS_PATH'] = './cutlass' # # reps = [ # # { # # "model": 0, # # "chain": "", # # "resname": "", # # "style": "cartoon", # Use cartoon style # # "color": "whiteCarbon", # # "residue_range": "", # # "around": 0, # # "byres": False, # # "visible": True # Ensure this representation is visible # # } # # ] # reps = [ # { # "model": 0, # "chain": "", # "resname": "", # "style": "cartoon", # "color": "whiteCarbon", # "residue_range": "", # "around": 0, # "byres": False, # "opacity": 0.2, # }, # { # "model": 1, # "chain": "", # "resname": "", # "style": "cartoon", # "color": "cyanCarbon", # "residue_range": "", # "around": 0, # "byres": False, # "opacity": 0.8, # } # ] # ## # def align_pdb_files(pdb_file_1, pdb_file_2): # # Load the structures # parser = PDB.PPBuilder() # io = PDB.PDBIO() # structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1) # structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2) # # Superimpose the second structure onto the first # super_imposer = PDB.Superimposer() # model_1 = structure_1[0] # model_2 = structure_2[0] # # Extract the coordinates from the two structures # atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] # Use CA atoms # atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"] # # Align the structures based on the CA atoms # coord_1 = [atom.get_coord() for atom in atoms_1] # coord_2 = [atom.get_coord() for atom in atoms_2] # super_imposer.set_atoms(atoms_1, atoms_2) # super_imposer.apply(model_2) # Apply the transformation to model_2 # # Save the aligned structure back to the original file # io.set_structure(structure_2) # Save the aligned structure to the second file (original file) # io.save(pdb_file_2) # # Function to convert .cif to .pdb and save as a temporary file # def convert_cif_to_pdb(cif_path): # """ # Convert a CIF file to a PDB file and save it as a temporary file. # Args: # cif_path (str): Path to the input CIF file. # Returns: # str: Path to the temporary PDB file. # """ # # Initialize the MMCIF parser # parser = MMCIFParser() # structure = parser.get_structure("protein", cif_path) # # Create a temporary file for the PDB output # with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file: # temp_pdb_path = temp_file.name # # Save the structure as a PDB file # io = PDBIO() # io.set_structure(structure) # io.save(temp_pdb_path) # return temp_pdb_path # def plot_3d(pred_loader): # # Get the CIF file path for the given prediction ID # cif_path = sorted(pred_loader.cif_paths)[0] # # Convert the CIF file to a temporary PDB file # temp_pdb_path = convert_cif_to_pdb(cif_path) # return temp_pdb_path, cif_path # def parse_json_input(json_data: List[Dict]) -> Dict: # """Convert Protenix JSON format to UI-friendly structure""" # components = { # "protein_chains": [], # "dna_sequences": [], # "ligands": [], # "complex_name": "" # } # for entry in json_data: # components["complex_name"] = entry.get("name", "") # for seq in entry["sequences"]: # if "proteinChain" in seq: # components["protein_chains"].append({ # "sequence": seq["proteinChain"]["sequence"], # "count": seq["proteinChain"]["count"] # }) # elif "dnaSequence" in seq: # components["dna_sequences"].append({ # "sequence": seq["dnaSequence"]["sequence"], # "count": seq["dnaSequence"]["count"] # }) # elif "ligand" in seq: # components["ligands"].append({ # "type": seq["ligand"]["ligand"], # "count": seq["ligand"]["count"] # }) # return components # def create_protenix_json(input_data: Dict) -> List[Dict]: # """Convert UI inputs to Protenix JSON format""" # sequences = [] # for pc in input_data["protein_chains"]: # sequences.append({ # "proteinChain": { # "sequence": pc["sequence"], # "count": pc["count"] # } # }) # for dna in input_data["dna_sequences"]: # sequences.append({ # "dnaSequence": { # "sequence": dna["sequence"], # "count": dna["count"] # } # }) # for lig in input_data["ligands"]: # sequences.append({ # "ligand": { # "ligand": lig["type"], # "count": lig["count"] # } # }) # return [{ # "sequences": sequences, # "name": input_data["complex_name"] # }] # #@torch.inference_mode() # @spaces.GPU(duration=120) # Specify a duration to avoid timeout # def predict_structure(input_collector: dict): # #first initialize runner # runner = InferenceRunner(configs) # """Handle both input types""" # os.makedirs("./output", exist_ok=True) # # Generate random filename with timestamp # random_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" # save_path = os.path.join("./output", f"{random_name}.json") # print(input_collector) # # Handle JSON input # if input_collector["json"]: # # Handle different input types # if isinstance(input_collector["json"], str): # Example JSON case (file path) # input_data = json.load(open(input_collector["json"])) # elif hasattr(input_collector["json"], "name"): # File upload case # input_data = json.load(open(input_collector["json"].name)) # else: # Direct JSON data case # input_data = input_collector["json"] # else: # Manual input case # input_data = create_protenix_json(input_collector["data"]) # with open(save_path, "w") as f: # json.dump(input_data, f, indent=2) # if input_data==example_json and input_collector['watermark']==True: # configs.saved_path = './output/example_output/' # else: # # run msa # json_file = update_infer_json(save_path, './output', True) # # Run prediction # configs.input_json_path = json_file # configs.watermark = input_collector['watermark'] # configs.saved_path = os.path.join("./output/", random_name) # infer_predict(runner, configs) # #saved_path = os.path.join('./output', f"{sample_name}", f"seed_{seed}", 'predictions') # # Generate visualizations # pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions')) # view3d, cif_path = plot_3d(pred_loader=pred_loader) # if configs.watermark: # pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions_orig')) # view3d_orig, _ = plot_3d(pred_loader=pred_loader) # align_pdb_files(view3d, view3d_orig) # view3d = [view3d, view3d_orig] # plot_best_confidence_measure(os.path.join(configs.saved_path, 'predictions')) # confidence_img_path = os.path.join(os.path.join(configs.saved_path, 'predictions'), "best_sample_confidence.png") # return view3d, confidence_img_path, cif_path # logger = logging.getLogger(__name__) # LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s" # logging.basicConfig( # format=LOG_FORMAT, # level=logging.INFO, # datefmt="%Y-%m-%d %H:%M:%S", # filemode="w", # ) # configs_base["use_deepspeed_evo_attention"] = ( # os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "False" # ) # arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 " # configs = {**configs_base, **{"data": data_configs}, **inference_configs} # configs = parse_configs( # configs=configs, # arg_str=arg_str, # fill_required_with_null=True, # ) # configs.load_checkpoint_path='./checkpoint.pt' # download_infercence_cache() # configs.use_deepspeed_evo_attention=False # add_watermark = gr.Checkbox(label="Add Watermark", value=True) # add_watermark1 = gr.Checkbox(label="Add Watermark", value=True) # with gr.Blocks(title="FoldMark", css=custom_css) as demo: # with gr.Row(): # # Use a Column to align the logo and title horizontally # gr.Image(value="./assets/foldmark_head.png", elem_id="logo", label="Logo", height=150, show_label=False) # with gr.Tab("Structure Predictor (JSON Upload)"): # # First create the upload component # json_upload = gr.File(label="Upload JSON", file_types=[".json"]) # # Then create the example component that references it # gr.Examples( # examples=[[EXAMPLE_PATH]], # inputs=[json_upload], # label="Click to use example JSON:", # examples_per_page=1 # ) # # Rest of the components # upload_name = gr.Textbox(label="Complex Name (optional)") # upload_output = gr.JSON(label="Parsed Components") # json_upload.upload( # fn=lambda f: parse_json_input(json.load(open(f.name))), # inputs=json_upload, # outputs=upload_output # ) # # Shared prediction components # with gr.Row(): # add_watermark.render() # submit_btn = gr.Button("Predict Structure", variant="primary") # #structure_view = gr.HTML(label="3D Visualization") # with gr.Row(): # view3d = Molecule3D(label="3D Visualization", reps=reps) # legend = gr.Markdown(""" # **Color Legend:** # - Unwatermarked Structure # - Watermarked Structure # """) # with gr.Row(): # cif_file = gr.File(label="Download CIF File") # with gr.Row(): # confidence_plot_image = gr.Image(label="Confidence Measures") # input_collector = gr.JSON(visible=False) # # Map inputs to a dictionary # submit_btn.click( # fn=lambda j, w: {"json": j, "watermark": w}, # inputs=[json_upload, add_watermark], # outputs=input_collector # ).then( # fn=predict_structure, # inputs=input_collector, # outputs=[view3d, confidence_plot_image, cif_file] # ) # gr.Markdown(""" # The example of the uploaded json file for structure prediction. #
#             [{
#         "sequences": [
#             {
#                 "proteinChain": {
#                     "sequence": "MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH",
#                     "count": 2
#                 }
#             },
#             {
#                 "dnaSequence": {
#                     "sequence": "CTAGGTAACATTACTCGCG",
#                     "count": 2
#                 }
#             },
#             {
#                 "dnaSequence": {
#                     "sequence": "GCGAGTAATGTTAC",
#                     "count": 2
#                 }
#             },
#             {
#                 "ligand": {
#                     "ligand": "CCD_PCG",
#                     "count": 2
#                 }
#             }
#         ],
#         "name": "7pzb"
#         }]
#         
# """) # with gr.Tab("Structure Predictor (Manual Input)"): # with gr.Row(): # complex_name = gr.Textbox(label="Complex Name") # # Replace gr.Group with gr.Accordion # with gr.Accordion(label="Protein Chains", open=True): # protein_chains = gr.Dataframe( # headers=["Sequence", "Count"], # datatype=["str", "number"], # row_count=1, # col_count=(2, "fixed") # ) # # Repeat for other groups # with gr.Accordion(label="DNA Sequences", open=True): # dna_sequences = gr.Dataframe( # headers=["Sequence", "Count"], # datatype=["str", "number"], # row_count=1 # ) # with gr.Accordion(label="Ligands", open=True): # ligands = gr.Dataframe( # headers=["Ligand Type", "Count"], # datatype=["str", "number"], # row_count=1 # ) # manual_output = gr.JSON(label="Generated JSON") # complex_name.change( # fn=lambda x: {"complex_name": x}, # inputs=complex_name, # outputs=manual_output # ) # # Shared prediction components # with gr.Row(): # add_watermark1.render() # submit_btn = gr.Button("Predict Structure", variant="primary") # #structure_view = gr.HTML(label="3D Visualization") # with gr.Row(): # view3d = Molecule3D(label="3D Visualization (Gray: Unwatermarked; Cyan: Watermarked)", reps=reps) # with gr.Row(): # cif_file = gr.File(label="Download CIF File") # with gr.Row(): # confidence_plot_image = gr.Image(label="Confidence Measures") # input_collector = gr.JSON(visible=False) # # Map inputs to a dictionary # submit_btn.click( # fn=lambda c, p, d, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "ligands": l}, "watermark": w}, # inputs=[complex_name, protein_chains, dna_sequences, ligands, add_watermark1], # outputs=input_collector # ).then( # fn=predict_structure, # inputs=input_collector, # outputs=[view3d, confidence_plot_image, cif_file] # ) # @spaces.GPU(duration=120) # def is_watermarked(file): # #first initialize runner # runner = InferenceRunner(configs) # # Generate a unique subdirectory and filename # unique_id = str(uuid.uuid4().hex[:8]) # subdir = os.path.join('./output', unique_id) # os.makedirs(subdir, exist_ok=True) # filename = f"{unique_id}.cif" # file_path = os.path.join(subdir, filename) # # Save the uploaded file to the new location # shutil.copy(file.name, file_path) # # Call your processing functions # configs.process_success = process_data(subdir) # configs.subdir = subdir # result = infer_detect(runner, configs) # # This function should return 'Watermarked' or 'Not Watermarked' # temp_pdb_path = convert_cif_to_pdb(file_path) # if result==False: # return "Not Watermarked", temp_pdb_path # else: # return "Watermarked", temp_pdb_path # with gr.Tab("Watermark Detector"): # # First create the upload component # cif_upload = gr.File(label="Upload .cif", file_types=["..cif"]) # with gr.Row(): # cif_3d_view = Molecule3D(label="3D Visualization of Input", reps=reps) # # Prediction output # prediction_output = gr.Textbox(label="Prediction") # # Define the interaction # cif_upload.change(is_watermarked, inputs=cif_upload, outputs=[prediction_output, cif_3d_view]) # # Example files # example_files = [ # "./examples/7r6r_watermarked.cif", # "./examples/7pzb_unwatermarked.cif" # ] # gr.Examples(examples=example_files, inputs=cif_upload) # if __name__ == "__main__": # demo.launch(share=True)