File size: 4,000 Bytes
44459bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Protenix folding submission command."""

from datetime import datetime
from pathlib import Path
from typing import Annotated

import typer
from rich.json import JSON
from rich.panel import Panel

from folding_studio.client import Client
from folding_studio.commands.utils import (
    success_fail_catch_print,
    success_fail_catch_spinner,
)
from folding_studio.config import FOLDING_API_KEY
from folding_studio.console import console
from folding_studio.query import ProtenixQuery


def protenix(
    source: Annotated[
        Path,
        typer.Argument(
            help=(
                "Path to the data source. Either a fasta file, a directory of fasta files"
                "describing a batch prediction request."
            ),
            # callback=_validate_source_path,
            exists=True,
        ),
    ],
    project_code: Annotated[  # noqa: ANN001
        str,
        typer.Option(
            help=(
                "Project code. If unknown, contact your PM or the Folding Studio team."
            ),
            exists=True,
            envvar="FOLDING_PROJECT_CODE",
        ),
    ],
    use_msa_server: Annotated[  # pylint: disable=unused-argument
        bool,
        typer.Option(
            help="Flag to use the MSA server for inference. Forced to True.",
            is_flag=True,
        ),
    ] = True,
    seed: Annotated[int, typer.Option(help="Random seed.")] = 0,
    cycle: Annotated[int, typer.Option(help="Pairformer cycle number.")] = 10,
    step: Annotated[
        int, typer.Option(help="Number of steps for the diffusion process.")
    ] = 200,
    sample: Annotated[int, typer.Option(help="Number of samples in each seed.")] = 5,
    output: Annotated[
        Path,
        typer.Option(
            help="Local path to download the result zip and query parameters to. "
            "Default to 'protenix_results'."
        ),
    ] = "protenix_results",
    force: Annotated[
        bool,
        typer.Option(
            help=(
                "Forces the download to overwrite any existing file "
                "with the same name in the specified location."
            )
        ),
    ] = False,
    unzip: Annotated[
        bool, typer.Option(help="Unzip the file after its download.")
    ] = False,
    spinner: Annotated[
        bool, typer.Option(help="Use live spinner in log output.")
    ] = True,
):
    """Synchronous Protenix folding submission."""

    success_fail_catch = (
        success_fail_catch_spinner if spinner else success_fail_catch_print
    )

    console.print(
        Panel("[bold cyan]:dna: Protenix Folding submission [/bold cyan]", expand=False)
    )
    output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}"

    # Create a client using API key or JWT
    with success_fail_catch(":key: Authenticating client"):
        client = Client.authenticate()

    # Define a query
    with success_fail_catch(":package: Generating query"):
        query_builder = (
            ProtenixQuery.from_file
            if source.is_file()
            else ProtenixQuery.from_directory
        )
        query: ProtenixQuery = query_builder(
            source,
            use_msa_server=True,
            seed=seed,
            cycle=cycle,
            step=step,
            sample=sample,
        )
        query.save_parameters(output_dir)

    console.print("[blue]Generated query: [/blue]", end="")
    console.print(JSON.from_data(query.payload), style="blue")

    # Send a request
    with success_fail_catch(":brain: Processing folding job"):
        response = client.send_request(query, project_code)

    # Access confidence data
    console.print("[blue]Confidence Data:[/blue]", end=" ")
    console.print(JSON.from_data(response.confidence_data), style="blue")

    with success_fail_catch(
        f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`"
    ):
        response.download_results(output_dir=output_dir, force=force, unzip=unzip)