# import spaces | |
# import logging | |
# import gradio as gr | |
# import os | |
# import uuid | |
# from datetime import datetime | |
# import numpy as np | |
# from configs.configs_base import configs as configs_base | |
# from configs.configs_data import data_configs | |
# from configs.configs_inference import inference_configs | |
# from runner.inference import download_infercence_cache, update_inference_configs, infer_predict, infer_detect, InferenceRunner | |
# from protenix.config import parse_configs, parse_sys_args | |
# from runner.msa_search import update_infer_json | |
# from protenix.web_service.prediction_visualization import plot_best_confidence_measure, PredictionLoader | |
# from process_data import process_data | |
# import json | |
# from typing import Dict, List | |
# from Bio.PDB import MMCIFParser, PDBIO | |
# import tempfile | |
# import shutil | |
# from Bio import PDB | |
# from gradio_molecule3d import Molecule3D | |
# EXAMPLE_PATH = './examples/example.json' | |
# example_json=[{'sequences': [{'proteinChain': {'sequence': 'MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH', 'count': 2}}, {'dnaSequence': {'sequence': 'CTAGGTAACATTACTCGCG', 'count': 2}}, {'dnaSequence': {'sequence': 'GCGAGTAATGTTAC', 'count': 2}}, {'ligand': {'ligand': 'CCD_PCG', 'count': 2}}], 'name': '7pzb_need_search_msa'}] | |
# # Custom CSS for styling | |
# custom_css = """ | |
# #logo { | |
# width: 50%; | |
# } | |
# .title { | |
# font-size: 32px; | |
# font-weight: bold; | |
# color: #4CAF50; | |
# display: flex; | |
# align-items: center; /* Vertically center the logo and text */ | |
# } | |
# """ | |
# os.environ["LAYERNORM_TYPE"] = "fast_layernorm" | |
# os.environ["USE_DEEPSPEED_EVO_ATTTENTION"] = "False" | |
# # Set environment variable in the script | |
# #os.environ['CUTLASS_PATH'] = './cutlass' | |
# # reps = [ | |
# # { | |
# # "model": 0, | |
# # "chain": "", | |
# # "resname": "", | |
# # "style": "cartoon", # Use cartoon style | |
# # "color": "whiteCarbon", | |
# # "residue_range": "", | |
# # "around": 0, | |
# # "byres": False, | |
# # "visible": True # Ensure this representation is visible | |
# # } | |
# # ] | |
# reps = [ | |
# { | |
# "model": 0, | |
# "chain": "", | |
# "resname": "", | |
# "style": "cartoon", | |
# "color": "whiteCarbon", | |
# "residue_range": "", | |
# "around": 0, | |
# "byres": False, | |
# "opacity": 0.2, | |
# }, | |
# { | |
# "model": 1, | |
# "chain": "", | |
# "resname": "", | |
# "style": "cartoon", | |
# "color": "cyanCarbon", | |
# "residue_range": "", | |
# "around": 0, | |
# "byres": False, | |
# "opacity": 0.8, | |
# } | |
# ] | |
# ## | |
# def align_pdb_files(pdb_file_1, pdb_file_2): | |
# # Load the structures | |
# parser = PDB.PPBuilder() | |
# io = PDB.PDBIO() | |
# structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1) | |
# structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2) | |
# # Superimpose the second structure onto the first | |
# super_imposer = PDB.Superimposer() | |
# model_1 = structure_1[0] | |
# model_2 = structure_2[0] | |
# # Extract the coordinates from the two structures | |
# atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] # Use CA atoms | |
# atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"] | |
# # Align the structures based on the CA atoms | |
# coord_1 = [atom.get_coord() for atom in atoms_1] | |
# coord_2 = [atom.get_coord() for atom in atoms_2] | |
# super_imposer.set_atoms(atoms_1, atoms_2) | |
# super_imposer.apply(model_2) # Apply the transformation to model_2 | |
# # Save the aligned structure back to the original file | |
# io.set_structure(structure_2) # Save the aligned structure to the second file (original file) | |
# io.save(pdb_file_2) | |
# # Function to convert .cif to .pdb and save as a temporary file | |
# def convert_cif_to_pdb(cif_path): | |
# """ | |
# Convert a CIF file to a PDB file and save it as a temporary file. | |
# Args: | |
# cif_path (str): Path to the input CIF file. | |
# Returns: | |
# str: Path to the temporary PDB file. | |
# """ | |
# # Initialize the MMCIF parser | |
# parser = MMCIFParser() | |
# structure = parser.get_structure("protein", cif_path) | |
# # Create a temporary file for the PDB output | |
# with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file: | |
# temp_pdb_path = temp_file.name | |
# # Save the structure as a PDB file | |
# io = PDBIO() | |
# io.set_structure(structure) | |
# io.save(temp_pdb_path) | |
# return temp_pdb_path | |
# def plot_3d(pred_loader): | |
# # Get the CIF file path for the given prediction ID | |
# cif_path = sorted(pred_loader.cif_paths)[0] | |
# # Convert the CIF file to a temporary PDB file | |
# temp_pdb_path = convert_cif_to_pdb(cif_path) | |
# return temp_pdb_path, cif_path | |
# def parse_json_input(json_data: List[Dict]) -> Dict: | |
# """Convert Protenix JSON format to UI-friendly structure""" | |
# components = { | |
# "protein_chains": [], | |
# "dna_sequences": [], | |
# "ligands": [], | |
# "complex_name": "" | |
# } | |
# for entry in json_data: | |
# components["complex_name"] = entry.get("name", "") | |
# for seq in entry["sequences"]: | |
# if "proteinChain" in seq: | |
# components["protein_chains"].append({ | |
# "sequence": seq["proteinChain"]["sequence"], | |
# "count": seq["proteinChain"]["count"] | |
# }) | |
# elif "dnaSequence" in seq: | |
# components["dna_sequences"].append({ | |
# "sequence": seq["dnaSequence"]["sequence"], | |
# "count": seq["dnaSequence"]["count"] | |
# }) | |
# elif "ligand" in seq: | |
# components["ligands"].append({ | |
# "type": seq["ligand"]["ligand"], | |
# "count": seq["ligand"]["count"] | |
# }) | |
# return components | |
# def create_protenix_json(input_data: Dict) -> List[Dict]: | |
# """Convert UI inputs to Protenix JSON format""" | |
# sequences = [] | |
# for pc in input_data["protein_chains"]: | |
# sequences.append({ | |
# "proteinChain": { | |
# "sequence": pc["sequence"], | |
# "count": pc["count"] | |
# } | |
# }) | |
# for dna in input_data["dna_sequences"]: | |
# sequences.append({ | |
# "dnaSequence": { | |
# "sequence": dna["sequence"], | |
# "count": dna["count"] | |
# } | |
# }) | |
# for lig in input_data["ligands"]: | |
# sequences.append({ | |
# "ligand": { | |
# "ligand": lig["type"], | |
# "count": lig["count"] | |
# } | |
# }) | |
# return [{ | |
# "sequences": sequences, | |
# "name": input_data["complex_name"] | |
# }] | |
# #@torch.inference_mode() | |
# @spaces.GPU(duration=120) # Specify a duration to avoid timeout | |
# def predict_structure(input_collector: dict): | |
# #first initialize runner | |
# runner = InferenceRunner(configs) | |
# """Handle both input types""" | |
# os.makedirs("./output", exist_ok=True) | |
# # Generate random filename with timestamp | |
# random_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" | |
# save_path = os.path.join("./output", f"{random_name}.json") | |
# print(input_collector) | |
# # Handle JSON input | |
# if input_collector["json"]: | |
# # Handle different input types | |
# if isinstance(input_collector["json"], str): # Example JSON case (file path) | |
# input_data = json.load(open(input_collector["json"])) | |
# elif hasattr(input_collector["json"], "name"): # File upload case | |
# input_data = json.load(open(input_collector["json"].name)) | |
# else: # Direct JSON data case | |
# input_data = input_collector["json"] | |
# else: # Manual input case | |
# input_data = create_protenix_json(input_collector["data"]) | |
# with open(save_path, "w") as f: | |
# json.dump(input_data, f, indent=2) | |
# if input_data==example_json and input_collector['watermark']==True: | |
# configs.saved_path = './output/example_output/' | |
# else: | |
# # run msa | |
# json_file = update_infer_json(save_path, './output', True) | |
# # Run prediction | |
# configs.input_json_path = json_file | |
# configs.watermark = input_collector['watermark'] | |
# configs.saved_path = os.path.join("./output/", random_name) | |
# infer_predict(runner, configs) | |
# #saved_path = os.path.join('./output', f"{sample_name}", f"seed_{seed}", 'predictions') | |
# # Generate visualizations | |
# pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions')) | |
# view3d, cif_path = plot_3d(pred_loader=pred_loader) | |
# if configs.watermark: | |
# pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions_orig')) | |
# view3d_orig, _ = plot_3d(pred_loader=pred_loader) | |
# align_pdb_files(view3d, view3d_orig) | |
# view3d = [view3d, view3d_orig] | |
# plot_best_confidence_measure(os.path.join(configs.saved_path, 'predictions')) | |
# confidence_img_path = os.path.join(os.path.join(configs.saved_path, 'predictions'), "best_sample_confidence.png") | |
# return view3d, confidence_img_path, cif_path | |
# logger = logging.getLogger(__name__) | |
# LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s" | |
# logging.basicConfig( | |
# format=LOG_FORMAT, | |
# level=logging.INFO, | |
# datefmt="%Y-%m-%d %H:%M:%S", | |
# filemode="w", | |
# ) | |
# configs_base["use_deepspeed_evo_attention"] = ( | |
# os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "False" | |
# ) | |
# arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 " | |
# configs = {**configs_base, **{"data": data_configs}, **inference_configs} | |
# configs = parse_configs( | |
# configs=configs, | |
# arg_str=arg_str, | |
# fill_required_with_null=True, | |
# ) | |
# configs.load_checkpoint_path='./checkpoint.pt' | |
# download_infercence_cache() | |
# configs.use_deepspeed_evo_attention=False | |
# add_watermark = gr.Checkbox(label="Add Watermark", value=True) | |
# add_watermark1 = gr.Checkbox(label="Add Watermark", value=True) | |
# with gr.Blocks(title="FoldMark", css=custom_css) as demo: | |
# with gr.Row(): | |
# # Use a Column to align the logo and title horizontally | |
# gr.Image(value="./assets/foldmark_head.png", elem_id="logo", label="Logo", height=150, show_label=False) | |
# with gr.Tab("Structure Predictor (JSON Upload)"): | |
# # First create the upload component | |
# json_upload = gr.File(label="Upload JSON", file_types=[".json"]) | |
# # Then create the example component that references it | |
# gr.Examples( | |
# examples=[[EXAMPLE_PATH]], | |
# inputs=[json_upload], | |
# label="Click to use example JSON:", | |
# examples_per_page=1 | |
# ) | |
# # Rest of the components | |
# upload_name = gr.Textbox(label="Complex Name (optional)") | |
# upload_output = gr.JSON(label="Parsed Components") | |
# json_upload.upload( | |
# fn=lambda f: parse_json_input(json.load(open(f.name))), | |
# inputs=json_upload, | |
# outputs=upload_output | |
# ) | |
# # Shared prediction components | |
# with gr.Row(): | |
# add_watermark.render() | |
# submit_btn = gr.Button("Predict Structure", variant="primary") | |
# #structure_view = gr.HTML(label="3D Visualization") | |
# with gr.Row(): | |
# view3d = Molecule3D(label="3D Visualization", reps=reps) | |
# legend = gr.Markdown(""" | |
# **Color Legend:** | |
# - <span style="color:grey">Unwatermarked Structure</span> | |
# - <span style="color:cyan">Watermarked Structure</span> | |
# """) | |
# with gr.Row(): | |
# cif_file = gr.File(label="Download CIF File") | |
# with gr.Row(): | |
# confidence_plot_image = gr.Image(label="Confidence Measures") | |
# input_collector = gr.JSON(visible=False) | |
# # Map inputs to a dictionary | |
# submit_btn.click( | |
# fn=lambda j, w: {"json": j, "watermark": w}, | |
# inputs=[json_upload, add_watermark], | |
# outputs=input_collector | |
# ).then( | |
# fn=predict_structure, | |
# inputs=input_collector, | |
# outputs=[view3d, confidence_plot_image, cif_file] | |
# ) | |
# gr.Markdown(""" | |
# The example of the uploaded json file for structure prediction. | |
# <pre> | |
# [{ | |
# "sequences": [ | |
# { | |
# "proteinChain": { | |
# "sequence": "MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH", | |
# "count": 2 | |
# } | |
# }, | |
# { | |
# "dnaSequence": { | |
# "sequence": "CTAGGTAACATTACTCGCG", | |
# "count": 2 | |
# } | |
# }, | |
# { | |
# "dnaSequence": { | |
# "sequence": "GCGAGTAATGTTAC", | |
# "count": 2 | |
# } | |
# }, | |
# { | |
# "ligand": { | |
# "ligand": "CCD_PCG", | |
# "count": 2 | |
# } | |
# } | |
# ], | |
# "name": "7pzb" | |
# }] | |
# </pre> | |
# """) | |
# with gr.Tab("Structure Predictor (Manual Input)"): | |
# with gr.Row(): | |
# complex_name = gr.Textbox(label="Complex Name") | |
# # Replace gr.Group with gr.Accordion | |
# with gr.Accordion(label="Protein Chains", open=True): | |
# protein_chains = gr.Dataframe( | |
# headers=["Sequence", "Count"], | |
# datatype=["str", "number"], | |
# row_count=1, | |
# col_count=(2, "fixed") | |
# ) | |
# # Repeat for other groups | |
# with gr.Accordion(label="DNA Sequences", open=True): | |
# dna_sequences = gr.Dataframe( | |
# headers=["Sequence", "Count"], | |
# datatype=["str", "number"], | |
# row_count=1 | |
# ) | |
# with gr.Accordion(label="Ligands", open=True): | |
# ligands = gr.Dataframe( | |
# headers=["Ligand Type", "Count"], | |
# datatype=["str", "number"], | |
# row_count=1 | |
# ) | |
# manual_output = gr.JSON(label="Generated JSON") | |
# complex_name.change( | |
# fn=lambda x: {"complex_name": x}, | |
# inputs=complex_name, | |
# outputs=manual_output | |
# ) | |
# # Shared prediction components | |
# with gr.Row(): | |
# add_watermark1.render() | |
# submit_btn = gr.Button("Predict Structure", variant="primary") | |
# #structure_view = gr.HTML(label="3D Visualization") | |
# with gr.Row(): | |
# view3d = Molecule3D(label="3D Visualization (Gray: Unwatermarked; Cyan: Watermarked)", reps=reps) | |
# with gr.Row(): | |
# cif_file = gr.File(label="Download CIF File") | |
# with gr.Row(): | |
# confidence_plot_image = gr.Image(label="Confidence Measures") | |
# input_collector = gr.JSON(visible=False) | |
# # Map inputs to a dictionary | |
# submit_btn.click( | |
# fn=lambda c, p, d, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "ligands": l}, "watermark": w}, | |
# inputs=[complex_name, protein_chains, dna_sequences, ligands, add_watermark1], | |
# outputs=input_collector | |
# ).then( | |
# fn=predict_structure, | |
# inputs=input_collector, | |
# outputs=[view3d, confidence_plot_image, cif_file] | |
# ) | |
# @spaces.GPU(duration=120) | |
# def is_watermarked(file): | |
# #first initialize runner | |
# runner = InferenceRunner(configs) | |
# # Generate a unique subdirectory and filename | |
# unique_id = str(uuid.uuid4().hex[:8]) | |
# subdir = os.path.join('./output', unique_id) | |
# os.makedirs(subdir, exist_ok=True) | |
# filename = f"{unique_id}.cif" | |
# file_path = os.path.join(subdir, filename) | |
# # Save the uploaded file to the new location | |
# shutil.copy(file.name, file_path) | |
# # Call your processing functions | |
# configs.process_success = process_data(subdir) | |
# configs.subdir = subdir | |
# result = infer_detect(runner, configs) | |
# # This function should return 'Watermarked' or 'Not Watermarked' | |
# temp_pdb_path = convert_cif_to_pdb(file_path) | |
# if result==False: | |
# return "Not Watermarked", temp_pdb_path | |
# else: | |
# return "Watermarked", temp_pdb_path | |
# with gr.Tab("Watermark Detector"): | |
# # First create the upload component | |
# cif_upload = gr.File(label="Upload .cif", file_types=["..cif"]) | |
# with gr.Row(): | |
# cif_3d_view = Molecule3D(label="3D Visualization of Input", reps=reps) | |
# # Prediction output | |
# prediction_output = gr.Textbox(label="Prediction") | |
# # Define the interaction | |
# cif_upload.change(is_watermarked, inputs=cif_upload, outputs=[prediction_output, cif_3d_view]) | |
# # Example files | |
# example_files = [ | |
# "./examples/7r6r_watermarked.cif", | |
# "./examples/7pzb_unwatermarked.cif" | |
# ] | |
# gr.Examples(examples=example_files, inputs=cif_upload) | |
# if __name__ == "__main__": | |
# demo.launch(share=True) |