File size: 5,265 Bytes
44459bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""Chai-1 folding submission command."""

from datetime import datetime
from pathlib import Path
from typing import Annotated, Optional

import typer
from rich.json import JSON
from rich.panel import Panel

from folding_studio.client import Client
from folding_studio.commands.utils import (
    success_fail_catch_print,
    success_fail_catch_spinner,
)
from folding_studio.console import console
from folding_studio.query.chai import ChaiQuery


def chai(
    source: Annotated[
        Path,
        typer.Argument(
            help=(
                "Path to the data source. Either a fasta file, a directory of fasta files "
                "or a csv/json file describing a batch prediction request."
            ),
            exists=True,
        ),
    ],
    project_code: Annotated[
        str,
        typer.Option(
            help="Project code. If unknown, contact your PM or the Folding Studio team.",
            envvar="FOLDING_PROJECT_CODE",
            exists=True,
        ),
    ],
    use_msa_server: Annotated[
        bool,
        typer.Option(
            help="Flag to enable MSA features. MSA search is performed by InstaDeep's MMseqs2 server.",
            is_flag=True,
        ),
    ] = True,
    use_templates_server: Annotated[
        bool,
        typer.Option(
            help="Flag to enable templates. Templates search is performed by InstaDeep's MMseqs2 server.",
            is_flag=True,
        ),
    ] = False,
    num_trunk_recycles: Annotated[
        int, typer.Option(help="Number of trunk recycles during inference.")
    ] = 3,
    seed: Annotated[int, typer.Option(help="Random seed for inference.")] = 0,
    num_diffn_timesteps: Annotated[
        int, typer.Option(help="Number of diffusion timesteps to run.")
    ] = 200,
    restraints: Annotated[
        Optional[str],
        typer.Option(help="Restraints information."),
    ] = None,
    recycle_msa_subsample: Annotated[
        int,
        typer.Option(help="Subsample parameter for recycling MSA during inference."),
    ] = 0,
    num_trunk_samples: Annotated[
        int, typer.Option(help="Number of trunk samples to generate during inference.")
    ] = 1,
    msa_path: Annotated[
        Optional[str],
        typer.Option(
            help="Path to the custom MSAs. It can be a .a3m or .aligned.pqt file, or a directory containing these files."
        ),
    ] = None,
    output: Annotated[
        Path,
        typer.Option(
            help="Local path to download the result zip and query parameters to. "
            "Default to 'chai_results'."
        ),
    ] = "chai_results",
    force: Annotated[
        bool,
        typer.Option(
            help=(
                "Forces the download to overwrite any existing file "
                "with the same name in the specified location."
            )
        ),
    ] = False,
    unzip: Annotated[
        bool, typer.Option(help="Unzip the file after its download.")
    ] = False,
    spinner: Annotated[
        bool, typer.Option(help="Use live spinner in log output.")
    ] = True,
):
    """Synchronous Chai-1 folding submission."""
    # If a custom MSA path is provided, disable automated MSA search.
    if msa_path is not None:
        console.print(
            "\n[yellow]:warning: Custom MSA path provided. Disabling automated MSA search.[/yellow]"
        )
        use_msa_server = False

    console.print(
        Panel("[bold cyan]:dna: Chai-1 Folding submission [/bold cyan]", expand=False)
    )

    success_fail_catch = (
        success_fail_catch_spinner if spinner else success_fail_catch_print
    )

    # Create a client using API key or JWT
    with success_fail_catch(":key: Authenticating client"):
        client = Client.authenticate()

    output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}"
    output_dir.mkdir(parents=True, exist_ok=True)
    # Define a query
    with success_fail_catch(":package: Generating query"):
        query_builder = (
            ChaiQuery.from_file if source.is_file() else ChaiQuery.from_directory
        )
        query: ChaiQuery = query_builder(
            source,
            restraints=restraints,
            use_msa_server=use_msa_server,
            use_templates_server=use_templates_server,
            num_trunk_recycles=num_trunk_recycles,
            seed=seed,
            num_diffn_timesteps=num_diffn_timesteps,
            recycle_msa_subsample=recycle_msa_subsample,
            num_trunk_samples=num_trunk_samples,
            custom_msa_paths=msa_path,
        )
        query.save_parameters(output_dir)

    console.print("[blue]Generated query: [/blue]", end="")
    console.print(JSON.from_data(query.payload), style="blue")

    # Send a request
    with success_fail_catch(":brain: Processing folding job"):
        response = client.send_request(query, project_code)

    # Access confidence data
    console.print("[blue]Confidence Data:[/blue]", end=" ")
    console.print(JSON.from_data(response.confidence_data), style="blue")

    # Download results
    with success_fail_catch(
        f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`"
    ):
        response.download_results(output_dir=output_dir, force=force, unzip=unzip)