|
"""Boltz-1 folding submission command.""" |
|
|
|
import json |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import Annotated, Any, Optional |
|
|
|
import typer |
|
from rich.json import JSON |
|
from rich.panel import Panel |
|
|
|
from folding_studio.client import Client |
|
from folding_studio.commands.utils import ( |
|
success_fail_catch_print, |
|
success_fail_catch_spinner, |
|
) |
|
from folding_studio.config import FOLDING_API_KEY |
|
from folding_studio.console import console |
|
from folding_studio.query.boltz import BoltzQuery |
|
|
|
|
|
def boltz( |
|
source: Annotated[ |
|
Path, |
|
typer.Argument( |
|
help=( |
|
"Path to the data source. Either a FASTA file, a YAML file, " |
|
"or a directory containing FASTA and YAML files." |
|
), |
|
exists=True, |
|
), |
|
], |
|
project_code: Annotated[ |
|
str, |
|
typer.Option( |
|
help="Project code. If unknown, contact your PM or the Folding Studio team.", |
|
envvar="FOLDING_PROJECT_CODE", |
|
|
|
), |
|
], |
|
parameters_json: Annotated[ |
|
Path | None, |
|
typer.Option(help="Path to JSON file containing Boltz inference parameters."), |
|
] = None, |
|
recycling_steps: Annotated[ |
|
int, typer.Option(help="Number of recycling steps for prediction.") |
|
] = 3, |
|
sampling_steps: Annotated[ |
|
int, typer.Option(help="Number of sampling steps for prediction.") |
|
] = 200, |
|
diffusion_samples: Annotated[ |
|
int, typer.Option(help="Number of diffusion samples for prediction.") |
|
] = 1, |
|
step_scale: Annotated[ |
|
float, |
|
typer.Option( |
|
help="Step size related to the temperature at which the diffusion process samples the distribution." |
|
), |
|
] = 1.638, |
|
msa_pairing_strategy: Annotated[ |
|
str, typer.Option(help="Pairing strategy for MSA generation.") |
|
] = "greedy", |
|
write_full_pae: Annotated[ |
|
bool, typer.Option(help="Whether to save the full PAE matrix as a file.") |
|
] = False, |
|
write_full_pde: Annotated[ |
|
bool, typer.Option(help="Whether to save the full PDE matrix as a file.") |
|
] = False, |
|
use_msa_server: Annotated[ |
|
bool, |
|
typer.Option(help="Flag to use the MSA server for inference.", is_flag=True), |
|
] = True, |
|
msa_path: Annotated[ |
|
Optional[str], |
|
typer.Option( |
|
help="Path to the custom MSAs. It can be a .a3m or .aligned.pqt file, or a directory containing these files." |
|
), |
|
] = None, |
|
seed: Annotated[ |
|
int | None, typer.Option(help="Seed for random number generation.") |
|
] = 0, |
|
output: Annotated[ |
|
Path, |
|
typer.Option( |
|
help="Local path to download the result zip and query parameters to. " |
|
"Default to 'boltz_results'." |
|
), |
|
] = "boltz_results", |
|
force: Annotated[ |
|
bool, |
|
typer.Option( |
|
help=( |
|
"Forces the download to overwrite any existing file " |
|
"with the same name in the specified location." |
|
) |
|
), |
|
] = False, |
|
unzip: Annotated[ |
|
bool, typer.Option(help="Unzip the file after its download.") |
|
] = False, |
|
spinner: Annotated[ |
|
bool, typer.Option(help="Use live spinner in log output.") |
|
] = True, |
|
): |
|
"""Synchronous Boltz-1 folding submission.""" |
|
|
|
success_fail_catch = ( |
|
success_fail_catch_spinner if spinner else success_fail_catch_print |
|
) |
|
|
|
|
|
if msa_path is not None: |
|
console.print( |
|
"\n[yellow]:warning: Custom MSA path provided. Disabling automated MSA search.[/yellow]" |
|
) |
|
use_msa_server = False |
|
|
|
console.print( |
|
Panel("[bold cyan]:dna: Boltz1 Folding submission [/bold cyan]", expand=False) |
|
) |
|
output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}" |
|
|
|
|
|
parameters = { |
|
"recycling_steps": recycling_steps, |
|
"sampling_steps": sampling_steps, |
|
"diffusion_samples": diffusion_samples, |
|
"step_scale": step_scale, |
|
"msa_pairing_strategy": msa_pairing_strategy, |
|
"write_full_pae": write_full_pae, |
|
"write_full_pde": write_full_pde, |
|
"use_msa_server": use_msa_server, |
|
"seed": seed, |
|
"custom_msa_paths": msa_path, |
|
} |
|
|
|
if parameters_json: |
|
try: |
|
with open(parameters_json, "r") as f: |
|
json_parameters: dict[str, Any] = json.load(f) |
|
except Exception as e: |
|
raise ValueError(f"Error reading JSON file: {e}") |
|
console.print( |
|
":warning: Parameters specified in the configuration file will " |
|
"take precedence over the CLI options." |
|
) |
|
parameters.update(json_parameters) |
|
|
|
|
|
with success_fail_catch(":key: Authenticating client"): |
|
client = Client.authenticate() |
|
|
|
|
|
with success_fail_catch(":package: Generating query"): |
|
query_builder = ( |
|
BoltzQuery.from_file if source.is_file() else BoltzQuery.from_directory |
|
) |
|
query: BoltzQuery = query_builder(source, **parameters) |
|
query.save_parameters(output_dir) |
|
|
|
console.print("[blue]Generated query: [/blue]", end="") |
|
console.print(JSON.from_data(query.payload), style="blue") |
|
|
|
|
|
with success_fail_catch(":brain: Processing folding job"): |
|
response = client.send_request(query, project_code) |
|
|
|
|
|
console.print("[blue]Confidence data:[/blue]", end=" ") |
|
console.print(JSON.from_data(response.confidence_data), style="blue") |
|
|
|
with success_fail_catch( |
|
f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`" |
|
): |
|
response.download_results(output_dir=output_dir, force=force, unzip=unzip) |
|
|