FoldMark / app.py
Zaixi's picture
Update app.py
5347f77 verified
raw
history blame
18.3 kB
import spaces
import logging
import gradio as gr
import os
import uuid
from datetime import datetime
import numpy as np
from configs.configs_base import configs as configs_base
from configs.configs_data import data_configs
from configs.configs_inference import inference_configs
from runner.inference import download_infercence_cache, update_inference_configs, infer_predict, infer_detect, InferenceRunner
from protenix.config import parse_configs, parse_sys_args
from runner.msa_search import update_infer_json
from protenix.web_service.prediction_visualization import plot_best_confidence_measure, PredictionLoader
from process_data import process_data
import json
from typing import Dict, List
from Bio.PDB import MMCIFParser, PDBIO
import tempfile
import shutil
from Bio import PDB
from gradio_molecule3d import Molecule3D
EXAMPLE_PATH = './examples/example.json'
example_json=[{'sequences': [{'proteinChain': {'sequence': 'MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH', 'count': 2}}, {'dnaSequence': {'sequence': 'CTAGGTAACATTACTCGCG', 'count': 2}}, {'dnaSequence': {'sequence': 'GCGAGTAATGTTAC', 'count': 2}}, {'ligand': {'ligand': 'CCD_PCG', 'count': 2}}], 'name': '7pzb_need_search_msa'}]
# Custom CSS for styling
custom_css = """
#logo {
width: 50%;
}
.title {
font-size: 32px;
font-weight: bold;
color: #4CAF50;
display: flex;
align-items: center; /* Vertically center the logo and text */
}
"""
os.environ["LAYERNORM_TYPE"] = "fast_layernorm"
os.environ["USE_DEEPSPEED_EVO_ATTTENTION"] = "False"
# Set environment variable in the script
#os.environ['CUTLASS_PATH'] = './cutlass'
# reps = [
# {
# "model": 0,
# "chain": "",
# "resname": "",
# "style": "cartoon", # Use cartoon style
# "color": "whiteCarbon",
# "residue_range": "",
# "around": 0,
# "byres": False,
# "visible": True # Ensure this representation is visible
# }
# ]
reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "cartoon",
"color": "whiteCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"opacity": 0.2,
},
{
"model": 1,
"chain": "",
"resname": "",
"style": "cartoon",
"color": "cyanCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"opacity": 0.8,
}
]
##
def align_pdb_files(pdb_file_1, pdb_file_2):
# Load the structures
parser = PDB.PPBuilder()
io = PDB.PDBIO()
structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1)
structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2)
# Superimpose the second structure onto the first
super_imposer = PDB.Superimposer()
model_1 = structure_1[0]
model_2 = structure_2[0]
# Extract the coordinates from the two structures
atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] # Use CA atoms
atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"]
# Align the structures based on the CA atoms
coord_1 = [atom.get_coord() for atom in atoms_1]
coord_2 = [atom.get_coord() for atom in atoms_2]
super_imposer.set_atoms(atoms_1, atoms_2)
super_imposer.apply(model_2) # Apply the transformation to model_2
# Save the aligned structure back to the original file
io.set_structure(structure_2) # Save the aligned structure to the second file (original file)
io.save(pdb_file_2)
# Function to convert .cif to .pdb and save as a temporary file
def convert_cif_to_pdb(cif_path):
"""
Convert a CIF file to a PDB file and save it as a temporary file.
Args:
cif_path (str): Path to the input CIF file.
Returns:
str: Path to the temporary PDB file.
"""
# Initialize the MMCIF parser
parser = MMCIFParser()
structure = parser.get_structure("protein", cif_path)
# Create a temporary file for the PDB output
with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file:
temp_pdb_path = temp_file.name
# Save the structure as a PDB file
io = PDBIO()
io.set_structure(structure)
io.save(temp_pdb_path)
return temp_pdb_path
def plot_3d(pred_loader):
# Get the CIF file path for the given prediction ID
cif_path = sorted(pred_loader.cif_paths)[0]
# Convert the CIF file to a temporary PDB file
temp_pdb_path = convert_cif_to_pdb(cif_path)
return temp_pdb_path, cif_path
def parse_json_input(json_data: List[Dict]) -> Dict:
"""Convert Protenix JSON format to UI-friendly structure"""
components = {
"protein_chains": [],
"dna_sequences": [],
"ligands": [],
"complex_name": ""
}
for entry in json_data:
components["complex_name"] = entry.get("name", "")
for seq in entry["sequences"]:
if "proteinChain" in seq:
components["protein_chains"].append({
"sequence": seq["proteinChain"]["sequence"],
"count": seq["proteinChain"]["count"]
})
elif "dnaSequence" in seq:
components["dna_sequences"].append({
"sequence": seq["dnaSequence"]["sequence"],
"count": seq["dnaSequence"]["count"]
})
elif "ligand" in seq:
components["ligands"].append({
"type": seq["ligand"]["ligand"],
"count": seq["ligand"]["count"]
})
return components
def create_protenix_json(input_data: Dict) -> List[Dict]:
"""Convert UI inputs to Protenix JSON format"""
sequences = []
for pc in input_data["protein_chains"]:
sequences.append({
"proteinChain": {
"sequence": pc["sequence"],
"count": pc["count"]
}
})
for dna in input_data["dna_sequences"]:
sequences.append({
"dnaSequence": {
"sequence": dna["sequence"],
"count": dna["count"]
}
})
for lig in input_data["ligands"]:
sequences.append({
"ligand": {
"ligand": lig["type"],
"count": lig["count"]
}
})
return [{
"sequences": sequences,
"name": input_data["complex_name"]
}]
#@torch.inference_mode()
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
def predict_structure(input_collector: dict):
#first initialize runner
runner = InferenceRunner(configs)
"""Handle both input types"""
os.makedirs("./output", exist_ok=True)
# Generate random filename with timestamp
random_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
save_path = os.path.join("./output", f"{random_name}.json")
print(input_collector)
# Handle JSON input
if input_collector["json"]:
# Handle different input types
if isinstance(input_collector["json"], str): # Example JSON case (file path)
input_data = json.load(open(input_collector["json"]))
elif hasattr(input_collector["json"], "name"): # File upload case
input_data = json.load(open(input_collector["json"].name))
else: # Direct JSON data case
input_data = input_collector["json"]
else: # Manual input case
input_data = create_protenix_json(input_collector["data"])
with open(save_path, "w") as f:
json.dump(input_data, f, indent=2)
if input_data==example_json and input_collector['watermark']==True:
configs.saved_path = './output/example_output/'
else:
# run msa
json_file = update_infer_json(save_path, './output', True)
# Run prediction
configs.input_json_path = json_file
configs.watermark = input_collector['watermark']
configs.saved_path = os.path.join("./output/", random_name)
infer_predict(runner, configs)
#saved_path = os.path.join('./output', f"{sample_name}", f"seed_{seed}", 'predictions')
# Generate visualizations
pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions'))
view3d, cif_path = plot_3d(pred_loader=pred_loader)
if configs.watermark:
pred_loader = PredictionLoader(os.path.join(configs.saved_path, 'predictions_orig'))
view3d_orig, _ = plot_3d(pred_loader=pred_loader)
align_pdb_files(view3d, view3d_orig)
view3d = [view3d, view3d_orig]
plot_best_confidence_measure(os.path.join(configs.saved_path, 'predictions'))
confidence_img_path = os.path.join(os.path.join(configs.saved_path, 'predictions'), "best_sample_confidence.png")
return view3d, confidence_img_path, cif_path
logger = logging.getLogger(__name__)
LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
logging.basicConfig(
format=LOG_FORMAT,
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
filemode="w",
)
configs_base["use_deepspeed_evo_attention"] = (
os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "False"
)
arg_str = "--seeds 101 --dump_dir ./output --input_json_path ./examples/example.json --model.N_cycle 10 --sample_diffusion.N_sample 5 --sample_diffusion.N_step 200 "
configs = {**configs_base, **{"data": data_configs}, **inference_configs}
configs = parse_configs(
configs=configs,
arg_str=arg_str,
fill_required_with_null=True,
)
configs.load_checkpoint_path='./checkpoint.pt'
download_infercence_cache()
configs.use_deepspeed_evo_attention=False
add_watermark = gr.Checkbox(label="Add Watermark", value=True)
add_watermark1 = gr.Checkbox(label="Add Watermark", value=True)
with gr.Blocks(title="FoldMark", css=custom_css) as demo:
with gr.Row():
# Use a Column to align the logo and title horizontally
gr.Image(value="./assets/foldmark_head.png", elem_id="logo", label="Logo", height=150, show_label=False)
with gr.Tab("Structure Predictor (JSON Upload)"):
# First create the upload component
json_upload = gr.File(label="Upload JSON", file_types=[".json"])
# Then create the example component that references it
gr.Examples(
examples=[[EXAMPLE_PATH]],
inputs=[json_upload],
label="Click to use example JSON:",
examples_per_page=1
)
# Rest of the components
upload_name = gr.Textbox(label="Complex Name (optional)")
upload_output = gr.JSON(label="Parsed Components")
json_upload.upload(
fn=lambda f: parse_json_input(json.load(open(f.name))),
inputs=json_upload,
outputs=upload_output
)
# Shared prediction components
with gr.Row():
add_watermark.render()
submit_btn = gr.Button("Predict Structure", variant="primary")
#structure_view = gr.HTML(label="3D Visualization")
with gr.Row():
view3d = Molecule3D(label="3D Visualization(Gray: Unwatermarked; Cyan: Watermarked)", reps=reps)
# legend = gr.Markdown("""
# **Color Legend:**
# - <span style="color:grey">Gray: Unwatermarked Structure</span>
# - <span style="color:cyan">Cyan: Watermarked Structure</span>
# """)
legend = gr.HTML("""
<div>
<strong>Color Legend:</strong><br>
- <span style="color:grey;">Gray: Unwatermarked Structure</span><br>
- <span style="color:cyan;">Cyan: Watermarked Structure</span>
</div>
""")
with gr.Row():
cif_file = gr.File(label="Download CIF File")
with gr.Row():
confidence_plot_image = gr.Image(label="Confidence Measures")
input_collector = gr.JSON(visible=False)
# Map inputs to a dictionary
submit_btn.click(
fn=lambda j, w: {"json": j, "watermark": w},
inputs=[json_upload, add_watermark],
outputs=input_collector
).then(
fn=predict_structure,
inputs=input_collector,
outputs=[view3d, confidence_plot_image, cif_file]
)
gr.Markdown("""
The example of the uploaded json file for structure prediction.
<pre>
[{
"sequences": [
{
"proteinChain": {
"sequence": "MAEVIRSSAFWRSFPIFEEFDSETLCELSGIASYRKWSAGTVIFQRGDQGDYMIVVVSGRIKLSLFTPQGRELMLRQHEAGALFGEMALLDGQPRSADATAVTAAEGYVIGKKDFLALITQRPKTAEAVIRFLCAQLRDTTDRLETIALYDLNARVARFFLATLRQIHGSEMPQSANLRLTLSQTDIASILGASRPKVNRAILSLEESGAIKRADGIICCNVGRLLSIADPEEDLEHHHHHHHH",
"count": 2
}
},
{
"dnaSequence": {
"sequence": "CTAGGTAACATTACTCGCG",
"count": 2
}
},
{
"dnaSequence": {
"sequence": "GCGAGTAATGTTAC",
"count": 2
}
},
{
"ligand": {
"ligand": "CCD_PCG",
"count": 2
}
}
],
"name": "7pzb"
}]
</pre>
""")
with gr.Tab("Structure Predictor (Manual Input)"):
with gr.Row():
complex_name = gr.Textbox(label="Complex Name")
# Replace gr.Group with gr.Accordion
with gr.Accordion(label="Protein Chains", open=True):
protein_chains = gr.Dataframe(
headers=["Sequence", "Count"],
datatype=["str", "number"],
row_count=1,
col_count=(2, "fixed")
)
# Repeat for other groups
with gr.Accordion(label="DNA Sequences", open=True):
dna_sequences = gr.Dataframe(
headers=["Sequence", "Count"],
datatype=["str", "number"],
row_count=1
)
with gr.Accordion(label="Ligands", open=True):
ligands = gr.Dataframe(
headers=["Ligand Type", "Count"],
datatype=["str", "number"],
row_count=1
)
manual_output = gr.JSON(label="Generated JSON")
complex_name.change(
fn=lambda x: {"complex_name": x},
inputs=complex_name,
outputs=manual_output
)
# Shared prediction components
with gr.Row():
add_watermark1.render()
submit_btn = gr.Button("Predict Structure", variant="primary")
#structure_view = gr.HTML(label="3D Visualization")
with gr.Row():
view3d = Molecule3D(label="3D Visualization (Gray: Unwatermarked; Cyan: Watermarked)", reps=reps)
with gr.Row():
cif_file = gr.File(label="Download CIF File")
with gr.Row():
confidence_plot_image = gr.Image(label="Confidence Measures")
input_collector = gr.JSON(visible=False)
# Map inputs to a dictionary
submit_btn.click(
fn=lambda c, p, d, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "ligands": l}, "watermark": w},
inputs=[complex_name, protein_chains, dna_sequences, ligands, add_watermark1],
outputs=input_collector
).then(
fn=predict_structure,
inputs=input_collector,
outputs=[view3d, confidence_plot_image, cif_file]
)
@spaces.GPU(duration=120)
def is_watermarked(file):
#first initialize runner
runner = InferenceRunner(configs)
# Generate a unique subdirectory and filename
unique_id = str(uuid.uuid4().hex[:8])
subdir = os.path.join('./output', unique_id)
os.makedirs(subdir, exist_ok=True)
filename = f"{unique_id}.cif"
file_path = os.path.join(subdir, filename)
# Save the uploaded file to the new location
shutil.copy(file.name, file_path)
#just for fast demonstration, otherwise it takes around 100 seconds
if '7r6r_watermarked' in file.name:
reuslt=True
elif '7pzb_unwatermarked' in file.name:
result=False
else:
# Call your processing functions
configs.process_success = process_data(subdir)
configs.subdir = subdir
result = infer_detect(runner, configs)
# This function should return 'Watermarked' or 'Not Watermarked'
temp_pdb_path = convert_cif_to_pdb(file_path)
if result==False:
return "Not Watermarked", temp_pdb_path
else:
return "Watermarked", temp_pdb_path
with gr.Tab("Watermark Detector"):
# First create the upload component
cif_upload = gr.File(label="Upload .cif", file_types=["..cif"])
with gr.Row():
cif_3d_view = Molecule3D(label="3D Visualization of Input", reps=reps)
# Prediction output
prediction_output = gr.Textbox(label="Prediction")
# Define the interaction
cif_upload.change(is_watermarked, inputs=cif_upload, outputs=[prediction_output, cif_3d_view])
# Example files
example_files = [
"./examples/7r6r_watermarked.cif",
"./examples/7pzb_unwatermarked.cif"
]
gr.Examples(examples=example_files, inputs=cif_upload)
if __name__ == "__main__":
demo.launch(share=True)