"""Predict protein structure using Folding Studio.""" import hashlib import logging from io import StringIO from pathlib import Path import gradio as gr import numpy as np import plotly.graph_objects as go from Bio import SeqIO from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer from folding_studio_data_models import FoldingModel from folding_studio_demo.models import BoltzModel, ChaiModel, ProtenixModel logger = logging.getLogger(__name__) SEQUENCE_DIR = Path("sequences") SEQUENCE_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_DIR = Path("output") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) THREE_TO_ONE_LETTER = { "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C", "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I", "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P", "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V", "SEC": "U", "PYL": "O", "ASX": "B", "GLX": "Z", "XAA": "X", "XLE": "J", "UNK": "X", } def convert_to_one_letter(resname: str) -> str: """Convert three-letter amino acid code to one-letter code. Args: resname (str): Three-letter amino acid code Returns: str: One-letter amino acid code """ return THREE_TO_ONE_LETTER.get(resname, "X") def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None: """Convert a .cif file to .pdb format using Biopython. Args: cif_path (str): Path to input .cif file pdb_path (str): Path to output .pdb file """ # Parse the CIF file parser = MMCIFParser() structure = parser.get_structure("structure", cif_path) # Save as PDB io = PDBIO() io.set_structure(structure) io.save(pdb_path) def create_plddt_figure( plddt_vals: list[list[float]], model_name: str, residue_codes: list[list[str]] = None, ) -> go.Figure: """Create a plot of metrics.""" plddt_traces = [] for i, plddt_val in enumerate(plddt_vals): # Create hover text with residue codes if available if residue_codes and i < len(residue_codes): hover_text = [ f"pLDDT: {plddt:.2f}
Residue: {code} {idx}" for idx, (plddt, code) in enumerate(zip(plddt_val, residue_codes[i])) ] else: hover_text = [ f"pLDDT: {plddt:.2f}
Residue index: {idx}" for idx, plddt in enumerate(plddt_val) ] plddt_traces.append( go.Scatter( x=np.arange(len(plddt_val)), y=plddt_val, hovertemplate="%{text}", text=hover_text, name=f"{model_name} {i}", visible=True, ) ) plddt_fig = go.Figure(data=plddt_traces) plddt_fig.update_layout( title="pLDDT", xaxis_title="Residue", yaxis_title="pLDDT", height=500, template="simple_white", legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99), ) return plddt_fig def _write_fasta_file( sequence: str, directory: Path = SEQUENCE_DIR ) -> tuple[str, Path]: """Write sequence to FASTA file. Args: sequence (str): Sequence to write to FASTA file directory (Path): Directory to write FASTA file to (default: SEQUENCE_DIR) Returns: tuple[str, Path]: Tuple containing the sequence ID and the path to the FASTA file """ input_rep = list(SeqIO.parse(StringIO(sequence), "fasta")) if not input_rep: raise gr.Error("No sequence found") seq_id = hashlib.sha256( "_".join([str(records.seq) for records in input_rep]).encode() ).hexdigest() seq_file = directory / f"sequence_{seq_id}.fasta" with open(seq_file, "w") as f: f.write(sequence) return seq_id, seq_file def extract_plddt_from_cif(cif_path): structure = MMCIFParser().get_structure("structure", cif_path) # Lists to store pLDDT values and residue codes plddt_values = [] residue_codes = [] # Iterate through all atoms for model in structure: for chain in model: for residue in chain: # Get the first atom of each residue (usually CA atom) if "CA" in residue: # The B-factor contains the pLDDT value plddt = residue["CA"].get_bfactor() plddt_values.append(plddt) # Get residue code and convert to one-letter code residue_codes.append(convert_to_one_letter(residue.get_resname())) return plddt_values, residue_codes def predict( sequence: str, api_key: str, model_type: FoldingModel, format_fasta: bool = False, progress=gr.Progress(), ) -> tuple[str, str]: """Predict protein structure from amino acid sequence using Boltz model. Args: sequence (str): Amino acid sequence to predict structure for api_key (str): Folding API key model (FoldingModel): Folding model to use format_fasta (bool): Whether to format the FASTA file progress (gr.Progress): Gradio progress tracker Returns: tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot """ if not api_key: raise gr.Error("Missing API key, please enter a valid API key") progress(0, desc="Setting up prediction...") # Set up unique output directory based on sequence hash seq_id, seq_file = _write_fasta_file(sequence) output_dir = OUTPUT_DIR / seq_id / model_type output_dir.mkdir(parents=True, exist_ok=True) if model_type == FoldingModel.BOLTZ: model = BoltzModel(api_key) elif model_type == FoldingModel.CHAI: model = ChaiModel(api_key) elif model_type == FoldingModel.PROTENIX: model = ProtenixModel(api_key) else: raise ValueError(f"Model {model_type} not supported") # Check if prediction already exists if not model.has_prediction(output_dir): progress(0.2, desc="Running prediction...") # Run prediction logger.info(f"Predicting {seq_id}") model.call(seq_file=seq_file, output_dir=output_dir, format_fasta=format_fasta) logger.info("Prediction done. Output directory: %s", output_dir) else: progress(0.2, desc="Using existing prediction...") logger.info("Prediction already exists. Output directory: %s", output_dir) progress(0.4, desc="Processing results...") # Convert output CIF to PDB if not model.has_prediction(output_dir): raise gr.Error("No prediction found") predictions = model.predictions(output_dir) pdb_paths = [] model_plddt_vals = [] model_residue_codes = [] total_predictions = len(predictions) for i, (model_idx, prediction) in enumerate(predictions.items()): progress( 0.4 + (0.4 * i / total_predictions), desc=f"Converting model {model_idx}..." ) cif_path = prediction["prediction_path"] logger.info(f"CIF file: {cif_path}") converted_pdb_path = str( output_dir / f"{model.model_name}_prediction_{model_idx}.pdb" ) convert_cif_to_pdb(str(cif_path), str(converted_pdb_path)) plddt_vals, residue_codes = extract_plddt_from_cif(cif_path) pdb_paths.append(converted_pdb_path) model_plddt_vals.append(plddt_vals) model_residue_codes.append(residue_codes) progress(0.8, desc="Generating plots...") plddt_fig = create_plddt_figure( plddt_vals=model_plddt_vals, model_name=model.model_name, residue_codes=model_residue_codes, ) progress(1.0, desc="Done!") return pdb_paths, plddt_fig def align_structures(pdb_paths: list[str]) -> list[str]: """Align multiple PDB structures to the first structure. Args: pdb_paths (list[str]): List of paths to PDB files to align Returns: list[str]: List of paths to aligned PDB files """ parser = PDBParser() io = PDBIO() # Parse the reference structure (first one) ref_structure = parser.get_structure("reference", pdb_paths[0]) ref_atoms = [atom for atom in ref_structure.get_atoms() if atom.get_name() == "CA"] aligned_paths = [pdb_paths[0]] # First structure is already aligned # Align each subsequent structure to the reference for i, pdb_path in enumerate(pdb_paths[1:], start=1): # Parse the structure to align structure = parser.get_structure(f"model_{i}", pdb_path) atoms = [atom for atom in structure.get_atoms() if atom.get_name() == "CA"] # Create superimposer sup = Superimposer() # Set the reference and moving atoms sup.set_atoms(ref_atoms, atoms) # Apply the transformation to all atoms in the structure sup.apply(structure.get_atoms()) # Save the aligned structure aligned_path = str(Path(pdb_path).parent / f"aligned_{Path(pdb_path).name}") io.set_structure(structure) io.save(aligned_path) aligned_paths.append(aligned_path) return aligned_paths def filter_predictions( aligned_paths: list[str], plddt_fig: go.Figure, chai_selected: list[int], boltz_selected: list[int], protenix_selected: list[int], ) -> tuple[list[str], go.Figure]: """Filter predictions based on selected checkboxes. Args: aligned_paths (list[str]): List of aligned PDB paths plddt_fig (go.Figure): Original pLDDT plot chai_selected (list[int]): Selected Chai model indices boltz_selected (list[int]): Selected Boltz model indices protenix_selected (list[int]): Selected Protenix model indices model_predictions (dict[FoldingModel, list[int]]): Dictionary mapping models to their prediction indices Returns: tuple[list[str], go.Figure]: Filtered PDB paths and updated pLDDT plot """ # Create a new figure with only selected traces filtered_fig = go.Figure() # Keep track of which traces to show visible_paths = [] # Helper function to check if a trace should be visible def should_show_trace(trace_name: str) -> bool: model_name = trace_name.split()[0] model_idx = int(trace_name.split()[1]) if model_name == "Chai" and model_idx in chai_selected: return True if model_name == "Boltz" and model_idx in boltz_selected: return True if model_name == "Protenix" and model_idx in protenix_selected: return True return False # Filter traces and paths for i, trace in enumerate(plddt_fig.data): if should_show_trace(trace.name): filtered_fig.add_trace(trace) visible_paths.append(aligned_paths[i]) # Update layout filtered_fig.update_layout( title="pLDDT", xaxis_title="Residue index", yaxis_title="pLDDT", height=500, template="simple_white", legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99), ) return visible_paths, filtered_fig def predict_comparison( sequence: str, api_key: str, model_types: list[FoldingModel], progress=gr.Progress() ) -> tuple[ list[str], go.Figure, gr.CheckboxGroup, gr.CheckboxGroup, gr.CheckboxGroup, list[str], go.Figure, dict, ]: """Predict protein structure from amino acid sequence using multiple models. Args: sequence (str): Amino acid sequence to predict structure for api_key (str): Folding API key model_types (list[FoldingModel]): List of folding models to use progress (gr.Progress): Gradio progress tracker Returns: tuple containing: - list[str]: Aligned PDB paths - go.Figure: pLDDT plot - gr.CheckboxGroup: Chai predictions checkbox group - gr.CheckboxGroup: Boltz predictions checkbox group - gr.CheckboxGroup: Protenix predictions checkbox group - list[str]: Original PDB paths - go.Figure: Original pLDDT plot - dict: Model predictions mapping """ if not api_key: raise gr.Error("Missing API key, please enter a valid API key") # Set up unique output directory based on sequence hash pdb_paths = [] plddt_traces = [] total_models = len(model_types) model_predictions = {} for i, model_type in enumerate(model_types): progress(i / total_models, desc=f"Running {model_type} prediction...") model_pdb_paths, model_plddt_traces = predict( sequence, api_key, model_type, format_fasta=True ) pdb_paths += model_pdb_paths plddt_traces += model_plddt_traces.data model_predictions[model_type] = [int(Path(p).stem[-1]) for p in model_pdb_paths] progress(0.9, desc="Aligning structures...") aligned_paths = align_structures(pdb_paths) plddt_fig = go.Figure(data=plddt_traces) plddt_fig.update_layout( title="pLDDT", xaxis_title="Residue index", yaxis_title="pLDDT", height=500, template="simple_white", legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99), ) progress(1.0, desc="Done!") # Create checkbox groups for each model type chai_predictions = gr.CheckboxGroup( visible=model_predictions.get(FoldingModel.CHAI) is not None, choices=model_predictions.get(FoldingModel.CHAI, []), value=model_predictions.get(FoldingModel.CHAI, []), ) boltz_predictions = gr.CheckboxGroup( visible=model_predictions.get(FoldingModel.BOLTZ) is not None, choices=model_predictions.get(FoldingModel.BOLTZ, []), value=model_predictions.get(FoldingModel.BOLTZ, []), ) protenix_predictions = gr.CheckboxGroup( visible=model_predictions.get(FoldingModel.PROTENIX) is not None, choices=model_predictions.get(FoldingModel.PROTENIX, []), value=model_predictions.get(FoldingModel.PROTENIX, []), ) return ( chai_predictions, boltz_predictions, protenix_predictions, aligned_paths, plddt_fig, )