|
"""Chai-1 folding submission command.""" |
|
|
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import Annotated, Optional |
|
|
|
import typer |
|
from rich.json import JSON |
|
from rich.panel import Panel |
|
|
|
from folding_studio.client import Client |
|
from folding_studio.commands.utils import ( |
|
success_fail_catch_print, |
|
success_fail_catch_spinner, |
|
) |
|
from folding_studio.console import console |
|
from folding_studio.query.chai import ChaiQuery |
|
|
|
|
|
def chai( |
|
source: Annotated[ |
|
Path, |
|
typer.Argument( |
|
help=( |
|
"Path to the data source. Either a fasta file, a directory of fasta files " |
|
"or a csv/json file describing a batch prediction request." |
|
), |
|
exists=True, |
|
), |
|
], |
|
project_code: Annotated[ |
|
str, |
|
typer.Option( |
|
help="Project code. If unknown, contact your PM or the Folding Studio team.", |
|
envvar="FOLDING_PROJECT_CODE", |
|
exists=True, |
|
), |
|
], |
|
use_msa_server: Annotated[ |
|
bool, |
|
typer.Option( |
|
help="Flag to enable MSA features. MSA search is performed by InstaDeep's MMseqs2 server.", |
|
is_flag=True, |
|
), |
|
] = True, |
|
use_templates_server: Annotated[ |
|
bool, |
|
typer.Option( |
|
help="Flag to enable templates. Templates search is performed by InstaDeep's MMseqs2 server.", |
|
is_flag=True, |
|
), |
|
] = False, |
|
num_trunk_recycles: Annotated[ |
|
int, typer.Option(help="Number of trunk recycles during inference.") |
|
] = 3, |
|
seed: Annotated[int, typer.Option(help="Random seed for inference.")] = 0, |
|
num_diffn_timesteps: Annotated[ |
|
int, typer.Option(help="Number of diffusion timesteps to run.") |
|
] = 200, |
|
restraints: Annotated[ |
|
Optional[str], |
|
typer.Option(help="Restraints information."), |
|
] = None, |
|
recycle_msa_subsample: Annotated[ |
|
int, |
|
typer.Option(help="Subsample parameter for recycling MSA during inference."), |
|
] = 0, |
|
num_trunk_samples: Annotated[ |
|
int, typer.Option(help="Number of trunk samples to generate during inference.") |
|
] = 1, |
|
msa_path: Annotated[ |
|
Optional[str], |
|
typer.Option( |
|
help="Path to the custom MSAs. It can be a .a3m or .aligned.pqt file, or a directory containing these files." |
|
), |
|
] = None, |
|
output: Annotated[ |
|
Path, |
|
typer.Option( |
|
help="Local path to download the result zip and query parameters to. " |
|
"Default to 'chai_results'." |
|
), |
|
] = "chai_results", |
|
force: Annotated[ |
|
bool, |
|
typer.Option( |
|
help=( |
|
"Forces the download to overwrite any existing file " |
|
"with the same name in the specified location." |
|
) |
|
), |
|
] = False, |
|
unzip: Annotated[ |
|
bool, typer.Option(help="Unzip the file after its download.") |
|
] = False, |
|
spinner: Annotated[ |
|
bool, typer.Option(help="Use live spinner in log output.") |
|
] = True, |
|
): |
|
"""Synchronous Chai-1 folding submission.""" |
|
|
|
if msa_path is not None: |
|
console.print( |
|
"\n[yellow]:warning: Custom MSA path provided. Disabling automated MSA search.[/yellow]" |
|
) |
|
use_msa_server = False |
|
|
|
console.print( |
|
Panel("[bold cyan]:dna: Chai-1 Folding submission [/bold cyan]", expand=False) |
|
) |
|
|
|
success_fail_catch = ( |
|
success_fail_catch_spinner if spinner else success_fail_catch_print |
|
) |
|
|
|
|
|
with success_fail_catch(":key: Authenticating client"): |
|
client = Client.authenticate() |
|
|
|
output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}" |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
with success_fail_catch(":package: Generating query"): |
|
query_builder = ( |
|
ChaiQuery.from_file if source.is_file() else ChaiQuery.from_directory |
|
) |
|
query: ChaiQuery = query_builder( |
|
source, |
|
restraints=restraints, |
|
use_msa_server=use_msa_server, |
|
use_templates_server=use_templates_server, |
|
num_trunk_recycles=num_trunk_recycles, |
|
seed=seed, |
|
num_diffn_timesteps=num_diffn_timesteps, |
|
recycle_msa_subsample=recycle_msa_subsample, |
|
num_trunk_samples=num_trunk_samples, |
|
custom_msa_paths=msa_path, |
|
) |
|
query.save_parameters(output_dir) |
|
|
|
console.print("[blue]Generated query: [/blue]", end="") |
|
console.print(JSON.from_data(query.payload), style="blue") |
|
|
|
|
|
with success_fail_catch(":brain: Processing folding job"): |
|
response = client.send_request(query, project_code) |
|
|
|
|
|
console.print("[blue]Confidence Data:[/blue]", end=" ") |
|
console.print(JSON.from_data(response.confidence_data), style="blue") |
|
|
|
|
|
with success_fail_catch( |
|
f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`" |
|
): |
|
response.download_results(output_dir=output_dir, force=force, unzip=unzip) |
|
|