|
"""CLI MSA search command and sub-commands.""" |
|
|
|
import json |
|
import shutil |
|
import zipfile |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import requests |
|
import typer |
|
from folding_studio_data_models import ( |
|
FeatureMode, |
|
MessageStatus, |
|
MSAPublication, |
|
) |
|
from rich import print |
|
from rich.console import Console |
|
from rich.markdown import Markdown |
|
from rich.table import Table |
|
from typing_extensions import Annotated |
|
|
|
from folding_studio.api_call.msa import simple_msa |
|
from folding_studio.config import API_URL, REQUEST_TIMEOUT |
|
from folding_studio.utils.data_model import ( |
|
MSARequestParams, |
|
SimpleInputFile, |
|
) |
|
from folding_studio.utils.headers import get_auth_headers |
|
|
|
app = typer.Typer(help="Handle MSA operation") |
|
msa_experiment_app = typer.Typer(help="Commands related to MSA experiments.") |
|
app.add_typer(msa_experiment_app, name="experiment") |
|
|
|
msa_experiment_ID_argument = typer.Argument(help="ID of the MSA experiment.") |
|
|
|
|
|
def _validate_source_path(path: Path) -> Path: |
|
"""Validate the msa job input source path. |
|
|
|
Args: |
|
path (Path): Source path. |
|
|
|
Raises: |
|
typer.BadParameter: If the source is an unsupported file. |
|
|
|
Returns: |
|
Path: The source. |
|
""" |
|
supported_simple_msa = tuple(item.value for item in SimpleInputFile) |
|
|
|
if path.suffix not in supported_simple_msa: |
|
raise typer.BadParameter( |
|
f"The source file '{path}' is not supported. " |
|
f"Only {supported_simple_msa} files are supported." |
|
) |
|
return path |
|
|
|
|
|
def _print_instructions_simple(response_json: dict, metadata_file: Path | None) -> None: |
|
"""Print pretty instructions after successful call to batch predict endpoint. |
|
|
|
Args: |
|
response_json (dict): Server json response |
|
metadata_file: (Path | None): File path where job submission metadata are written. |
|
""" |
|
pub = MSAPublication.model_validate(response_json) |
|
msa_experiment_id = pub.message.msa_experiment_id |
|
|
|
console = Console() |
|
if pub.status == MessageStatus.NOT_PUBLISHED_DONE: |
|
print( |
|
f"The results of your msa_experiment {msa_experiment_id} were found in the cache." |
|
) |
|
print("Use the following command to download the msa results:") |
|
md = f"""```shell |
|
folding msa experiment features {msa_experiment_id} |
|
""" |
|
console.print(Markdown(md)) |
|
elif pub.status == MessageStatus.NOT_PUBLISHED_PENDING: |
|
print( |
|
f"Your msa_experiment [bold]{msa_experiment_id}[/bold] is [bold green]still running.[/bold green]" |
|
) |
|
print("Use the following command to check on its status at a later time.") |
|
md = f"""```shell |
|
folding msa experiment status {msa_experiment_id} |
|
""" |
|
console.print(Markdown(md)) |
|
elif pub.status == MessageStatus.PUBLISHED: |
|
print("[bold green]Experiment submitted successfully ![/bold green]") |
|
print(f"The msa_experiment_id is [bold]{msa_experiment_id}[/bold]") |
|
|
|
if not metadata_file: |
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
|
metadata_file = f"simple_prediction_{timestamp}.json" |
|
with open(metadata_file, "w") as f: |
|
json.dump(response_json, f, indent=4) |
|
|
|
print(f"Prediction job metadata written to [bold]{metadata_file}[/bold]") |
|
print("You can query your experiment status with the command:") |
|
md = f"""```shell |
|
folding msa experiment status {msa_experiment_id} |
|
""" |
|
console.print(Markdown(md)) |
|
else: |
|
raise ValueError(f"Unknown publication status: {pub.status}") |
|
|
|
|
|
@app.command() |
|
def search( |
|
source: Annotated[ |
|
Path, |
|
typer.Argument( |
|
help=("Path to the input fasta file."), |
|
callback=_validate_source_path, |
|
exists=True, |
|
), |
|
], |
|
project_code: Annotated[ |
|
str, |
|
typer.Option( |
|
help=( |
|
"Project code. If unknown, contact your PM or the Folding Studio team." |
|
), |
|
exists=True, |
|
envvar="FOLDING_PROJECT_CODE", |
|
), |
|
], |
|
cache: Annotated[ |
|
bool, |
|
typer.Option(help="Use cached experiment results if any."), |
|
] = True, |
|
msa_mode: Annotated[ |
|
FeatureMode, |
|
typer.Option(help="Mode of the MSA features generation."), |
|
] = FeatureMode.SEARCH, |
|
metadata_file: Annotated[ |
|
Optional[Path], |
|
typer.Option( |
|
help=( |
|
"Path to the file where the job metadata returned by the server are written." |
|
), |
|
), |
|
] = None, |
|
): |
|
"""Run an MSA tool. Read more at https://int-bio-foldingstudio-gcp.nw.r.appspot.com/tutorials/msa_search/.""" |
|
|
|
params = MSARequestParams( |
|
ignore_cache=not cache, |
|
msa_mode=msa_mode, |
|
) |
|
response = simple_msa( |
|
file=source, |
|
params=params, |
|
project_code=project_code, |
|
) |
|
|
|
_print_instructions_simple(response_json=response, metadata_file=metadata_file) |
|
|
|
|
|
def _download_file_from_signed_url( |
|
msa_exp_id: str, |
|
endpoint: str, |
|
output: Path, |
|
force: bool, |
|
unzip: bool = False, |
|
) -> None: |
|
"""Download a zip file from an experiment id. |
|
|
|
Args: |
|
msa_exp_id (str): MSA Experiment id. |
|
endpoint (str): API endpoint to call. |
|
output (Path): Output file path. |
|
force (bool): Force file writing if it already exists. |
|
unzip (bool): Unzip the zip file after downloading. |
|
|
|
Raises: |
|
typer.Exit: If output file path exists but force set to false. |
|
typer.Exit: If unzip set to true but the directory already exists and force set to false. |
|
typer.Exit: If an error occurred during the initial API call. |
|
""" |
|
if output.exists() and not force: |
|
print( |
|
f"Warning: The file '{output}' already exists. Use the --force flag to overwrite it." |
|
) |
|
raise typer.Exit(code=1) |
|
|
|
if unzip: |
|
if not output.suffix == ".zip": |
|
print( |
|
"Error: The downloaded file is not a .zip file. Please ensure the correct file format." |
|
) |
|
raise typer.Exit(code=1) |
|
|
|
dir_path = output.with_suffix("") |
|
if dir_path.exists() and not force: |
|
print( |
|
f"Warning: The --unzip flag is raised but the directory '{dir_path}' " |
|
"already exists. Use the --force flag to overwrite it." |
|
) |
|
raise typer.Exit(code=1) |
|
|
|
url = API_URL + endpoint |
|
|
|
headers = get_auth_headers() |
|
response = requests.get( |
|
url, |
|
params={"msa_experiment_id": msa_exp_id}, |
|
headers=headers, |
|
timeout=REQUEST_TIMEOUT, |
|
) |
|
if not response.ok: |
|
print(f"Failed to download the file: {response.content.decode()}.") |
|
raise typer.Exit(code=1) |
|
|
|
file_response = requests.get( |
|
response.json()["signed_url"], |
|
stream=True, |
|
timeout=REQUEST_TIMEOUT, |
|
) |
|
with output.open("wb") as f: |
|
file_response.raw.decode_content = True |
|
shutil.copyfileobj(file_response.raw, f) |
|
print(f"File downloaded successfully to {output}.") |
|
|
|
if unzip: |
|
dir_path.mkdir(parents=True, exist_ok=True) |
|
with zipfile.ZipFile(output, "r") as zip_ref: |
|
zip_ref.extractall(dir_path) |
|
print(f"Extracted all files to {dir_path}.") |
|
|
|
|
|
@msa_experiment_app.command() |
|
def status( |
|
msa_exp_id: Annotated[str, msa_experiment_ID_argument], |
|
): |
|
"""Get an MSA experiment status.""" |
|
url = API_URL + "getMSAExperimentStatus" |
|
headers = get_auth_headers() |
|
response = requests.get( |
|
url, |
|
params={"msa_experiment_id": msa_exp_id}, |
|
headers=headers, |
|
timeout=REQUEST_TIMEOUT, |
|
) |
|
|
|
if not response.ok: |
|
print(f"An error occurred : {response.content.decode()}") |
|
raise typer.Exit(code=1) |
|
|
|
message = response.json() |
|
print(message["status"]) |
|
|
|
|
|
@msa_experiment_app.command() |
|
def features( |
|
msa_exp_id: Annotated[str, msa_experiment_ID_argument], |
|
output: Annotated[ |
|
Optional[Path], |
|
typer.Option( |
|
help="Local path to download the zip to. Default to '<msa_exp_id>_features.zip'." |
|
), |
|
] = None, |
|
force: Annotated[ |
|
bool, |
|
typer.Option( |
|
help=( |
|
"Forces the download to overwrite any existing file " |
|
"with the same name in the specified location." |
|
) |
|
), |
|
] = False, |
|
unzip: Annotated[ |
|
bool, typer.Option(help="Automatically unzip the file after its download.") |
|
] = False, |
|
): |
|
"""Get an experiment features.""" |
|
if output is None: |
|
output = Path(f"{msa_exp_id}_features.zip") |
|
|
|
_download_file_from_signed_url( |
|
msa_exp_id=msa_exp_id, |
|
endpoint="getZippedMSAExperimentFeatures", |
|
output=output, |
|
force=force, |
|
unzip=unzip, |
|
) |
|
|
|
|
|
@msa_experiment_app.command() |
|
def logs( |
|
msa_exp_id: Annotated[str, msa_experiment_ID_argument], |
|
output: Annotated[ |
|
Optional[Path], |
|
typer.Option( |
|
help="Local path to download the logs to. Default to '<exp_id>_logs.txt'." |
|
), |
|
] = None, |
|
force: Annotated[ |
|
bool, |
|
typer.Option( |
|
help=( |
|
"Forces the download to overwrite any existing file " |
|
"with the same name in the specified location." |
|
) |
|
), |
|
] = False, |
|
): |
|
"""Get an experiment logs.""" |
|
if output is None: |
|
output = Path(f"{msa_exp_id}_logs.txt") |
|
|
|
_download_file_from_signed_url( |
|
msa_exp_id=msa_exp_id, |
|
endpoint="getExperimentLogs", |
|
output=output, |
|
force=force, |
|
) |
|
|
|
|
|
@msa_experiment_app.command() |
|
def list( |
|
limit: Annotated[ |
|
int, |
|
typer.Option( |
|
help=("Max number of experiment to display in the terminal."), |
|
), |
|
] = 100, |
|
output: Annotated[ |
|
Optional[Path], |
|
typer.Option( |
|
"--output", |
|
"-o", |
|
help=( |
|
"Path to the file where the job metadata returned by the server are written." |
|
), |
|
), |
|
] = None, |
|
): |
|
"""Get all your done and pending experiment ids. The IDs are provided in the order of submission, starting with the most recent.""" |
|
headers = get_auth_headers() |
|
url = API_URL + "getDoneAndPendingMSAExperiments" |
|
response = requests.get( |
|
url, |
|
headers=headers, |
|
timeout=REQUEST_TIMEOUT, |
|
) |
|
|
|
if not response.ok: |
|
print(f"An error occurred : {response.content.decode()}") |
|
raise typer.Exit(code=1) |
|
|
|
response_json = response.json() |
|
if output: |
|
with open(output, "w") as f: |
|
json.dump(response_json, f, indent=4) |
|
|
|
print(f"Done and pending MSA experiments list written to [bold]{output}[/bold]") |
|
|
|
table = Table(title="Done and pending MSA experiments") |
|
|
|
table.add_column("MSA Experiment ID", justify="right", style="cyan", no_wrap=True) |
|
table.add_column("Status", style="magenta") |
|
|
|
total_exp_nb = 0 |
|
for status, exp_list in response_json.items(): |
|
total_exp_nb += len(exp_list) |
|
for exp in exp_list: |
|
table.add_row(exp, status) |
|
if limit < total_exp_nb: |
|
print( |
|
f"The table below is truncated to the last [bold]{limit}[/bold] submitted MSA experiments. Increase '--limit' to see more." |
|
) |
|
if not output: |
|
print("Use '--output' to get the full list in file format.") |
|
else: |
|
print(f"See the full list in file format at [bold]{output}[/bold]") |
|
break |
|
|
|
console = Console() |
|
console.print(table) |
|
|