jfaustin's picture
add dockerfile and folding studio cli
44459bb
"""Boltz-1 folding submission command."""
import json
from datetime import datetime
from pathlib import Path
from typing import Annotated, Any, Optional
import typer
from rich.json import JSON
from rich.panel import Panel
from folding_studio.client import Client
from folding_studio.commands.utils import (
success_fail_catch_print,
success_fail_catch_spinner,
)
from folding_studio.config import FOLDING_API_KEY
from folding_studio.console import console
from folding_studio.query.boltz import BoltzQuery
def boltz(
source: Annotated[
Path,
typer.Argument(
help=(
"Path to the data source. Either a FASTA file, a YAML file, "
"or a directory containing FASTA and YAML files."
),
exists=True,
),
],
project_code: Annotated[
str,
typer.Option(
help="Project code. If unknown, contact your PM or the Folding Studio team.",
envvar="FOLDING_PROJECT_CODE",
# exists=True,
),
],
parameters_json: Annotated[
Path | None,
typer.Option(help="Path to JSON file containing Boltz inference parameters."),
] = None,
recycling_steps: Annotated[
int, typer.Option(help="Number of recycling steps for prediction.")
] = 3,
sampling_steps: Annotated[
int, typer.Option(help="Number of sampling steps for prediction.")
] = 200,
diffusion_samples: Annotated[
int, typer.Option(help="Number of diffusion samples for prediction.")
] = 1,
step_scale: Annotated[
float,
typer.Option(
help="Step size related to the temperature at which the diffusion process samples the distribution."
),
] = 1.638,
msa_pairing_strategy: Annotated[
str, typer.Option(help="Pairing strategy for MSA generation.")
] = "greedy",
write_full_pae: Annotated[
bool, typer.Option(help="Whether to save the full PAE matrix as a file.")
] = False,
write_full_pde: Annotated[
bool, typer.Option(help="Whether to save the full PDE matrix as a file.")
] = False,
use_msa_server: Annotated[
bool,
typer.Option(help="Flag to use the MSA server for inference.", is_flag=True),
] = True,
msa_path: Annotated[
Optional[str],
typer.Option(
help="Path to the custom MSAs. It can be a .a3m or .aligned.pqt file, or a directory containing these files."
),
] = None,
seed: Annotated[
int | None, typer.Option(help="Seed for random number generation.")
] = 0,
output: Annotated[
Path,
typer.Option(
help="Local path to download the result zip and query parameters to. "
"Default to 'boltz_results'."
),
] = "boltz_results",
force: Annotated[
bool,
typer.Option(
help=(
"Forces the download to overwrite any existing file "
"with the same name in the specified location."
)
),
] = False,
unzip: Annotated[
bool, typer.Option(help="Unzip the file after its download.")
] = False,
spinner: Annotated[
bool, typer.Option(help="Use live spinner in log output.")
] = True,
):
"""Synchronous Boltz-1 folding submission."""
success_fail_catch = (
success_fail_catch_spinner if spinner else success_fail_catch_print
)
# If a custom MSA path is provided, disable automated MSA search.
if msa_path is not None:
console.print(
"\n[yellow]:warning: Custom MSA path provided. Disabling automated MSA search.[/yellow]"
)
use_msa_server = False
console.print(
Panel("[bold cyan]:dna: Boltz1 Folding submission [/bold cyan]", expand=False)
)
output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}"
# Initialize parameters with CLI-provided values
parameters = {
"recycling_steps": recycling_steps,
"sampling_steps": sampling_steps,
"diffusion_samples": diffusion_samples,
"step_scale": step_scale,
"msa_pairing_strategy": msa_pairing_strategy,
"write_full_pae": write_full_pae,
"write_full_pde": write_full_pde,
"use_msa_server": use_msa_server,
"seed": seed,
"custom_msa_paths": msa_path,
}
if parameters_json:
try:
with open(parameters_json, "r") as f:
json_parameters: dict[str, Any] = json.load(f)
except Exception as e:
raise ValueError(f"Error reading JSON file: {e}")
console.print(
":warning: Parameters specified in the configuration file will "
"take precedence over the CLI options."
)
parameters.update(json_parameters)
# Create a client using API key or JWT
with success_fail_catch(":key: Authenticating client"):
client = Client.authenticate()
# Define query
with success_fail_catch(":package: Generating query"):
query_builder = (
BoltzQuery.from_file if source.is_file() else BoltzQuery.from_directory
)
query: BoltzQuery = query_builder(source, **parameters)
query.save_parameters(output_dir)
console.print("[blue]Generated query: [/blue]", end="")
console.print(JSON.from_data(query.payload), style="blue")
# Send a request
with success_fail_catch(":brain: Processing folding job"):
response = client.send_request(query, project_code)
# Access confidence data
console.print("[blue]Confidence data:[/blue]", end=" ")
console.print(JSON.from_data(response.confidence_data), style="blue")
with success_fail_catch(
f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`"
):
response.download_results(output_dir=output_dir, force=force, unzip=unzip)