chengzhang1006's picture
add more informations (#15)
01fba1c verified
"""API entries data model."""
from __future__ import annotations
import csv
import json
import tempfile
from ast import literal_eval
from enum import Enum
from pathlib import Path
from typing import List
import cloudpathlib
from folding_studio_data_models import CustomFileType, FeatureMode
from folding_studio_data_models.content import TemplateMaskCollection
from folding_studio_data_models.exceptions import (
TemplatesMasksSettingsError,
)
from pydantic import BaseModel, ConfigDict, model_validator
from rich import print # pylint:disable=redefined-builtin
from typing_extensions import Self
from folding_studio.api_call.upload_custom_files import upload_custom_files
from folding_studio.utils.file_helpers import (
partition_template_pdb_from_file,
)
from folding_studio.utils.headers import get_auth_headers
class SimpleInputFile(str, Enum):
"""Supported simple prediction file source extensions."""
FASTA = ".fasta"
class BatchInputFile(str, Enum):
"""Supported batch prediction file source extensions."""
CSV = ".csv"
JSON = ".json"
class PredictRequestParams(BaseModel):
"""Prediction parameters model."""
ignore_cache: bool
template_mode: FeatureMode
custom_template_ids: List[str]
msa_mode: FeatureMode
max_msa_clusters: int
max_extra_msa: int
gap_trick: bool
num_recycle: int
random_seed: int
model_subset: set[int]
model_config = ConfigDict(protected_namespaces=())
class MSARequestParams(BaseModel):
"""MSA parameters model."""
ignore_cache: bool
msa_mode: FeatureMode
class PredictRequestCustomFiles(BaseModel):
"""Prediction custom files model."""
templates: List[Path | str]
msas: List[Path]
initial_guess_files: List[Path] | None = None
templates_masks_files: List[Path] | None = None
uploaded: bool = False
_local_to_uploaded: dict | None = None
@model_validator(mode="after")
def _check_templates_and_masks_content(self) -> Self:
"""Checks if templates used by mask are being uploaded."""
if not self.templates_masks_files:
return self
custom_templates_names = [Path(m).name for m in self.templates]
for tm_file in self.templates_masks_files:
tm_collection = TemplateMaskCollection.model_validate_json(
tm_file.read_text()
)
if not (
all(
tm.template_name in custom_templates_names
for tm in tm_collection.templates_masks
)
):
err = "Templates files are missing. Check your input command."
raise TemplatesMasksSettingsError(err)
return self
@classmethod
def _from_json_batch_file(cls, batch_jobs_file: Path) -> PredictRequestCustomFiles:
"""
Reads a JSON batch jobs file and extracts custom templates and MSAs.
Args:
batch_jobs_file (Path): The path to the batch jobs file in JSON format.
Returns:
An instance of PredictRequestCustomFiles.
"""
custom_templates = []
custom_msas = []
initial_guess_files = []
templates_masks_files = []
jobs = json.loads(batch_jobs_file.read_text())
for req in jobs["requests"]:
tmpl = req["parameters"].get("custom_templates", [])
custom_templates.extend(tmpl)
msa = req["parameters"].get("custom_msas", [])
custom_msas.extend(msa)
ig = req["parameters"].get("initial_guess_file")
if ig:
initial_guess_files.append(ig)
tm = req["parameters"].get("templates_masks_file")
if tm:
templates_masks_files.append(tm)
return cls(
templates=custom_templates,
msas=custom_msas,
initial_guess_files=initial_guess_files,
templates_masks_files=templates_masks_files,
)
@classmethod
def _from_csv_batch_file(cls, batch_jobs_file: Path) -> PredictRequestCustomFiles:
"""
Reads a CSV batch jobs file and extracts custom templates and MSAs.
Args:
batch_jobs_file (Path): The path to the batch jobs file in CSV format.
Returns:
An instance of PredictRequestCustomFiles.
"""
custom_templates = []
custom_msas = []
initial_guess_files = []
templates_masks_files = []
with batch_jobs_file.open("r") as file:
jobs_reader = csv.DictReader(
file,
quotechar='"',
delimiter=",",
quoting=csv.QUOTE_ALL,
)
for row in jobs_reader:
tmpl = row.get("custom_templates")
if tmpl:
tmpl = literal_eval(tmpl)
custom_templates.extend(tmpl)
msa = row.get("custom_msas")
if msa:
msa = literal_eval(msa)
custom_msas.extend(msa)
ig = row.get("initial_guess_file")
if ig:
initial_guess_files.extend([ig])
tm = row.get("templates_masks_file")
if tm:
templates_masks_files.extend([tm])
return cls(
templates=custom_templates,
msas=custom_msas,
initial_guess_files=initial_guess_files,
templates_masks_files=templates_masks_files,
)
@classmethod
def from_batch_jobs_file(cls, batch_jobs_file: Path) -> PredictRequestCustomFiles:
"""Creates a PredictRequestCustomFiles instance from a batch jobs file (CSV or JSON).
This function reads a batch jobs file, resolves file paths for custom templates and MSAs,
and returns a PredictRequestCustomFiles object.
Args:
batch_jobs_file (Path): The path to the batch jobs file. Must be a CSV or JSON file.
Returns:
PredictRequestCustomFiles: An instance containing the custom templates and MSAs.
Raises:
ValueError: If the file is not a CSV or JSON file.
"""
if batch_jobs_file.suffix == BatchInputFile.CSV:
return cls._from_csv_batch_file(batch_jobs_file)
elif batch_jobs_file.suffix == BatchInputFile.JSON:
return cls._from_json_batch_file(batch_jobs_file)
else:
raise ValueError(
f"Unsupported file type {batch_jobs_file.suffix}: {batch_jobs_file}"
)
def upload(self, api_key: str | None = None) -> None:
"""Upload local custom paths to GCP through an API request.
Returns:
A dict mapping local to uploaded files path.
"""
if self.uploaded:
print("Custom files already uploaded, skipping upload.")
return self._local_to_uploaded
local_to_uploaded = {}
headers = get_auth_headers(api_key)
if len(self.templates) > 0:
_, templates_to_upload = partition_template_pdb_from_file(
custom_templates=self.templates
)
filename_to_gcs_path = upload_custom_files(
headers=headers,
paths=[Path(t) for t in templates_to_upload],
file_type=CustomFileType.TEMPLATE,
)
self.templates = list(filename_to_gcs_path.values())
local_to_uploaded.update(filename_to_gcs_path)
if len(self.msas) > 0:
filename_to_gcs_path = upload_custom_files(
headers=headers,
paths=[Path(m) for m in self.msas],
file_type=CustomFileType.MSA,
)
self.msas = list(filename_to_gcs_path.values())
local_to_uploaded.update(filename_to_gcs_path)
if self.initial_guess_files:
filename_to_gcs_path = upload_custom_files(
headers=headers,
paths=[Path(ig) for ig in self.initial_guess_files]
if self.initial_guess_files
else self.initial_guess_files,
file_type=CustomFileType.INITIAL_GUESS,
)
self.initial_guess_files = list(filename_to_gcs_path.values())
local_to_uploaded.update(filename_to_gcs_path)
if self.templates_masks_files:
# Replace content of tm files to match the uploaded template file
new_tm_files = _replace_tm_file_template_content(
templates_masks_files=self.templates_masks_files,
local_to_uploaded=local_to_uploaded,
)
filename_to_gcs_path = upload_custom_files(
headers=headers,
paths=new_tm_files.values(),
file_type=CustomFileType.TEMPLATE_MASK,
)
for k, v in new_tm_files.items():
new_tm_files[k] = filename_to_gcs_path[str(v)]
self.templates_masks_files = list(new_tm_files.values())
local_to_uploaded.update(new_tm_files)
self.uploaded = True
self._local_to_uploaded = local_to_uploaded
return local_to_uploaded
def _replace_tm_file_template_content(
templates_masks_files: List[Path], local_to_uploaded: dict
):
"""Helper function to replace the template name in TM files."""
new_tm_files = {}
for tm in templates_masks_files:
mask_content = tm.read_text()
for (
template,
uploaded_file,
) in local_to_uploaded.items():
mask_content = mask_content.replace(
template.split("/")[-1],
cloudpathlib.CloudPath(uploaded_file).name,
)
# Get the default temporary directory
# and write a new tm file which contains the uploaded template file name
temp_dir = tempfile.gettempdir()
temp_file_path = Path(temp_dir) / tm.name
temp_file_path.write_text(mask_content)
new_tm_files[str(tm)] = temp_file_path
return new_tm_files