jfaustin's picture
add dockerfile and folding studio cli
44459bb
"""CLI MSA search command and sub-commands."""
import json
import shutil
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Optional
import requests
import typer
from folding_studio_data_models import (
FeatureMode,
MessageStatus,
MSAPublication,
)
from rich import print # pylint:disable=redefined-builtin
from rich.console import Console
from rich.markdown import Markdown
from rich.table import Table
from typing_extensions import Annotated
from folding_studio.api_call.msa import simple_msa
from folding_studio.config import API_URL, REQUEST_TIMEOUT
from folding_studio.utils.data_model import (
MSARequestParams,
SimpleInputFile,
)
from folding_studio.utils.headers import get_auth_headers
app = typer.Typer(help="Handle MSA operation")
msa_experiment_app = typer.Typer(help="Commands related to MSA experiments.")
app.add_typer(msa_experiment_app, name="experiment")
msa_experiment_ID_argument = typer.Argument(help="ID of the MSA experiment.")
def _validate_source_path(path: Path) -> Path:
"""Validate the msa job input source path.
Args:
path (Path): Source path.
Raises:
typer.BadParameter: If the source is an unsupported file.
Returns:
Path: The source.
"""
supported_simple_msa = tuple(item.value for item in SimpleInputFile)
if path.suffix not in supported_simple_msa:
raise typer.BadParameter(
f"The source file '{path}' is not supported. "
f"Only {supported_simple_msa} files are supported."
)
return path
def _print_instructions_simple(response_json: dict, metadata_file: Path | None) -> None:
"""Print pretty instructions after successful call to batch predict endpoint.
Args:
response_json (dict): Server json response
metadata_file: (Path | None): File path where job submission metadata are written.
"""
pub = MSAPublication.model_validate(response_json)
msa_experiment_id = pub.message.msa_experiment_id
console = Console()
if pub.status == MessageStatus.NOT_PUBLISHED_DONE:
print(
f"The results of your msa_experiment {msa_experiment_id} were found in the cache."
)
print("Use the following command to download the msa results:")
md = f"""```shell
folding msa experiment features {msa_experiment_id}
"""
console.print(Markdown(md))
elif pub.status == MessageStatus.NOT_PUBLISHED_PENDING:
print(
f"Your msa_experiment [bold]{msa_experiment_id}[/bold] is [bold green]still running.[/bold green]"
)
print("Use the following command to check on its status at a later time.")
md = f"""```shell
folding msa experiment status {msa_experiment_id}
"""
console.print(Markdown(md))
elif pub.status == MessageStatus.PUBLISHED:
print("[bold green]Experiment submitted successfully ![/bold green]")
print(f"The msa_experiment_id is [bold]{msa_experiment_id}[/bold]")
if not metadata_file:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
metadata_file = f"simple_prediction_{timestamp}.json"
with open(metadata_file, "w") as f:
json.dump(response_json, f, indent=4)
print(f"Prediction job metadata written to [bold]{metadata_file}[/bold]")
print("You can query your experiment status with the command:")
md = f"""```shell
folding msa experiment status {msa_experiment_id}
"""
console.print(Markdown(md))
else:
raise ValueError(f"Unknown publication status: {pub.status}")
@app.command()
def search( # pylint: disable=dangerous-default-value, too-many-arguments, too-many-locals
source: Annotated[
Path,
typer.Argument(
help=("Path to the input fasta file."),
callback=_validate_source_path,
exists=True,
),
],
project_code: Annotated[
str,
typer.Option(
help=(
"Project code. If unknown, contact your PM or the Folding Studio team."
),
exists=True,
envvar="FOLDING_PROJECT_CODE",
),
],
cache: Annotated[
bool,
typer.Option(help="Use cached experiment results if any."),
] = True,
msa_mode: Annotated[
FeatureMode,
typer.Option(help="Mode of the MSA features generation."),
] = FeatureMode.SEARCH,
metadata_file: Annotated[
Optional[Path],
typer.Option(
help=(
"Path to the file where the job metadata returned by the server are written."
),
),
] = None,
):
"""Run an MSA tool. Read more at https://int-bio-foldingstudio-gcp.nw.r.appspot.com/tutorials/msa_search/."""
params = MSARequestParams(
ignore_cache=not cache,
msa_mode=msa_mode,
)
response = simple_msa(
file=source,
params=params,
project_code=project_code,
)
_print_instructions_simple(response_json=response, metadata_file=metadata_file)
def _download_file_from_signed_url(
msa_exp_id: str,
endpoint: str,
output: Path,
force: bool,
unzip: bool = False,
) -> None:
"""Download a zip file from an experiment id.
Args:
msa_exp_id (str): MSA Experiment id.
endpoint (str): API endpoint to call.
output (Path): Output file path.
force (bool): Force file writing if it already exists.
unzip (bool): Unzip the zip file after downloading.
Raises:
typer.Exit: If output file path exists but force set to false.
typer.Exit: If unzip set to true but the directory already exists and force set to false.
typer.Exit: If an error occurred during the initial API call.
"""
if output.exists() and not force:
print(
f"Warning: The file '{output}' already exists. Use the --force flag to overwrite it."
)
raise typer.Exit(code=1)
if unzip:
if not output.suffix == ".zip":
print(
"Error: The downloaded file is not a .zip file. Please ensure the correct file format."
)
raise typer.Exit(code=1)
dir_path = output.with_suffix("")
if dir_path.exists() and not force:
print(
f"Warning: The --unzip flag is raised but the directory '{dir_path}' "
"already exists. Use the --force flag to overwrite it."
)
raise typer.Exit(code=1)
url = API_URL + endpoint
headers = get_auth_headers()
response = requests.get(
url,
params={"msa_experiment_id": msa_exp_id},
headers=headers,
timeout=REQUEST_TIMEOUT,
)
if not response.ok:
print(f"Failed to download the file: {response.content.decode()}.")
raise typer.Exit(code=1)
file_response = requests.get(
response.json()["signed_url"],
stream=True,
timeout=REQUEST_TIMEOUT,
)
with output.open("wb") as f:
file_response.raw.decode_content = True
shutil.copyfileobj(file_response.raw, f)
print(f"File downloaded successfully to {output}.")
if unzip:
dir_path.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(output, "r") as zip_ref:
zip_ref.extractall(dir_path)
print(f"Extracted all files to {dir_path}.")
@msa_experiment_app.command()
def status(
msa_exp_id: Annotated[str, msa_experiment_ID_argument],
):
"""Get an MSA experiment status."""
url = API_URL + "getMSAExperimentStatus"
headers = get_auth_headers()
response = requests.get(
url,
params={"msa_experiment_id": msa_exp_id},
headers=headers,
timeout=REQUEST_TIMEOUT,
)
if not response.ok:
print(f"An error occurred : {response.content.decode()}")
raise typer.Exit(code=1)
message = response.json()
print(message["status"])
@msa_experiment_app.command()
def features(
msa_exp_id: Annotated[str, msa_experiment_ID_argument],
output: Annotated[
Optional[Path],
typer.Option(
help="Local path to download the zip to. Default to '<msa_exp_id>_features.zip'."
),
] = None,
force: Annotated[
bool,
typer.Option(
help=(
"Forces the download to overwrite any existing file "
"with the same name in the specified location."
)
),
] = False,
unzip: Annotated[
bool, typer.Option(help="Automatically unzip the file after its download.")
] = False,
):
"""Get an experiment features."""
if output is None:
output = Path(f"{msa_exp_id}_features.zip")
_download_file_from_signed_url(
msa_exp_id=msa_exp_id,
endpoint="getZippedMSAExperimentFeatures",
output=output,
force=force,
unzip=unzip,
)
@msa_experiment_app.command()
def logs(
msa_exp_id: Annotated[str, msa_experiment_ID_argument],
output: Annotated[
Optional[Path],
typer.Option(
help="Local path to download the logs to. Default to '<exp_id>_logs.txt'."
),
] = None,
force: Annotated[
bool,
typer.Option(
help=(
"Forces the download to overwrite any existing file "
"with the same name in the specified location."
)
),
] = False,
):
"""Get an experiment logs."""
if output is None:
output = Path(f"{msa_exp_id}_logs.txt")
_download_file_from_signed_url(
msa_exp_id=msa_exp_id,
endpoint="getExperimentLogs",
output=output,
force=force,
)
@msa_experiment_app.command()
def list(
limit: Annotated[
int,
typer.Option(
help=("Max number of experiment to display in the terminal."),
),
] = 100,
output: Annotated[
Optional[Path],
typer.Option(
"--output",
"-o",
help=(
"Path to the file where the job metadata returned by the server are written."
),
),
] = None,
): # pylint:disable=redefined-builtin
"""Get all your done and pending experiment ids. The IDs are provided in the order of submission, starting with the most recent."""
headers = get_auth_headers()
url = API_URL + "getDoneAndPendingMSAExperiments"
response = requests.get(
url,
headers=headers,
timeout=REQUEST_TIMEOUT,
)
if not response.ok:
print(f"An error occurred : {response.content.decode()}")
raise typer.Exit(code=1)
response_json = response.json()
if output:
with open(output, "w") as f:
json.dump(response_json, f, indent=4)
print(f"Done and pending MSA experiments list written to [bold]{output}[/bold]")
table = Table(title="Done and pending MSA experiments")
table.add_column("MSA Experiment ID", justify="right", style="cyan", no_wrap=True)
table.add_column("Status", style="magenta")
total_exp_nb = 0
for status, exp_list in response_json.items():
total_exp_nb += len(exp_list)
for exp in exp_list:
table.add_row(exp, status)
if limit < total_exp_nb:
print(
f"The table below is truncated to the last [bold]{limit}[/bold] submitted MSA experiments. Increase '--limit' to see more."
)
if not output:
print("Use '--output' to get the full list in file format.")
else:
print(f"See the full list in file format at [bold]{output}[/bold]")
break
console = Console()
console.print(table)