jfaustin's picture
add dockerfile and folding studio cli
44459bb
"""Protenix folding submission command."""
from datetime import datetime
from pathlib import Path
from typing import Annotated
import typer
from rich.json import JSON
from rich.panel import Panel
from folding_studio.client import Client
from folding_studio.commands.utils import (
success_fail_catch_print,
success_fail_catch_spinner,
)
from folding_studio.config import FOLDING_API_KEY
from folding_studio.console import console
from folding_studio.query import ProtenixQuery
def protenix(
source: Annotated[
Path,
typer.Argument(
help=(
"Path to the data source. Either a fasta file, a directory of fasta files"
"describing a batch prediction request."
),
# callback=_validate_source_path,
exists=True,
),
],
project_code: Annotated[ # noqa: ANN001
str,
typer.Option(
help=(
"Project code. If unknown, contact your PM or the Folding Studio team."
),
exists=True,
envvar="FOLDING_PROJECT_CODE",
),
],
use_msa_server: Annotated[ # pylint: disable=unused-argument
bool,
typer.Option(
help="Flag to use the MSA server for inference. Forced to True.",
is_flag=True,
),
] = True,
seed: Annotated[int, typer.Option(help="Random seed.")] = 0,
cycle: Annotated[int, typer.Option(help="Pairformer cycle number.")] = 10,
step: Annotated[
int, typer.Option(help="Number of steps for the diffusion process.")
] = 200,
sample: Annotated[int, typer.Option(help="Number of samples in each seed.")] = 5,
output: Annotated[
Path,
typer.Option(
help="Local path to download the result zip and query parameters to. "
"Default to 'protenix_results'."
),
] = "protenix_results",
force: Annotated[
bool,
typer.Option(
help=(
"Forces the download to overwrite any existing file "
"with the same name in the specified location."
)
),
] = False,
unzip: Annotated[
bool, typer.Option(help="Unzip the file after its download.")
] = False,
spinner: Annotated[
bool, typer.Option(help="Use live spinner in log output.")
] = True,
):
"""Synchronous Protenix folding submission."""
success_fail_catch = (
success_fail_catch_spinner if spinner else success_fail_catch_print
)
console.print(
Panel("[bold cyan]:dna: Protenix Folding submission [/bold cyan]", expand=False)
)
output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}"
# Create a client using API key or JWT
with success_fail_catch(":key: Authenticating client"):
client = Client.authenticate()
# Define a query
with success_fail_catch(":package: Generating query"):
query_builder = (
ProtenixQuery.from_file
if source.is_file()
else ProtenixQuery.from_directory
)
query: ProtenixQuery = query_builder(
source,
use_msa_server=True,
seed=seed,
cycle=cycle,
step=step,
sample=sample,
)
query.save_parameters(output_dir)
console.print("[blue]Generated query: [/blue]", end="")
console.print(JSON.from_data(query.payload), style="blue")
# Send a request
with success_fail_catch(":brain: Processing folding job"):
response = client.send_request(query, project_code)
# Access confidence data
console.print("[blue]Confidence Data:[/blue]", end=" ")
console.print(JSON.from_data(response.confidence_data), style="blue")
with success_fail_catch(
f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`"
):
response.download_results(output_dir=output_dir, force=force, unzip=unzip)