jfaustin's picture
add dockerfile and folding studio cli
44459bb
"""Chai-1 folding submission command."""
from datetime import datetime
from pathlib import Path
from typing import Annotated, Optional
import typer
from rich.json import JSON
from rich.panel import Panel
from folding_studio.client import Client
from folding_studio.commands.utils import (
success_fail_catch_print,
success_fail_catch_spinner,
)
from folding_studio.console import console
from folding_studio.query.chai import ChaiQuery
def chai(
source: Annotated[
Path,
typer.Argument(
help=(
"Path to the data source. Either a fasta file, a directory of fasta files "
"or a csv/json file describing a batch prediction request."
),
exists=True,
),
],
project_code: Annotated[
str,
typer.Option(
help="Project code. If unknown, contact your PM or the Folding Studio team.",
envvar="FOLDING_PROJECT_CODE",
exists=True,
),
],
use_msa_server: Annotated[
bool,
typer.Option(
help="Flag to enable MSA features. MSA search is performed by InstaDeep's MMseqs2 server.",
is_flag=True,
),
] = True,
use_templates_server: Annotated[
bool,
typer.Option(
help="Flag to enable templates. Templates search is performed by InstaDeep's MMseqs2 server.",
is_flag=True,
),
] = False,
num_trunk_recycles: Annotated[
int, typer.Option(help="Number of trunk recycles during inference.")
] = 3,
seed: Annotated[int, typer.Option(help="Random seed for inference.")] = 0,
num_diffn_timesteps: Annotated[
int, typer.Option(help="Number of diffusion timesteps to run.")
] = 200,
restraints: Annotated[
Optional[str],
typer.Option(help="Restraints information."),
] = None,
recycle_msa_subsample: Annotated[
int,
typer.Option(help="Subsample parameter for recycling MSA during inference."),
] = 0,
num_trunk_samples: Annotated[
int, typer.Option(help="Number of trunk samples to generate during inference.")
] = 1,
msa_path: Annotated[
Optional[str],
typer.Option(
help="Path to the custom MSAs. It can be a .a3m or .aligned.pqt file, or a directory containing these files."
),
] = None,
output: Annotated[
Path,
typer.Option(
help="Local path to download the result zip and query parameters to. "
"Default to 'chai_results'."
),
] = "chai_results",
force: Annotated[
bool,
typer.Option(
help=(
"Forces the download to overwrite any existing file "
"with the same name in the specified location."
)
),
] = False,
unzip: Annotated[
bool, typer.Option(help="Unzip the file after its download.")
] = False,
spinner: Annotated[
bool, typer.Option(help="Use live spinner in log output.")
] = True,
):
"""Synchronous Chai-1 folding submission."""
# If a custom MSA path is provided, disable automated MSA search.
if msa_path is not None:
console.print(
"\n[yellow]:warning: Custom MSA path provided. Disabling automated MSA search.[/yellow]"
)
use_msa_server = False
console.print(
Panel("[bold cyan]:dna: Chai-1 Folding submission [/bold cyan]", expand=False)
)
success_fail_catch = (
success_fail_catch_spinner if spinner else success_fail_catch_print
)
# Create a client using API key or JWT
with success_fail_catch(":key: Authenticating client"):
client = Client.authenticate()
output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}"
output_dir.mkdir(parents=True, exist_ok=True)
# Define a query
with success_fail_catch(":package: Generating query"):
query_builder = (
ChaiQuery.from_file if source.is_file() else ChaiQuery.from_directory
)
query: ChaiQuery = query_builder(
source,
restraints=restraints,
use_msa_server=use_msa_server,
use_templates_server=use_templates_server,
num_trunk_recycles=num_trunk_recycles,
seed=seed,
num_diffn_timesteps=num_diffn_timesteps,
recycle_msa_subsample=recycle_msa_subsample,
num_trunk_samples=num_trunk_samples,
custom_msa_paths=msa_path,
)
query.save_parameters(output_dir)
console.print("[blue]Generated query: [/blue]", end="")
console.print(JSON.from_data(query.payload), style="blue")
# Send a request
with success_fail_catch(":brain: Processing folding job"):
response = client.send_request(query, project_code)
# Access confidence data
console.print("[blue]Confidence Data:[/blue]", end=" ")
console.print(JSON.from_data(response.confidence_data), style="blue")
# Download results
with success_fail_catch(
f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`"
):
response.download_results(output_dir=output_dir, force=force, unzip=unzip)