File size: 4,000 Bytes
44459bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
"""Protenix folding submission command."""
from datetime import datetime
from pathlib import Path
from typing import Annotated
import typer
from rich.json import JSON
from rich.panel import Panel
from folding_studio.client import Client
from folding_studio.commands.utils import (
success_fail_catch_print,
success_fail_catch_spinner,
)
from folding_studio.config import FOLDING_API_KEY
from folding_studio.console import console
from folding_studio.query import ProtenixQuery
def protenix(
source: Annotated[
Path,
typer.Argument(
help=(
"Path to the data source. Either a fasta file, a directory of fasta files"
"describing a batch prediction request."
),
# callback=_validate_source_path,
exists=True,
),
],
project_code: Annotated[ # noqa: ANN001
str,
typer.Option(
help=(
"Project code. If unknown, contact your PM or the Folding Studio team."
),
exists=True,
envvar="FOLDING_PROJECT_CODE",
),
],
use_msa_server: Annotated[ # pylint: disable=unused-argument
bool,
typer.Option(
help="Flag to use the MSA server for inference. Forced to True.",
is_flag=True,
),
] = True,
seed: Annotated[int, typer.Option(help="Random seed.")] = 0,
cycle: Annotated[int, typer.Option(help="Pairformer cycle number.")] = 10,
step: Annotated[
int, typer.Option(help="Number of steps for the diffusion process.")
] = 200,
sample: Annotated[int, typer.Option(help="Number of samples in each seed.")] = 5,
output: Annotated[
Path,
typer.Option(
help="Local path to download the result zip and query parameters to. "
"Default to 'protenix_results'."
),
] = "protenix_results",
force: Annotated[
bool,
typer.Option(
help=(
"Forces the download to overwrite any existing file "
"with the same name in the specified location."
)
),
] = False,
unzip: Annotated[
bool, typer.Option(help="Unzip the file after its download.")
] = False,
spinner: Annotated[
bool, typer.Option(help="Use live spinner in log output.")
] = True,
):
"""Synchronous Protenix folding submission."""
success_fail_catch = (
success_fail_catch_spinner if spinner else success_fail_catch_print
)
console.print(
Panel("[bold cyan]:dna: Protenix Folding submission [/bold cyan]", expand=False)
)
output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}"
# Create a client using API key or JWT
with success_fail_catch(":key: Authenticating client"):
client = Client.authenticate()
# Define a query
with success_fail_catch(":package: Generating query"):
query_builder = (
ProtenixQuery.from_file
if source.is_file()
else ProtenixQuery.from_directory
)
query: ProtenixQuery = query_builder(
source,
use_msa_server=True,
seed=seed,
cycle=cycle,
step=step,
sample=sample,
)
query.save_parameters(output_dir)
console.print("[blue]Generated query: [/blue]", end="")
console.print(JSON.from_data(query.payload), style="blue")
# Send a request
with success_fail_catch(":brain: Processing folding job"):
response = client.send_request(query, project_code)
# Access confidence data
console.print("[blue]Confidence Data:[/blue]", end=" ")
console.print(JSON.from_data(response.confidence_data), style="blue")
with success_fail_catch(
f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`"
):
response.download_results(output_dir=output_dir, force=force, unzip=unzip)
|