FoldMark / app.py
Zaixi's picture
thnaks
e8bea69
raw
history blame
18.5 kB
# import spaces
# import logging
# import gradio as gr
# import os
# import uuid
# from datetime import datetime
# import numpy as np
# from configs.configs_base import configs as configs_base
# from configs.configs_data import data_configs
# from configs.configs_inference import inference_configs
# from runner.inference import download_infercence_cache, update_inference_configs, infer_predict, infer_detect, InferenceRunner
# from protenix.config import parse_configs, parse_sys_args
# from runner.msa_search import update_infer_json
# from protenix.web_service.prediction_visualization import plot_best_confidence_measure, PredictionLoader
# from process_data import process_data
# import json
# from typing import Dict, List
# from Bio.PDB import MMCIFParser, PDBIO
# import tempfile
# import shutil
# from Bio import PDB
# from gradio_molecule3d import Molecule3D
# EXAMPLE_PATH = './examples/example.json'
# example_json=[{'sequences': [{'proteinChain': {'sequence': 'MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH', 'count': 2}}, {'dnaSequence': {'sequence': 'CTAGGTAACATTACTCGCG', 'count': 2}}, {'dnaSequence': {'sequence': 'GCGAGTAATGTTAC', 'count': 2}}, {'ligand': {'ligand': 'CCD_PCG', 'count': 2}}], 'name': '7pzb_need_search_msa'}]
# # Custom CSS for styling
# custom_css = """
# #logo {
# width: 50%;
# }
# .title {
# font-size: 32px;
# font-weight: bold;
# color: #4CAF50;
# display: flex;
# align-items: center; /* Vertically center the logo and text */
# }
# """
# os.environ["LAYERNORM_TYPE"] = "fast_layernorm"
# os.environ["USE_DEEPSPEED_EVO_ATTTENTION"] = "False"
# # Set environment variable in the script
# #os.environ['CUTLASS_PATH'] = './cutlass'
# # reps = [
# # {
# # "model": 0,
# # "chain": "",
# # "resname": "",
# # "style": "cartoon", # Use cartoon style
# # "color": "whiteCarbon",
# # "residue_range": "",
# # "around": 0,
# # "byres": False,
# # "visible": True # Ensure this representation is visible
# # }
# # ]
# reps = [
# {
# "model": 0,
# "chain": "",
# "resname": "",
# "style": "cartoon",
# "color": "whiteCarbon",
# "residue_range": "",
# "around": 0,
# "byres": False,
# "opacity": 0.2,
# },
# {
# "model": 1,
# "chain": "",
# "resname": "",
# "style": "cartoon",
# "color": "cyanCarbon",
# "residue_range": "",
# "around": 0,
# "byres": False,
# "opacity": 0.8,
# }
# ]
# ##
# def align_pdb_files(pdb_file_1, pdb_file_2):
# # Load the structures
# parser = PDB.PPBuilder()
# io = PDB.PDBIO()
# structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1)
# structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2)
# # Superimpose the second structure onto the first
# super_imposer = PDB.Superimposer()
# model_1 = structure_1[0]
# model_2 = structure_2[0]
# # Extract the coordinates from the two structures
# atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] # Use CA atoms
# atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"]
# # Align the structures based on the CA atoms
# coord_1 = [atom.get_coord() for atom in atoms_1]
# coord_2 = [atom.get_coord() for atom in atoms_2]
# super_imposer.set_atoms(atoms_1, atoms_2)
# super_imposer.apply(model_2) # Apply the transformation to model_2
# # Save the aligned structure back to the original file
# io.set_structure(structure_2) # Save the aligned structure to the second file (original file)
# io.save(pdb_file_2)
# # Function to convert .cif to .pdb and save as a temporary file
# def convert_cif_to_pdb(cif_path):
# """
# Convert a CIF file to a PDB file and save it as a temporary file.
# Args:
# cif_path (str): Path to the input CIF file.
# Returns:
# str: Path to the temporary PDB file.
# """
# # Initialize the MMCIF parser
# parser = MMCIFParser()
# structure = parser.get_structure("protein", cif_path)
# # Create a temporary file for the PDB output
# with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file:
# temp_pdb_path = temp_file.name
# # Save the structure as a PDB file
# io = PDBIO()
# io.set_structure(structure)
# io.save(temp_pdb_path)
# return temp_pdb_path
# def plot_3d(pred_loader):
# # Get the CIF file path for the given prediction ID
# cif_path = sorted(pred_loader.cif_paths)[0]
# # Convert the CIF file to a temporary PDB file
# temp_pdb_path = convert_cif_to_pdb(cif_path)
# return temp_pdb_path, cif_path
# def parse_json_input(json_data: List[Dict]) -> Dict:
# """Convert Protenix JSON format to UI-friendly structure"""
# components = {
# "protein_chains": [],
# "dna_sequences": [],
# "ligands": [],
# "complex_name": ""
# }
# for entry in json_data:
# components["complex_name"] = entry.get("name", "")
# for seq in entry["sequences"]:
# if "proteinChain" in seq:
# components["protein_chains"].append({
# "sequence": seq["proteinChain"]["sequence"],
# "count": seq["proteinChain"]["count"]
# })
# elif "dnaSequence" in seq:
# components["dna_sequences"].append({
# "sequence": seq["dnaSequence"]["sequence"],
# "count": seq["dnaSequence"]["count"]
# })
# elif "ligand" in seq:
# components["ligands"].append({
# "type": seq["ligand"]["ligand"],
# "count": seq["ligand"]["count"]
# })
# return components
# def create_protenix_json(input_data: Dict) -> List[Dict]:
# """Convert UI inputs to Protenix JSON format"""
# sequences = []
# for pc in input_data["protein_chains"]:
# sequences.append({
# "proteinChain": {
# "sequence": pc["sequence"],
# "count": pc["count"]
# }
# })
# for dna in input_data["dna_sequences"]:
# sequences.append({
# "dnaSequence": {
# "sequence": dna["sequence"],
# "count": dna["count"]
# }
# })
# for lig in input_data["ligands"]:
# sequences.append({
# "ligand": {
# "ligand": lig["type"],
# "count": lig["count"]
# }
# })
# return [{
# "sequences": sequences,
# "name": input_data["complex_name"]
# }]
# #@torch.inference_mode()
# @spaces.GPU(duration=120) # Specify a duration to avoid timeout
# def predict_structure(input_collector: dict):
# #first initialize runner
# runner = InferenceRunner(configs)
# """Handle both input types"""
# os.makedirs("./output", exist_ok=True)
# # Generate random filename with timestamp
# random_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
# save_path = os.path.join("./output", f"{random_name}.json")
# print(input_collector)
# # Handle JSON input
# if input_collector["json"]:
# # Handle different input types
# if isinstance(input_collector["json"], str): # Example JSON case (file path)
# input_data = json.load(open(input_collector["json"]))
# elif hasattr(input_collector["json"], "name"): # File upload case
# input_data = json.load(open(input_collector["json"].name))
# else: # Direct JSON data case
# input_data = input_collector["json"]
# else: # Manual input case
# input_data = create_protenix_json(input_collector["data"])
# with open(save_path, "w") as f:
# json.dump(input_data, f, indent=2)
# if input_data==example_json and input_collector['watermark']==True:
# configs.saved_path = './output/example_output/'
# else:
# # run msa
# json_file = update_infer_json(save_path, './output', True)
# # Run prediction
# configs.input_json_path = json_file
# configs.watermark = input_collector['watermark']
# configs.saved_path = os.path.join("./output/", random_name)
# infer_predict(runner, configs)
# #saved_path = os.path.join('./output', f"{sample_name}", f"seed_{seed}", 'predictions')
# # Generate visualizations
# pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions'))
# view3d, cif_path = plot_3d(pred_loader=pred_loader)
# if configs.watermark:
# pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions_orig'))
# view3d_orig, _ = plot_3d(pred_loader=pred_loader)
# align_pdb_files(view3d, view3d_orig)
# view3d = [view3d, view3d_orig]
# plot_best_confidence_measure(os.path.join(configs.saved_path, 'predictions'))
# confidence_img_path = os.path.join(os.path.join(configs.saved_path, 'predictions'), "best_sample_confidence.png")
# return view3d, confidence_img_path, cif_path
# logger = logging.getLogger(__name__)
# LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
# logging.basicConfig(
# format=LOG_FORMAT,
# level=logging.INFO,
# datefmt="%Y-%m-%d %H:%M:%S",
# filemode="w",
# )
# configs_base["use_deepspeed_evo_attention"] = (
# os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "False"
# )
# arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 "
# configs = {**configs_base, **{"data": data_configs}, **inference_configs}
# configs = parse_configs(
# configs=configs,
# arg_str=arg_str,
# fill_required_with_null=True,
# )
# configs.load_checkpoint_path='./checkpoint.pt'
# download_infercence_cache()
# configs.use_deepspeed_evo_attention=False
# add_watermark = gr.Checkbox(label="Add Watermark", value=True)
# add_watermark1 = gr.Checkbox(label="Add Watermark", value=True)
# with gr.Blocks(title="FoldMark", css=custom_css) as demo:
# with gr.Row():
# # Use a Column to align the logo and title horizontally
# gr.Image(value="./assets/foldmark_head.png", elem_id="logo", label="Logo", height=150, show_label=False)
# with gr.Tab("Structure Predictor (JSON Upload)"):
# # First create the upload component
# json_upload = gr.File(label="Upload JSON", file_types=[".json"])
# # Then create the example component that references it
# gr.Examples(
# examples=[[EXAMPLE_PATH]],
# inputs=[json_upload],
# label="Click to use example JSON:",
# examples_per_page=1
# )
# # Rest of the components
# upload_name = gr.Textbox(label="Complex Name (optional)")
# upload_output = gr.JSON(label="Parsed Components")
# json_upload.upload(
# fn=lambda f: parse_json_input(json.load(open(f.name))),
# inputs=json_upload,
# outputs=upload_output
# )
# # Shared prediction components
# with gr.Row():
# add_watermark.render()
# submit_btn = gr.Button("Predict Structure", variant="primary")
# #structure_view = gr.HTML(label="3D Visualization")
# with gr.Row():
# view3d = Molecule3D(label="3D Visualization", reps=reps)
# legend = gr.Markdown("""
# **Color Legend:**
# - <span style="color:grey">Unwatermarked Structure</span>
# - <span style="color:cyan">Watermarked Structure</span>
# """)
# with gr.Row():
# cif_file = gr.File(label="Download CIF File")
# with gr.Row():
# confidence_plot_image = gr.Image(label="Confidence Measures")
# input_collector = gr.JSON(visible=False)
# # Map inputs to a dictionary
# submit_btn.click(
# fn=lambda j, w: {"json": j, "watermark": w},
# inputs=[json_upload, add_watermark],
# outputs=input_collector
# ).then(
# fn=predict_structure,
# inputs=input_collector,
# outputs=[view3d, confidence_plot_image, cif_file]
# )
# gr.Markdown("""
# The example of the uploaded json file for structure prediction.
# <pre>
# [{
# "sequences": [
# {
# "proteinChain": {
# "sequence": "MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH",
# "count": 2
# }
# },
# {
# "dnaSequence": {
# "sequence": "CTAGGTAACATTACTCGCG",
# "count": 2
# }
# },
# {
# "dnaSequence": {
# "sequence": "GCGAGTAATGTTAC",
# "count": 2
# }
# },
# {
# "ligand": {
# "ligand": "CCD_PCG",
# "count": 2
# }
# }
# ],
# "name": "7pzb"
# }]
# </pre>
# """)
# with gr.Tab("Structure Predictor (Manual Input)"):
# with gr.Row():
# complex_name = gr.Textbox(label="Complex Name")
# # Replace gr.Group with gr.Accordion
# with gr.Accordion(label="Protein Chains", open=True):
# protein_chains = gr.Dataframe(
# headers=["Sequence", "Count"],
# datatype=["str", "number"],
# row_count=1,
# col_count=(2, "fixed")
# )
# # Repeat for other groups
# with gr.Accordion(label="DNA Sequences", open=True):
# dna_sequences = gr.Dataframe(
# headers=["Sequence", "Count"],
# datatype=["str", "number"],
# row_count=1
# )
# with gr.Accordion(label="Ligands", open=True):
# ligands = gr.Dataframe(
# headers=["Ligand Type", "Count"],
# datatype=["str", "number"],
# row_count=1
# )
# manual_output = gr.JSON(label="Generated JSON")
# complex_name.change(
# fn=lambda x: {"complex_name": x},
# inputs=complex_name,
# outputs=manual_output
# )
# # Shared prediction components
# with gr.Row():
# add_watermark1.render()
# submit_btn = gr.Button("Predict Structure", variant="primary")
# #structure_view = gr.HTML(label="3D Visualization")
# with gr.Row():
# view3d = Molecule3D(label="3D Visualization (Gray: Unwatermarked; Cyan: Watermarked)", reps=reps)
# with gr.Row():
# cif_file = gr.File(label="Download CIF File")
# with gr.Row():
# confidence_plot_image = gr.Image(label="Confidence Measures")
# input_collector = gr.JSON(visible=False)
# # Map inputs to a dictionary
# submit_btn.click(
# fn=lambda c, p, d, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "ligands": l}, "watermark": w},
# inputs=[complex_name, protein_chains, dna_sequences, ligands, add_watermark1],
# outputs=input_collector
# ).then(
# fn=predict_structure,
# inputs=input_collector,
# outputs=[view3d, confidence_plot_image, cif_file]
# )
# @spaces.GPU(duration=120)
# def is_watermarked(file):
# #first initialize runner
# runner = InferenceRunner(configs)
# # Generate a unique subdirectory and filename
# unique_id = str(uuid.uuid4().hex[:8])
# subdir = os.path.join('./output', unique_id)
# os.makedirs(subdir, exist_ok=True)
# filename = f"{unique_id}.cif"
# file_path = os.path.join(subdir, filename)
# # Save the uploaded file to the new location
# shutil.copy(file.name, file_path)
# # Call your processing functions
# configs.process_success = process_data(subdir)
# configs.subdir = subdir
# result = infer_detect(runner, configs)
# # This function should return 'Watermarked' or 'Not Watermarked'
# temp_pdb_path = convert_cif_to_pdb(file_path)
# if result==False:
# return "Not Watermarked", temp_pdb_path
# else:
# return "Watermarked", temp_pdb_path
# with gr.Tab("Watermark Detector"):
# # First create the upload component
# cif_upload = gr.File(label="Upload .cif", file_types=["..cif"])
# with gr.Row():
# cif_3d_view = Molecule3D(label="3D Visualization of Input", reps=reps)
# # Prediction output
# prediction_output = gr.Textbox(label="Prediction")
# # Define the interaction
# cif_upload.change(is_watermarked, inputs=cif_upload, outputs=[prediction_output, cif_3d_view])
# # Example files
# example_files = [
# "./examples/7r6r_watermarked.cif",
# "./examples/7pzb_unwatermarked.cif"
# ]
# gr.Examples(examples=example_files, inputs=cif_upload)
# if __name__ == "__main__":
# demo.launch(share=True)