|
"""Protenix folding submission command.""" |
|
|
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import Annotated |
|
|
|
import typer |
|
from rich.json import JSON |
|
from rich.panel import Panel |
|
|
|
from folding_studio.client import Client |
|
from folding_studio.commands.utils import ( |
|
success_fail_catch_print, |
|
success_fail_catch_spinner, |
|
) |
|
from folding_studio.config import FOLDING_API_KEY |
|
from folding_studio.console import console |
|
from folding_studio.query import ProtenixQuery |
|
|
|
|
|
def protenix( |
|
source: Annotated[ |
|
Path, |
|
typer.Argument( |
|
help=( |
|
"Path to the data source. Either a fasta file, a directory of fasta files" |
|
"describing a batch prediction request." |
|
), |
|
|
|
exists=True, |
|
), |
|
], |
|
project_code: Annotated[ |
|
str, |
|
typer.Option( |
|
help=( |
|
"Project code. If unknown, contact your PM or the Folding Studio team." |
|
), |
|
exists=True, |
|
envvar="FOLDING_PROJECT_CODE", |
|
), |
|
], |
|
use_msa_server: Annotated[ |
|
bool, |
|
typer.Option( |
|
help="Flag to use the MSA server for inference. Forced to True.", |
|
is_flag=True, |
|
), |
|
] = True, |
|
seed: Annotated[int, typer.Option(help="Random seed.")] = 0, |
|
cycle: Annotated[int, typer.Option(help="Pairformer cycle number.")] = 10, |
|
step: Annotated[ |
|
int, typer.Option(help="Number of steps for the diffusion process.") |
|
] = 200, |
|
sample: Annotated[int, typer.Option(help="Number of samples in each seed.")] = 5, |
|
output: Annotated[ |
|
Path, |
|
typer.Option( |
|
help="Local path to download the result zip and query parameters to. " |
|
"Default to 'protenix_results'." |
|
), |
|
] = "protenix_results", |
|
force: Annotated[ |
|
bool, |
|
typer.Option( |
|
help=( |
|
"Forces the download to overwrite any existing file " |
|
"with the same name in the specified location." |
|
) |
|
), |
|
] = False, |
|
unzip: Annotated[ |
|
bool, typer.Option(help="Unzip the file after its download.") |
|
] = False, |
|
spinner: Annotated[ |
|
bool, typer.Option(help="Use live spinner in log output.") |
|
] = True, |
|
): |
|
"""Synchronous Protenix folding submission.""" |
|
|
|
success_fail_catch = ( |
|
success_fail_catch_spinner if spinner else success_fail_catch_print |
|
) |
|
|
|
console.print( |
|
Panel("[bold cyan]:dna: Protenix Folding submission [/bold cyan]", expand=False) |
|
) |
|
output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}" |
|
|
|
|
|
with success_fail_catch(":key: Authenticating client"): |
|
client = Client.authenticate() |
|
|
|
|
|
with success_fail_catch(":package: Generating query"): |
|
query_builder = ( |
|
ProtenixQuery.from_file |
|
if source.is_file() |
|
else ProtenixQuery.from_directory |
|
) |
|
query: ProtenixQuery = query_builder( |
|
source, |
|
use_msa_server=True, |
|
seed=seed, |
|
cycle=cycle, |
|
step=step, |
|
sample=sample, |
|
) |
|
query.save_parameters(output_dir) |
|
|
|
console.print("[blue]Generated query: [/blue]", end="") |
|
console.print(JSON.from_data(query.payload), style="blue") |
|
|
|
|
|
with success_fail_catch(":brain: Processing folding job"): |
|
response = client.send_request(query, project_code) |
|
|
|
|
|
console.print("[blue]Confidence Data:[/blue]", end=" ") |
|
console.print(JSON.from_data(response.confidence_data), style="blue") |
|
|
|
with success_fail_catch( |
|
f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`" |
|
): |
|
response.download_results(output_dir=output_dir, force=force, unzip=unzip) |
|
|