from Bio.PDB import MMCIFParser, PDBIO from folding_studio.client import Client from folding_studio.query.boltz import BoltzQuery, BoltzParameters from pathlib import Path import gradio as gr import hashlib import logging import numpy as np import os import plotly.graph_objects as go from molecule import molecule # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(), ] ) logger = logging.getLogger(__name__) def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None: """Convert a .cif file to .pdb format using Biopython. Args: cif_path (str): Path to input .cif file pdb_path (str): Path to output .pdb file """ # Parse the CIF file parser = MMCIFParser() structure = parser.get_structure("structure", cif_path) # Save as PDB io = PDBIO() io.set_structure(structure) io.save(pdb_path) def call_boltz(seq_file: Path | str, api_key: str, output_dir: Path) -> None: """Call Boltz prediction.""" # Initialize parameters with CLI-provided values parameters = { "recycling_steps": 3, "sampling_steps": 200, "diffusion_samples": 1, "step_scale": 1.638, "msa_pairing_strategy": "greedy", "write_full_pae": False, "write_full_pde": False, "use_msa_server": True, "seed": 0, "custom_msa_paths": None, } # Create a client using API key logger.info("Authenticating client with API key") client = Client.from_api_key(api_key=api_key) # Define query seq_file = Path(seq_file) query = BoltzQuery.from_file(seq_file, query_name="gradio", parameters=BoltzParameters(**parameters)) query.save_parameters(output_dir) logger.info("Payload: %s", query.payload) # Send a request logger.info("Sending request to Folding Studio API") response = client.send_request(query, project_code=os.environ["FOLDING_PROJECT_CODE"]) # Access confidence data logger.info("Confidence data: %s", response.confidence_data) response.download_results(output_dir=output_dir, force=True, unzip=True) logger.info("Results downloaded to %s", output_dir) def predict(sequence: str, api_key: str) -> str: """Predict protein structure from amino acid sequence using Boltz model. Args: sequence (str): Amino acid sequence to predict structure for api_key (str): Folding API key Returns: str: HTML iframe containing 3D molecular visualization """ # Set up unique output directory based on sequence hash seq_id = hashlib.sha1(sequence.encode()).hexdigest() seq_file = Path(f"sequence_{seq_id}.fasta") _write_fasta_file(seq_file, sequence) output_dir = Path(f"sequence_{seq_id}") output_dir.mkdir(parents=True, exist_ok=True) # Check if prediction already exists pred_cif = list(output_dir.rglob("*_model_0.cif")) if not pred_cif: # Run Boltz prediction logger.info(f"Predicting {seq_file.stem}") call_boltz(seq_file=seq_file, api_key=api_key, output_dir=output_dir) logger.info("Prediction done. Output directory: %s", output_dir) else: logger.info("Prediction already exists. Output directory: %s", output_dir) # output_dir = Path("boltz_results") # debug # Convert output CIF to PDB pred_cif = list(output_dir.rglob("*_model_0.cif"))[0] logger.info("Output file: %s", pred_cif) converted_pdb_path = str(output_dir / "pred.pdb") convert_cif_to_pdb(str(pred_cif), str(converted_pdb_path)) logger.info("Converted PDB file: %s", converted_pdb_path) # Generate molecular visualization mol = _create_molecule_visualization( converted_pdb_path, sequence, ) plddt_file = list(pred_cif.parent.glob("plddt_*.npz"))[0] logger.info("plddt file: %s", plddt_file) plddt_vals = np.load(plddt_file)["plddt"] return _wrap_in_iframe(mol), add_plddt_plot(plddt_vals=plddt_vals) def _write_fasta_file(filepath: Path, sequence: str) -> None: """Write sequence to FASTA file.""" with open(filepath, "w") as f: f.write(f">A|protein\n{sequence}") def _create_molecule_visualization(pdb_path: Path, sequence: str) -> str: """Create molecular visualization using molecule module.""" return molecule( str(pdb_path), lenSeqs=1, num_res=len(sequence), selectedResidues=list(range(1, len(sequence) + 1)), allSeqs=[sequence], sequences=[{ "Score": 0, "RMSD": 0, "Recovery": 0, "Mean pLDDT": 0, "seq": sequence }], ) def _wrap_in_iframe(content: str) -> str: """Wrap content in an HTML iframe with appropriate styling and permissions.""" return f"""""" def add_plddt_plot(plddt_vals: list[float]) -> str: """Create a plot of metrics.""" visible = True plddt_trace = go.Scatter( x=np.arange(len(plddt_vals)), y=plddt_vals, hovertemplate="pLDDT: %{y:.2f}
Residue index: %{x}
", name="seq", visible=visible, ) plddt_fig = go.Figure(data=[plddt_trace]) plddt_fig.update_layout( title="pLDDT", xaxis_title="Residue index", yaxis_title="pLDDT", height=500, template="simple_white", legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99), ) return plddt_fig demo = gr.Blocks(title="Folding Studio: structure prediction with Boltz-1") with demo: gr.Markdown("# Input") with gr.Row(): with gr.Column(): sequence = gr.Textbox(label="Sequence", value="") api_key = gr.Textbox(label="Folding API Key", type="password") gr.Markdown("# Output") with gr.Row(): predict_btn = gr.Button("Predict") with gr.Row(): with gr.Column(): mol_output = gr.HTML() with gr.Column(): metrics_plot = gr.Plot(label="pLDDT") predict_btn.click( fn=predict, inputs=[sequence, api_key], outputs=[mol_output, metrics_plot] ) demo.launch()