jfaustin's picture
add dockerfile and folding studio cli
44459bb
raw
history blame
4 kB
"""Protenix folding submission command."""
from datetime import datetime
from pathlib import Path
from typing import Annotated
import typer
from rich.json import JSON
from rich.panel import Panel
from folding_studio.client import Client
from folding_studio.commands.utils import (
success_fail_catch_print,
success_fail_catch_spinner,
)
from folding_studio.config import FOLDING_API_KEY
from folding_studio.console import console
from folding_studio.query import ProtenixQuery
def protenix(
source: Annotated[
Path,
typer.Argument(
help=(
"Path to the data source. Either a fasta file, a directory of fasta files"
"describing a batch prediction request."
),
# callback=_validate_source_path,
exists=True,
),
],
project_code: Annotated[ # noqa: ANN001
str,
typer.Option(
help=(
"Project code. If unknown, contact your PM or the Folding Studio team."
),
exists=True,
envvar="FOLDING_PROJECT_CODE",
),
],
use_msa_server: Annotated[ # pylint: disable=unused-argument
bool,
typer.Option(
help="Flag to use the MSA server for inference. Forced to True.",
is_flag=True,
),
] = True,
seed: Annotated[int, typer.Option(help="Random seed.")] = 0,
cycle: Annotated[int, typer.Option(help="Pairformer cycle number.")] = 10,
step: Annotated[
int, typer.Option(help="Number of steps for the diffusion process.")
] = 200,
sample: Annotated[int, typer.Option(help="Number of samples in each seed.")] = 5,
output: Annotated[
Path,
typer.Option(
help="Local path to download the result zip and query parameters to. "
"Default to 'protenix_results'."
),
] = "protenix_results",
force: Annotated[
bool,
typer.Option(
help=(
"Forces the download to overwrite any existing file "
"with the same name in the specified location."
)
),
] = False,
unzip: Annotated[
bool, typer.Option(help="Unzip the file after its download.")
] = False,
spinner: Annotated[
bool, typer.Option(help="Use live spinner in log output.")
] = True,
):
"""Synchronous Protenix folding submission."""
success_fail_catch = (
success_fail_catch_spinner if spinner else success_fail_catch_print
)
console.print(
Panel("[bold cyan]:dna: Protenix Folding submission [/bold cyan]", expand=False)
)
output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}"
# Create a client using API key or JWT
with success_fail_catch(":key: Authenticating client"):
client = Client.authenticate()
# Define a query
with success_fail_catch(":package: Generating query"):
query_builder = (
ProtenixQuery.from_file
if source.is_file()
else ProtenixQuery.from_directory
)
query: ProtenixQuery = query_builder(
source,
use_msa_server=True,
seed=seed,
cycle=cycle,
step=step,
sample=sample,
)
query.save_parameters(output_dir)
console.print("[blue]Generated query: [/blue]", end="")
console.print(JSON.from_data(query.payload), style="blue")
# Send a request
with success_fail_catch(":brain: Processing folding job"):
response = client.send_request(query, project_code)
# Access confidence data
console.print("[blue]Confidence Data:[/blue]", end=" ")
console.print(JSON.from_data(response.confidence_data), style="blue")
with success_fail_catch(
f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`"
):
response.download_results(output_dir=output_dir, force=force, unzip=unzip)