"""API entries data model.""" from __future__ import annotations import csv import json import tempfile from ast import literal_eval from enum import Enum from pathlib import Path from typing import List import cloudpathlib from folding_studio_data_models import CustomFileType, FeatureMode from folding_studio_data_models.content import TemplateMaskCollection from folding_studio_data_models.exceptions import ( TemplatesMasksSettingsError, ) from pydantic import BaseModel, ConfigDict, model_validator from rich import print # pylint:disable=redefined-builtin from typing_extensions import Self from folding_studio.api_call.upload_custom_files import upload_custom_files from folding_studio.utils.file_helpers import ( partition_template_pdb_from_file, ) from folding_studio.utils.headers import get_auth_headers class SimpleInputFile(str, Enum): """Supported simple prediction file source extensions.""" FASTA = ".fasta" class BatchInputFile(str, Enum): """Supported batch prediction file source extensions.""" CSV = ".csv" JSON = ".json" class PredictRequestParams(BaseModel): """Prediction parameters model.""" ignore_cache: bool template_mode: FeatureMode custom_template_ids: List[str] msa_mode: FeatureMode max_msa_clusters: int max_extra_msa: int gap_trick: bool num_recycle: int random_seed: int model_subset: set[int] model_config = ConfigDict(protected_namespaces=()) class MSARequestParams(BaseModel): """MSA parameters model.""" ignore_cache: bool msa_mode: FeatureMode class PredictRequestCustomFiles(BaseModel): """Prediction custom files model.""" templates: List[Path | str] msas: List[Path] initial_guess_files: List[Path] | None = None templates_masks_files: List[Path] | None = None uploaded: bool = False _local_to_uploaded: dict | None = None @model_validator(mode="after") def _check_templates_and_masks_content(self) -> Self: """Checks if templates used by mask are being uploaded.""" if not self.templates_masks_files: return self custom_templates_names = [Path(m).name for m in self.templates] for tm_file in self.templates_masks_files: tm_collection = TemplateMaskCollection.model_validate_json( tm_file.read_text() ) if not ( all( tm.template_name in custom_templates_names for tm in tm_collection.templates_masks ) ): err = "Templates files are missing. Check your input command." raise TemplatesMasksSettingsError(err) return self @classmethod def _from_json_batch_file(cls, batch_jobs_file: Path) -> PredictRequestCustomFiles: """ Reads a JSON batch jobs file and extracts custom templates and MSAs. Args: batch_jobs_file (Path): The path to the batch jobs file in JSON format. Returns: An instance of PredictRequestCustomFiles. """ custom_templates = [] custom_msas = [] initial_guess_files = [] templates_masks_files = [] jobs = json.loads(batch_jobs_file.read_text()) for req in jobs["requests"]: tmpl = req["parameters"].get("custom_templates", []) custom_templates.extend(tmpl) msa = req["parameters"].get("custom_msas", []) custom_msas.extend(msa) ig = req["parameters"].get("initial_guess_file") if ig: initial_guess_files.append(ig) tm = req["parameters"].get("templates_masks_file") if tm: templates_masks_files.append(tm) return cls( templates=custom_templates, msas=custom_msas, initial_guess_files=initial_guess_files, templates_masks_files=templates_masks_files, ) @classmethod def _from_csv_batch_file(cls, batch_jobs_file: Path) -> PredictRequestCustomFiles: """ Reads a CSV batch jobs file and extracts custom templates and MSAs. Args: batch_jobs_file (Path): The path to the batch jobs file in CSV format. Returns: An instance of PredictRequestCustomFiles. """ custom_templates = [] custom_msas = [] initial_guess_files = [] templates_masks_files = [] with batch_jobs_file.open("r") as file: jobs_reader = csv.DictReader( file, quotechar='"', delimiter=",", quoting=csv.QUOTE_ALL, ) for row in jobs_reader: tmpl = row.get("custom_templates") if tmpl: tmpl = literal_eval(tmpl) custom_templates.extend(tmpl) msa = row.get("custom_msas") if msa: msa = literal_eval(msa) custom_msas.extend(msa) ig = row.get("initial_guess_file") if ig: initial_guess_files.extend([ig]) tm = row.get("templates_masks_file") if tm: templates_masks_files.extend([tm]) return cls( templates=custom_templates, msas=custom_msas, initial_guess_files=initial_guess_files, templates_masks_files=templates_masks_files, ) @classmethod def from_batch_jobs_file(cls, batch_jobs_file: Path) -> PredictRequestCustomFiles: """Creates a PredictRequestCustomFiles instance from a batch jobs file (CSV or JSON). This function reads a batch jobs file, resolves file paths for custom templates and MSAs, and returns a PredictRequestCustomFiles object. Args: batch_jobs_file (Path): The path to the batch jobs file. Must be a CSV or JSON file. Returns: PredictRequestCustomFiles: An instance containing the custom templates and MSAs. Raises: ValueError: If the file is not a CSV or JSON file. """ if batch_jobs_file.suffix == BatchInputFile.CSV: return cls._from_csv_batch_file(batch_jobs_file) elif batch_jobs_file.suffix == BatchInputFile.JSON: return cls._from_json_batch_file(batch_jobs_file) else: raise ValueError( f"Unsupported file type {batch_jobs_file.suffix}: {batch_jobs_file}" ) def upload(self, api_key: str | None = None) -> None: """Upload local custom paths to GCP through an API request. Returns: A dict mapping local to uploaded files path. """ if self.uploaded: print("Custom files already uploaded, skipping upload.") return self._local_to_uploaded local_to_uploaded = {} headers = get_auth_headers(api_key) if len(self.templates) > 0: _, templates_to_upload = partition_template_pdb_from_file( custom_templates=self.templates ) filename_to_gcs_path = upload_custom_files( headers=headers, paths=[Path(t) for t in templates_to_upload], file_type=CustomFileType.TEMPLATE, ) self.templates = list(filename_to_gcs_path.values()) local_to_uploaded.update(filename_to_gcs_path) if len(self.msas) > 0: filename_to_gcs_path = upload_custom_files( headers=headers, paths=[Path(m) for m in self.msas], file_type=CustomFileType.MSA, ) self.msas = list(filename_to_gcs_path.values()) local_to_uploaded.update(filename_to_gcs_path) if self.initial_guess_files: filename_to_gcs_path = upload_custom_files( headers=headers, paths=[Path(ig) for ig in self.initial_guess_files] if self.initial_guess_files else self.initial_guess_files, file_type=CustomFileType.INITIAL_GUESS, ) self.initial_guess_files = list(filename_to_gcs_path.values()) local_to_uploaded.update(filename_to_gcs_path) if self.templates_masks_files: # Replace content of tm files to match the uploaded template file new_tm_files = _replace_tm_file_template_content( templates_masks_files=self.templates_masks_files, local_to_uploaded=local_to_uploaded, ) filename_to_gcs_path = upload_custom_files( headers=headers, paths=new_tm_files.values(), file_type=CustomFileType.TEMPLATE_MASK, ) for k, v in new_tm_files.items(): new_tm_files[k] = filename_to_gcs_path[str(v)] self.templates_masks_files = list(new_tm_files.values()) local_to_uploaded.update(new_tm_files) self.uploaded = True self._local_to_uploaded = local_to_uploaded return local_to_uploaded def _replace_tm_file_template_content( templates_masks_files: List[Path], local_to_uploaded: dict ): """Helper function to replace the template name in TM files.""" new_tm_files = {} for tm in templates_masks_files: mask_content = tm.read_text() for ( template, uploaded_file, ) in local_to_uploaded.items(): mask_content = mask_content.replace( template.split("/")[-1], cloudpathlib.CloudPath(uploaded_file).name, ) # Get the default temporary directory # and write a new tm file which contains the uploaded template file name temp_dir = tempfile.gettempdir() temp_file_path = Path(temp_dir) / tm.name temp_file_path.write_text(mask_content) new_tm_files[str(tm)] = temp_file_path return new_tm_files