FoldMark / app.py
Zaixi's picture
add
c6eeef3
raw
history blame
17.7 kB
import logging
import gradio as gr
import os
import uuid
from datetime import datetime
import numpy as np
from configs.configs_base import configs as configs_base
from configs.configs_data import data_configs
from configs.configs_inference import inference_configs
from runner.inference import download_infercence_cache, update_inference_configs, infer_predict, infer_detect, InferenceRunner
from protenix.config import parse_configs, parse_sys_args
from runner.msa_search import update_infer_json
from protenix.web_service.prediction_visualization import plot_best_confidence_measure, PredictionLoader
from process_data import process_data
import json
from typing import Dict, List
from Bio.PDB import MMCIFParser, PDBIO
import tempfile
import shutil
from Bio import PDB
from gradio_molecule3d import Molecule3D
#import spaces # Import spaces for ZeroGPU compatibility
EXAMPLE_PATH = './examples/example.json'
example_json=[{'sequences': [{'proteinChain': {'sequence': 'MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH', 'count': 2}}, {'dnaSequence': {'sequence': 'CTAGGTAACATTACTCGCG', 'count': 2}}, {'dnaSequence': {'sequence': 'GCGAGTAATGTTAC', 'count': 2}}, {'ligand': {'ligand': 'CCD_PCG', 'count': 2}}], 'name': '7pzb_need_search_msa'}]
# Custom CSS for styling
custom_css = """
#logo {
width: 50%;
}
.title {
font-size: 32px;
font-weight: bold;
color: #4CAF50;
display: flex;
align-items: center; /* Vertically center the logo and text */
}
"""
os.environ["LAYERNORM_TYPE"] = "fast_layernorm"
os.environ["USE_DEEPSPEED_EVO_ATTTENTION"] = "False"
# Set environment variable in the script
#os.environ['CUTLASS_PATH'] = './cutlass'
# reps = [
# {
# "model": 0,
# "chain": "",
# "resname": "",
# "style": "cartoon", # Use cartoon style
# "color": "whiteCarbon",
# "residue_range": "",
# "around": 0,
# "byres": False,
# "visible": True # Ensure this representation is visible
# }
# ]
reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "cartoon",
"color": "whiteCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"opacity": 0.2,
},
{
"model": 1,
"chain": "",
"resname": "",
"style": "cartoon",
"color": "cyanCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"opacity": 0.8,
}
]
def align_pdb_files(pdb_file_1, pdb_file_2):
# Load the structures
parser = PDB.PPBuilder()
io = PDB.PDBIO()
structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1)
structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2)
# Superimpose the second structure onto the first
super_imposer = PDB.Superimposer()
model_1 = structure_1[0]
model_2 = structure_2[0]
# Extract the coordinates from the two structures
atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] # Use CA atoms
atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"]
# Align the structures based on the CA atoms
coord_1 = [atom.get_coord() for atom in atoms_1]
coord_2 = [atom.get_coord() for atom in atoms_2]
super_imposer.set_atoms(atoms_1, atoms_2)
super_imposer.apply(model_2) # Apply the transformation to model_2
# Save the aligned structure back to the original file
io.set_structure(structure_2) # Save the aligned structure to the second file (original file)
io.save(pdb_file_2)
# Function to convert .cif to .pdb and save as a temporary file
def convert_cif_to_pdb(cif_path):
"""
Convert a CIF file to a PDB file and save it as a temporary file.
Args:
cif_path (str): Path to the input CIF file.
Returns:
str: Path to the temporary PDB file.
"""
# Initialize the MMCIF parser
parser = MMCIFParser()
structure = parser.get_structure("protein", cif_path)
# Create a temporary file for the PDB output
with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file:
temp_pdb_path = temp_file.name
# Save the structure as a PDB file
io = PDBIO()
io.set_structure(structure)
io.save(temp_pdb_path)
return temp_pdb_path
def plot_3d(pred_loader):
# Get the CIF file path for the given prediction ID
cif_path = sorted(pred_loader.cif_paths)[0]
# Convert the CIF file to a temporary PDB file
temp_pdb_path = convert_cif_to_pdb(cif_path)
return temp_pdb_path, cif_path
def parse_json_input(json_data: List[Dict]) -> Dict:
"""Convert Protenix JSON format to UI-friendly structure"""
components = {
"protein_chains": [],
"dna_sequences": [],
"ligands": [],
"complex_name": ""
}
for entry in json_data:
components["complex_name"] = entry.get("name", "")
for seq in entry["sequences"]:
if "proteinChain" in seq:
components["protein_chains"].append({
"sequence": seq["proteinChain"]["sequence"],
"count": seq["proteinChain"]["count"]
})
elif "dnaSequence" in seq:
components["dna_sequences"].append({
"sequence": seq["dnaSequence"]["sequence"],
"count": seq["dnaSequence"]["count"]
})
elif "ligand" in seq:
components["ligands"].append({
"type": seq["ligand"]["ligand"],
"count": seq["ligand"]["count"]
})
return components
def create_protenix_json(input_data: Dict) -> List[Dict]:
"""Convert UI inputs to Protenix JSON format"""
sequences = []
for pc in input_data["protein_chains"]:
sequences.append({
"proteinChain": {
"sequence": pc["sequence"],
"count": pc["count"]
}
})
for dna in input_data["dna_sequences"]:
sequences.append({
"dnaSequence": {
"sequence": dna["sequence"],
"count": dna["count"]
}
})
for lig in input_data["ligands"]:
sequences.append({
"ligand": {
"ligand": lig["type"],
"count": lig["count"]
}
})
return [{
"sequences": sequences,
"name": input_data["complex_name"]
}]
#@torch.inference_mode()
#@spaces.GPU(duration=120) # Specify a duration to avoid timeout
def predict_structure(input_collector: dict):
"""Handle both input types"""
os.makedirs("./output", exist_ok=True)
# Generate random filename with timestamp
random_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
save_path = os.path.join("./output", f"{random_name}.json")
print(input_collector)
# Handle JSON input
if input_collector["json"]:
# Handle different input types
if isinstance(input_collector["json"], str): # Example JSON case (file path)
input_data = json.load(open(input_collector["json"]))
elif hasattr(input_collector["json"], "name"): # File upload case
input_data = json.load(open(input_collector["json"].name))
else: # Direct JSON data case
input_data = input_collector["json"]
else: # Manual input case
input_data = create_protenix_json(input_collector["data"])
with open(save_path, "w") as f:
json.dump(input_data, f, indent=2)
if input_data==example_json and input_collector['watermark']==True:
configs.saved_path = './output/example_output/'
else:
# run msa
json_file = update_infer_json(save_path, './output', True)
# Run prediction
configs.input_json_path = json_file
configs.watermark = input_collector['watermark']
configs.saved_path = os.path.join("./output/", random_name)
infer_predict(runner, configs)
#saved_path = os.path.join('./output', f"{sample_name}", f"seed_{seed}", 'predictions')
# Generate visualizations
pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions'))
view3d, cif_path = plot_3d(pred_loader=pred_loader)
if configs.watermark:
pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions_orig'))
view3d_orig, _ = plot_3d(pred_loader=pred_loader)
align_pdb_files(view3d, view3d_orig)
view3d = [view3d, view3d_orig]
plot_best_confidence_measure(os.path.join(configs.saved_path, 'predictions'))
confidence_img_path = os.path.join(os.path.join(configs.saved_path, 'predictions'), "best_sample_confidence.png")
return view3d, confidence_img_path, cif_path
logger = logging.getLogger(__name__)
LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
logging.basicConfig(
format=LOG_FORMAT,
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
filemode="w",
)
configs_base["use_deepspeed_evo_attention"] = (
os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "False"
)
arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 "
configs = {**configs_base, **{"data": data_configs}, **inference_configs}
configs = parse_configs(
configs=configs,
arg_str=arg_str,
fill_required_with_null=True,
)
configs.load_checkpoint_path='./checkpoint.pt'
download_infercence_cache(configs, model_version="v0.2.0")
configs.use_deepspeed_evo_attention=False
runner = InferenceRunner(configs)
add_watermark = gr.Checkbox(label="Add Watermark", value=True)
add_watermark1 = gr.Checkbox(label="Add Watermark", value=True)
with gr.Blocks(title="FoldMark", css=custom_css) as demo:
with gr.Row():
# Use a Column to align the logo and title horizontally
gr.Image(value="./assets/foldmark_head.png", elem_id="logo", label="Logo", height=150, show_label=False)
with gr.Tab("Structure Predictor (JSON Upload)"):
# First create the upload component
json_upload = gr.File(label="Upload JSON", file_types=[".json"])
# Then create the example component that references it
gr.Examples(
examples=[[EXAMPLE_PATH]],
inputs=[json_upload],
label="Click to use example JSON:",
examples_per_page=1
)
# Rest of the components
upload_name = gr.Textbox(label="Complex Name (optional)")
upload_output = gr.JSON(label="Parsed Components")
json_upload.upload(
fn=lambda f: parse_json_input(json.load(open(f.name))),
inputs=json_upload,
outputs=upload_output
)
# Shared prediction components
with gr.Row():
add_watermark.render()
submit_btn = gr.Button("Predict Structure", variant="primary")
#structure_view = gr.HTML(label="3D Visualization")
with gr.Row():
view3d = Molecule3D(label="3D Visualization", reps=reps)
legend = gr.Markdown("""
**Color Legend:**
- <span style="color:grey">Unwatermarked Structure</span>
- <span style="color:cyan">Watermarked Structure</span>
""")
with gr.Row():
cif_file = gr.File(label="Download CIF File")
with gr.Row():
confidence_plot_image = gr.Image(label="Confidence Measures")
input_collector = gr.JSON(visible=False)
# Map inputs to a dictionary
submit_btn.click(
fn=lambda j, w: {"json": j, "watermark": w},
inputs=[json_upload, add_watermark],
outputs=input_collector
).then(
fn=predict_structure,
inputs=input_collector,
outputs=[view3d, confidence_plot_image, cif_file]
)
gr.Markdown("""
The example of the uploaded json file for structure prediction.
<pre>
[{
"sequences": [
{
"proteinChain": {
"sequence": "MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH",
"count": 2
}
},
{
"dnaSequence": {
"sequence": "CTAGGTAACATTACTCGCG",
"count": 2
}
},
{
"dnaSequence": {
"sequence": "GCGAGTAATGTTAC",
"count": 2
}
},
{
"ligand": {
"ligand": "CCD_PCG",
"count": 2
}
}
],
"name": "7pzb"
}]
</pre>
""")
with gr.Tab("Structure Predictor (Manual Input)"):
with gr.Row():
complex_name = gr.Textbox(label="Complex Name")
# Replace gr.Group with gr.Accordion
with gr.Accordion(label="Protein Chains", open=True):
protein_chains = gr.Dataframe(
headers=["Sequence", "Count"],
datatype=["str", "number"],
row_count=1,
col_count=(2, "fixed")
)
# Repeat for other groups
with gr.Accordion(label="DNA Sequences", open=True):
dna_sequences = gr.Dataframe(
headers=["Sequence", "Count"],
datatype=["str", "number"],
row_count=1
)
with gr.Accordion(label="Ligands", open=True):
ligands = gr.Dataframe(
headers=["Ligand Type", "Count"],
datatype=["str", "number"],
row_count=1
)
manual_output = gr.JSON(label="Generated JSON")
complex_name.change(
fn=lambda x: {"complex_name": x},
inputs=complex_name,
outputs=manual_output
)
# Shared prediction components
with gr.Row():
add_watermark1.render()
submit_btn = gr.Button("Predict Structure", variant="primary")
#structure_view = gr.HTML(label="3D Visualization")
with gr.Row():
view3d = Molecule3D(label="3D Visualization", reps=reps)
legend = gr.Markdown("""
**Color Legend:**
- <span style="color:grey">Unwatermarked Structure</span>
- <span style="color:cyan">Watermarked Structure</span>
""")
with gr.Row():
cif_file = gr.File(label="Download CIF File")
with gr.Row():
confidence_plot_image = gr.Image(label="Confidence Measures")
input_collector = gr.JSON(visible=False)
# Map inputs to a dictionary
submit_btn.click(
fn=lambda c, p, d, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "ligands": l}, "watermark": w},
inputs=[complex_name, protein_chains, dna_sequences, ligands, add_watermark1],
outputs=input_collector
).then(
fn=predict_structure,
inputs=input_collector,
outputs=[view3d, confidence_plot_image, cif_file]
)
#@spaces.GPU
def is_watermarked(file):
# Generate a unique subdirectory and filename
unique_id = str(uuid.uuid4().hex[:8])
subdir = os.path.join('./output', unique_id)
os.makedirs(subdir, exist_ok=True)
filename = f"{unique_id}.cif"
file_path = os.path.join(subdir, filename)
# Save the uploaded file to the new location
shutil.copy(file.name, file_path)
# Call your processing functions
configs.process_success = process_data(subdir)
configs.subdir = subdir
result = infer_detect(runner, configs)
# This function should return 'Watermarked' or 'Not Watermarked'
temp_pdb_path = convert_cif_to_pdb(file_path)
if result==False:
return "Not Watermarked", temp_pdb_path
else:
return "Watermarked", temp_pdb_path
with gr.Tab("Watermark Detector"):
# First create the upload component
cif_upload = gr.File(label="Upload .cif", file_types=["..cif"])
with gr.Row():
cif_3d_view = Molecule3D(label="3D Visualization of Input", reps=reps)
# Prediction output
prediction_output = gr.Textbox(label="Prediction")
# Define the interaction
cif_upload.change(is_watermarked, inputs=cif_upload, outputs=[prediction_output, cif_3d_view])
# Example files
example_files = [
"./examples/7r6r_watermarked.cif",
"./examples/7pzb_unwatermarked.cif"
]
gr.Examples(examples=example_files, inputs=cif_upload)
if __name__ == "__main__":
demo.launch(share=True)