"""Chai-1 folding submission command.""" from datetime import datetime from pathlib import Path from typing import Annotated, Optional import typer from rich.json import JSON from rich.panel import Panel from folding_studio.client import Client from folding_studio.commands.utils import ( success_fail_catch_print, success_fail_catch_spinner, ) from folding_studio.console import console from folding_studio.query.chai import ChaiQuery def chai( source: Annotated[ Path, typer.Argument( help=( "Path to the data source. Either a fasta file, a directory of fasta files " "or a csv/json file describing a batch prediction request." ), exists=True, ), ], project_code: Annotated[ str, typer.Option( help="Project code. If unknown, contact your PM or the Folding Studio team.", envvar="FOLDING_PROJECT_CODE", exists=True, ), ], use_msa_server: Annotated[ bool, typer.Option( help="Flag to enable MSA features. MSA search is performed by InstaDeep's MMseqs2 server.", is_flag=True, ), ] = True, use_templates_server: Annotated[ bool, typer.Option( help="Flag to enable templates. Templates search is performed by InstaDeep's MMseqs2 server.", is_flag=True, ), ] = False, num_trunk_recycles: Annotated[ int, typer.Option(help="Number of trunk recycles during inference.") ] = 3, seed: Annotated[int, typer.Option(help="Random seed for inference.")] = 0, num_diffn_timesteps: Annotated[ int, typer.Option(help="Number of diffusion timesteps to run.") ] = 200, restraints: Annotated[ Optional[str], typer.Option(help="Restraints information."), ] = None, recycle_msa_subsample: Annotated[ int, typer.Option(help="Subsample parameter for recycling MSA during inference."), ] = 0, num_trunk_samples: Annotated[ int, typer.Option(help="Number of trunk samples to generate during inference.") ] = 1, msa_path: Annotated[ Optional[str], typer.Option( help="Path to the custom MSAs. It can be a .a3m or .aligned.pqt file, or a directory containing these files." ), ] = None, output: Annotated[ Path, typer.Option( help="Local path to download the result zip and query parameters to. " "Default to 'chai_results'." ), ] = "chai_results", force: Annotated[ bool, typer.Option( help=( "Forces the download to overwrite any existing file " "with the same name in the specified location." ) ), ] = False, unzip: Annotated[ bool, typer.Option(help="Unzip the file after its download.") ] = False, spinner: Annotated[ bool, typer.Option(help="Use live spinner in log output.") ] = True, ): """Synchronous Chai-1 folding submission.""" # If a custom MSA path is provided, disable automated MSA search. if msa_path is not None: console.print( "\n[yellow]:warning: Custom MSA path provided. Disabling automated MSA search.[/yellow]" ) use_msa_server = False console.print( Panel("[bold cyan]:dna: Chai-1 Folding submission [/bold cyan]", expand=False) ) success_fail_catch = ( success_fail_catch_spinner if spinner else success_fail_catch_print ) # Create a client using API key or JWT with success_fail_catch(":key: Authenticating client"): client = Client.authenticate() output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}" output_dir.mkdir(parents=True, exist_ok=True) # Define a query with success_fail_catch(":package: Generating query"): query_builder = ( ChaiQuery.from_file if source.is_file() else ChaiQuery.from_directory ) query: ChaiQuery = query_builder( source, restraints=restraints, use_msa_server=use_msa_server, use_templates_server=use_templates_server, num_trunk_recycles=num_trunk_recycles, seed=seed, num_diffn_timesteps=num_diffn_timesteps, recycle_msa_subsample=recycle_msa_subsample, num_trunk_samples=num_trunk_samples, custom_msa_paths=msa_path, ) query.save_parameters(output_dir) console.print("[blue]Generated query: [/blue]", end="") console.print(JSON.from_data(query.payload), style="blue") # Send a request with success_fail_catch(":brain: Processing folding job"): response = client.send_request(query, project_code) # Access confidence data console.print("[blue]Confidence Data:[/blue]", end=" ") console.print(JSON.from_data(response.confidence_data), style="blue") # Download results with success_fail_catch( f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`" ): response.download_results(output_dir=output_dir, force=force, unzip=unzip)