File size: 5,265 Bytes
44459bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
"""Chai-1 folding submission command."""
from datetime import datetime
from pathlib import Path
from typing import Annotated, Optional
import typer
from rich.json import JSON
from rich.panel import Panel
from folding_studio.client import Client
from folding_studio.commands.utils import (
success_fail_catch_print,
success_fail_catch_spinner,
)
from folding_studio.console import console
from folding_studio.query.chai import ChaiQuery
def chai(
source: Annotated[
Path,
typer.Argument(
help=(
"Path to the data source. Either a fasta file, a directory of fasta files "
"or a csv/json file describing a batch prediction request."
),
exists=True,
),
],
project_code: Annotated[
str,
typer.Option(
help="Project code. If unknown, contact your PM or the Folding Studio team.",
envvar="FOLDING_PROJECT_CODE",
exists=True,
),
],
use_msa_server: Annotated[
bool,
typer.Option(
help="Flag to enable MSA features. MSA search is performed by InstaDeep's MMseqs2 server.",
is_flag=True,
),
] = True,
use_templates_server: Annotated[
bool,
typer.Option(
help="Flag to enable templates. Templates search is performed by InstaDeep's MMseqs2 server.",
is_flag=True,
),
] = False,
num_trunk_recycles: Annotated[
int, typer.Option(help="Number of trunk recycles during inference.")
] = 3,
seed: Annotated[int, typer.Option(help="Random seed for inference.")] = 0,
num_diffn_timesteps: Annotated[
int, typer.Option(help="Number of diffusion timesteps to run.")
] = 200,
restraints: Annotated[
Optional[str],
typer.Option(help="Restraints information."),
] = None,
recycle_msa_subsample: Annotated[
int,
typer.Option(help="Subsample parameter for recycling MSA during inference."),
] = 0,
num_trunk_samples: Annotated[
int, typer.Option(help="Number of trunk samples to generate during inference.")
] = 1,
msa_path: Annotated[
Optional[str],
typer.Option(
help="Path to the custom MSAs. It can be a .a3m or .aligned.pqt file, or a directory containing these files."
),
] = None,
output: Annotated[
Path,
typer.Option(
help="Local path to download the result zip and query parameters to. "
"Default to 'chai_results'."
),
] = "chai_results",
force: Annotated[
bool,
typer.Option(
help=(
"Forces the download to overwrite any existing file "
"with the same name in the specified location."
)
),
] = False,
unzip: Annotated[
bool, typer.Option(help="Unzip the file after its download.")
] = False,
spinner: Annotated[
bool, typer.Option(help="Use live spinner in log output.")
] = True,
):
"""Synchronous Chai-1 folding submission."""
# If a custom MSA path is provided, disable automated MSA search.
if msa_path is not None:
console.print(
"\n[yellow]:warning: Custom MSA path provided. Disabling automated MSA search.[/yellow]"
)
use_msa_server = False
console.print(
Panel("[bold cyan]:dna: Chai-1 Folding submission [/bold cyan]", expand=False)
)
success_fail_catch = (
success_fail_catch_spinner if spinner else success_fail_catch_print
)
# Create a client using API key or JWT
with success_fail_catch(":key: Authenticating client"):
client = Client.authenticate()
output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}"
output_dir.mkdir(parents=True, exist_ok=True)
# Define a query
with success_fail_catch(":package: Generating query"):
query_builder = (
ChaiQuery.from_file if source.is_file() else ChaiQuery.from_directory
)
query: ChaiQuery = query_builder(
source,
restraints=restraints,
use_msa_server=use_msa_server,
use_templates_server=use_templates_server,
num_trunk_recycles=num_trunk_recycles,
seed=seed,
num_diffn_timesteps=num_diffn_timesteps,
recycle_msa_subsample=recycle_msa_subsample,
num_trunk_samples=num_trunk_samples,
custom_msa_paths=msa_path,
)
query.save_parameters(output_dir)
console.print("[blue]Generated query: [/blue]", end="")
console.print(JSON.from_data(query.payload), style="blue")
# Send a request
with success_fail_catch(":brain: Processing folding job"):
response = client.send_request(query, project_code)
# Access confidence data
console.print("[blue]Confidence Data:[/blue]", end=" ")
console.print(JSON.from_data(response.confidence_data), style="blue")
# Download results
with success_fail_catch(
f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`"
):
response.download_results(output_dir=output_dir, force=force, unzip=unzip)
|