"""Boltz-1 folding submission command.""" import json from datetime import datetime from pathlib import Path from typing import Annotated, Any, Optional import typer from rich.json import JSON from rich.panel import Panel from folding_studio.client import Client from folding_studio.commands.utils import ( success_fail_catch_print, success_fail_catch_spinner, ) from folding_studio.config import FOLDING_API_KEY from folding_studio.console import console from folding_studio.query.boltz import BoltzQuery def boltz( source: Annotated[ Path, typer.Argument( help=( "Path to the data source. Either a FASTA file, a YAML file, " "or a directory containing FASTA and YAML files." ), exists=True, ), ], project_code: Annotated[ str, typer.Option( help="Project code. If unknown, contact your PM or the Folding Studio team.", envvar="FOLDING_PROJECT_CODE", # exists=True, ), ], parameters_json: Annotated[ Path | None, typer.Option(help="Path to JSON file containing Boltz inference parameters."), ] = None, recycling_steps: Annotated[ int, typer.Option(help="Number of recycling steps for prediction.") ] = 3, sampling_steps: Annotated[ int, typer.Option(help="Number of sampling steps for prediction.") ] = 200, diffusion_samples: Annotated[ int, typer.Option(help="Number of diffusion samples for prediction.") ] = 1, step_scale: Annotated[ float, typer.Option( help="Step size related to the temperature at which the diffusion process samples the distribution." ), ] = 1.638, msa_pairing_strategy: Annotated[ str, typer.Option(help="Pairing strategy for MSA generation.") ] = "greedy", write_full_pae: Annotated[ bool, typer.Option(help="Whether to save the full PAE matrix as a file.") ] = False, write_full_pde: Annotated[ bool, typer.Option(help="Whether to save the full PDE matrix as a file.") ] = False, use_msa_server: Annotated[ bool, typer.Option(help="Flag to use the MSA server for inference.", is_flag=True), ] = True, msa_path: Annotated[ Optional[str], typer.Option( help="Path to the custom MSAs. It can be a .a3m or .aligned.pqt file, or a directory containing these files." ), ] = None, seed: Annotated[ int | None, typer.Option(help="Seed for random number generation.") ] = 0, output: Annotated[ Path, typer.Option( help="Local path to download the result zip and query parameters to. " "Default to 'boltz_results'." ), ] = "boltz_results", force: Annotated[ bool, typer.Option( help=( "Forces the download to overwrite any existing file " "with the same name in the specified location." ) ), ] = False, unzip: Annotated[ bool, typer.Option(help="Unzip the file after its download.") ] = False, spinner: Annotated[ bool, typer.Option(help="Use live spinner in log output.") ] = True, ): """Synchronous Boltz-1 folding submission.""" success_fail_catch = ( success_fail_catch_spinner if spinner else success_fail_catch_print ) # If a custom MSA path is provided, disable automated MSA search. if msa_path is not None: console.print( "\n[yellow]:warning: Custom MSA path provided. Disabling automated MSA search.[/yellow]" ) use_msa_server = False console.print( Panel("[bold cyan]:dna: Boltz1 Folding submission [/bold cyan]", expand=False) ) output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}" # Initialize parameters with CLI-provided values parameters = { "recycling_steps": recycling_steps, "sampling_steps": sampling_steps, "diffusion_samples": diffusion_samples, "step_scale": step_scale, "msa_pairing_strategy": msa_pairing_strategy, "write_full_pae": write_full_pae, "write_full_pde": write_full_pde, "use_msa_server": use_msa_server, "seed": seed, "custom_msa_paths": msa_path, } if parameters_json: try: with open(parameters_json, "r") as f: json_parameters: dict[str, Any] = json.load(f) except Exception as e: raise ValueError(f"Error reading JSON file: {e}") console.print( ":warning: Parameters specified in the configuration file will " "take precedence over the CLI options." ) parameters.update(json_parameters) # Create a client using API key or JWT with success_fail_catch(":key: Authenticating client"): client = Client.authenticate() # Define query with success_fail_catch(":package: Generating query"): query_builder = ( BoltzQuery.from_file if source.is_file() else BoltzQuery.from_directory ) query: BoltzQuery = query_builder(source, **parameters) query.save_parameters(output_dir) console.print("[blue]Generated query: [/blue]", end="") console.print(JSON.from_data(query.payload), style="blue") # Send a request with success_fail_catch(":brain: Processing folding job"): response = client.send_request(query, project_code) # Access confidence data console.print("[blue]Confidence data:[/blue]", end=" ") console.print(JSON.from_data(response.confidence_data), style="blue") with success_fail_catch( f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`" ): response.download_results(output_dir=output_dir, force=force, unzip=unzip)