|
"""API entries data model.""" |
|
|
|
from __future__ import annotations |
|
|
|
import csv |
|
import json |
|
import tempfile |
|
from ast import literal_eval |
|
from enum import Enum |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import cloudpathlib |
|
from folding_studio_data_models import CustomFileType, FeatureMode |
|
from folding_studio_data_models.content import TemplateMaskCollection |
|
from folding_studio_data_models.exceptions import ( |
|
TemplatesMasksSettingsError, |
|
) |
|
from pydantic import BaseModel, ConfigDict, model_validator |
|
from rich import print |
|
from typing_extensions import Self |
|
|
|
from folding_studio.api_call.upload_custom_files import upload_custom_files |
|
from folding_studio.utils.file_helpers import ( |
|
partition_template_pdb_from_file, |
|
) |
|
from folding_studio.utils.headers import get_auth_headers |
|
|
|
|
|
class SimpleInputFile(str, Enum): |
|
"""Supported simple prediction file source extensions.""" |
|
|
|
FASTA = ".fasta" |
|
|
|
|
|
class BatchInputFile(str, Enum): |
|
"""Supported batch prediction file source extensions.""" |
|
|
|
CSV = ".csv" |
|
JSON = ".json" |
|
|
|
|
|
class PredictRequestParams(BaseModel): |
|
"""Prediction parameters model.""" |
|
|
|
ignore_cache: bool |
|
template_mode: FeatureMode |
|
custom_template_ids: List[str] |
|
msa_mode: FeatureMode |
|
max_msa_clusters: int |
|
max_extra_msa: int |
|
gap_trick: bool |
|
num_recycle: int |
|
random_seed: int |
|
model_subset: set[int] |
|
|
|
model_config = ConfigDict(protected_namespaces=()) |
|
|
|
|
|
class MSARequestParams(BaseModel): |
|
"""MSA parameters model.""" |
|
|
|
ignore_cache: bool |
|
msa_mode: FeatureMode |
|
|
|
|
|
class PredictRequestCustomFiles(BaseModel): |
|
"""Prediction custom files model.""" |
|
|
|
templates: List[Path | str] |
|
msas: List[Path] |
|
initial_guess_files: List[Path] | None = None |
|
templates_masks_files: List[Path] | None = None |
|
uploaded: bool = False |
|
_local_to_uploaded: dict | None = None |
|
|
|
@model_validator(mode="after") |
|
def _check_templates_and_masks_content(self) -> Self: |
|
"""Checks if templates used by mask are being uploaded.""" |
|
if not self.templates_masks_files: |
|
return self |
|
|
|
custom_templates_names = [Path(m).name for m in self.templates] |
|
for tm_file in self.templates_masks_files: |
|
tm_collection = TemplateMaskCollection.model_validate_json( |
|
tm_file.read_text() |
|
) |
|
if not ( |
|
all( |
|
tm.template_name in custom_templates_names |
|
for tm in tm_collection.templates_masks |
|
) |
|
): |
|
err = "Templates files are missing. Check your input command." |
|
raise TemplatesMasksSettingsError(err) |
|
return self |
|
|
|
@classmethod |
|
def _from_json_batch_file(cls, batch_jobs_file: Path) -> PredictRequestCustomFiles: |
|
""" |
|
Reads a JSON batch jobs file and extracts custom templates and MSAs. |
|
|
|
Args: |
|
batch_jobs_file (Path): The path to the batch jobs file in JSON format. |
|
|
|
Returns: |
|
An instance of PredictRequestCustomFiles. |
|
""" |
|
custom_templates = [] |
|
custom_msas = [] |
|
initial_guess_files = [] |
|
templates_masks_files = [] |
|
|
|
jobs = json.loads(batch_jobs_file.read_text()) |
|
for req in jobs["requests"]: |
|
tmpl = req["parameters"].get("custom_templates", []) |
|
custom_templates.extend(tmpl) |
|
|
|
msa = req["parameters"].get("custom_msas", []) |
|
custom_msas.extend(msa) |
|
|
|
ig = req["parameters"].get("initial_guess_file") |
|
if ig: |
|
initial_guess_files.append(ig) |
|
|
|
tm = req["parameters"].get("templates_masks_file") |
|
if tm: |
|
templates_masks_files.append(tm) |
|
|
|
return cls( |
|
templates=custom_templates, |
|
msas=custom_msas, |
|
initial_guess_files=initial_guess_files, |
|
templates_masks_files=templates_masks_files, |
|
) |
|
|
|
@classmethod |
|
def _from_csv_batch_file(cls, batch_jobs_file: Path) -> PredictRequestCustomFiles: |
|
""" |
|
Reads a CSV batch jobs file and extracts custom templates and MSAs. |
|
|
|
Args: |
|
batch_jobs_file (Path): The path to the batch jobs file in CSV format. |
|
|
|
Returns: |
|
An instance of PredictRequestCustomFiles. |
|
""" |
|
custom_templates = [] |
|
custom_msas = [] |
|
initial_guess_files = [] |
|
templates_masks_files = [] |
|
|
|
with batch_jobs_file.open("r") as file: |
|
jobs_reader = csv.DictReader( |
|
file, |
|
quotechar='"', |
|
delimiter=",", |
|
quoting=csv.QUOTE_ALL, |
|
) |
|
for row in jobs_reader: |
|
tmpl = row.get("custom_templates") |
|
if tmpl: |
|
tmpl = literal_eval(tmpl) |
|
custom_templates.extend(tmpl) |
|
|
|
msa = row.get("custom_msas") |
|
if msa: |
|
msa = literal_eval(msa) |
|
custom_msas.extend(msa) |
|
|
|
ig = row.get("initial_guess_file") |
|
if ig: |
|
initial_guess_files.extend([ig]) |
|
|
|
tm = row.get("templates_masks_file") |
|
if tm: |
|
templates_masks_files.extend([tm]) |
|
return cls( |
|
templates=custom_templates, |
|
msas=custom_msas, |
|
initial_guess_files=initial_guess_files, |
|
templates_masks_files=templates_masks_files, |
|
) |
|
|
|
@classmethod |
|
def from_batch_jobs_file(cls, batch_jobs_file: Path) -> PredictRequestCustomFiles: |
|
"""Creates a PredictRequestCustomFiles instance from a batch jobs file (CSV or JSON). |
|
|
|
This function reads a batch jobs file, resolves file paths for custom templates and MSAs, |
|
and returns a PredictRequestCustomFiles object. |
|
|
|
Args: |
|
batch_jobs_file (Path): The path to the batch jobs file. Must be a CSV or JSON file. |
|
|
|
Returns: |
|
PredictRequestCustomFiles: An instance containing the custom templates and MSAs. |
|
|
|
Raises: |
|
ValueError: If the file is not a CSV or JSON file. |
|
""" |
|
if batch_jobs_file.suffix == BatchInputFile.CSV: |
|
return cls._from_csv_batch_file(batch_jobs_file) |
|
elif batch_jobs_file.suffix == BatchInputFile.JSON: |
|
return cls._from_json_batch_file(batch_jobs_file) |
|
else: |
|
raise ValueError( |
|
f"Unsupported file type {batch_jobs_file.suffix}: {batch_jobs_file}" |
|
) |
|
|
|
def upload(self, api_key: str | None = None) -> None: |
|
"""Upload local custom paths to GCP through an API request. |
|
Returns: |
|
A dict mapping local to uploaded files path. |
|
""" |
|
if self.uploaded: |
|
print("Custom files already uploaded, skipping upload.") |
|
return self._local_to_uploaded |
|
|
|
local_to_uploaded = {} |
|
|
|
headers = get_auth_headers(api_key) |
|
if len(self.templates) > 0: |
|
_, templates_to_upload = partition_template_pdb_from_file( |
|
custom_templates=self.templates |
|
) |
|
filename_to_gcs_path = upload_custom_files( |
|
headers=headers, |
|
paths=[Path(t) for t in templates_to_upload], |
|
file_type=CustomFileType.TEMPLATE, |
|
) |
|
self.templates = list(filename_to_gcs_path.values()) |
|
local_to_uploaded.update(filename_to_gcs_path) |
|
|
|
if len(self.msas) > 0: |
|
filename_to_gcs_path = upload_custom_files( |
|
headers=headers, |
|
paths=[Path(m) for m in self.msas], |
|
file_type=CustomFileType.MSA, |
|
) |
|
self.msas = list(filename_to_gcs_path.values()) |
|
local_to_uploaded.update(filename_to_gcs_path) |
|
|
|
if self.initial_guess_files: |
|
filename_to_gcs_path = upload_custom_files( |
|
headers=headers, |
|
paths=[Path(ig) for ig in self.initial_guess_files] |
|
if self.initial_guess_files |
|
else self.initial_guess_files, |
|
file_type=CustomFileType.INITIAL_GUESS, |
|
) |
|
self.initial_guess_files = list(filename_to_gcs_path.values()) |
|
local_to_uploaded.update(filename_to_gcs_path) |
|
|
|
if self.templates_masks_files: |
|
|
|
new_tm_files = _replace_tm_file_template_content( |
|
templates_masks_files=self.templates_masks_files, |
|
local_to_uploaded=local_to_uploaded, |
|
) |
|
filename_to_gcs_path = upload_custom_files( |
|
headers=headers, |
|
paths=new_tm_files.values(), |
|
file_type=CustomFileType.TEMPLATE_MASK, |
|
) |
|
for k, v in new_tm_files.items(): |
|
new_tm_files[k] = filename_to_gcs_path[str(v)] |
|
self.templates_masks_files = list(new_tm_files.values()) |
|
local_to_uploaded.update(new_tm_files) |
|
|
|
self.uploaded = True |
|
self._local_to_uploaded = local_to_uploaded |
|
return local_to_uploaded |
|
|
|
|
|
def _replace_tm_file_template_content( |
|
templates_masks_files: List[Path], local_to_uploaded: dict |
|
): |
|
"""Helper function to replace the template name in TM files.""" |
|
new_tm_files = {} |
|
for tm in templates_masks_files: |
|
mask_content = tm.read_text() |
|
for ( |
|
template, |
|
uploaded_file, |
|
) in local_to_uploaded.items(): |
|
mask_content = mask_content.replace( |
|
template.split("/")[-1], |
|
cloudpathlib.CloudPath(uploaded_file).name, |
|
) |
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
temp_file_path = Path(temp_dir) / tm.name |
|
temp_file_path.write_text(mask_content) |
|
new_tm_files[str(tm)] = temp_file_path |
|
return new_tm_files |
|
|