|
"""Utils for validating the FASTA files for the AF3 like models.""" |
|
|
|
import logging |
|
import re |
|
import shutil |
|
from abc import abstractmethod |
|
from collections import defaultdict |
|
from enum import Enum |
|
from pathlib import Path |
|
|
|
from Bio import SeqIO |
|
|
|
|
|
class EntityType(str, Enum): |
|
"""Enum for the entity type of a given sequence.""" |
|
|
|
DNA = "dna" |
|
RNA = "rna" |
|
PROTEIN = "protein" |
|
PEPTIDE = "peptide" |
|
ION = "ion" |
|
LIGAND = "ligand" |
|
SMILES = "smiles" |
|
CCD = "ccd" |
|
|
|
|
|
def get_entity_type(sequence: str) -> EntityType: |
|
"""Get the entity type of a given sequence. |
|
|
|
The entity type is determined based on the sequence composition. |
|
|
|
Args: |
|
sequence (str): The input sequence. |
|
|
|
Returns: |
|
EntityType: The entity type of the input sequence. |
|
""" |
|
DNA_SEQUENCE_SET = set("ACGT") |
|
RNA_SEQUENCE_SET = set("ACGU") |
|
PROTEIN_SEQUENCE_SET = set("ACDEFGHIKLMNPQRSTVWY") |
|
|
|
|
|
if re.fullmatch(r"[A-Za-z]{1,2}[\d\+\-]*", sequence): |
|
return EntityType.ION |
|
|
|
|
|
if set(sequence.upper()).issubset(DNA_SEQUENCE_SET): |
|
return EntityType.DNA |
|
|
|
|
|
elif set(sequence.upper()).issubset(RNA_SEQUENCE_SET): |
|
return EntityType.RNA |
|
|
|
|
|
elif set(sequence.upper()).issubset(PROTEIN_SEQUENCE_SET): |
|
return EntityType.PROTEIN |
|
|
|
|
|
return EntityType.LIGAND |
|
|
|
|
|
def has_multiple_chains(header: str) -> bool: |
|
"""Check if a given header contains multiple chains in RCSB format. |
|
|
|
A header with multiple chains will have the following in the header: |
|
``` |
|
Chains A, B, C, ... |
|
``` |
|
where `A`, `B`, `C`, ... are the chain identifiers. |
|
|
|
|
|
Args: |
|
header (str): The input header string containing chain information. |
|
|
|
Returns: |
|
bool: True if the header contains multiple chains, False otherwise. |
|
""" |
|
match = re.search(r"chains?\s+([A-Za-z, ]+)", header, re.I) |
|
return len(match.group(1).replace(" ", "").split(",")) > 1 if match else False |
|
|
|
|
|
class BaseFastaValidator: |
|
"""Base class for validating FASTA files.""" |
|
|
|
@abstractmethod |
|
def is_valid_fasta(self, fasta_path: Path) -> tuple[bool, str | None]: |
|
"""Validate whether a given FASTA file follows the required format. |
|
|
|
Args: |
|
fasta_path (Path): Path to the FASTA file. |
|
|
|
Returns: |
|
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not |
|
""" |
|
raise NotImplementedError("Subclasses must implement this method") |
|
|
|
@abstractmethod |
|
def transform_fasta(self, fasta_path: Path) -> str: |
|
"""Transform a FASTA file into the required format. |
|
|
|
Args: |
|
fasta_path (Path): Path to the FASTA file. |
|
|
|
Returns: |
|
Transformed FASTA content in the required format. |
|
""" |
|
raise NotImplementedError("Subclasses must implement this method") |
|
|
|
def process_directory(self, input_dir: str, output_dir: str) -> None: |
|
"""Process all FASTA files in the input directory, validate or transform them, and save them to the output directory. |
|
|
|
Args: |
|
input_dir (str): Path to the directory containing FASTA files. |
|
output_dir (str): Path to the output directory where processed files will be saved. |
|
""" |
|
|
|
output_path = Path(output_dir) |
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
for fasta_file in Path(input_dir).glob("*.fasta"): |
|
output_file = output_path / fasta_file.name |
|
if has_multiple_chains(fasta_file.read_text()): |
|
logging.warning( |
|
f"Skipping {fasta_file} because it contains multiple chains in a single sequence.\n" |
|
"Please split multiple chains into separate sequences using the following format:\n" |
|
">Chain A\n" |
|
"MTEIVLKFL...\n" |
|
">Chain B\n" |
|
"MTEIVLKFL...\n\n" |
|
"Instead of:\n" |
|
">Chains A, B\n" |
|
"MTEIVLKFL..." |
|
) |
|
continue |
|
if self.is_valid_fasta(fasta_file): |
|
shutil.copy(fasta_file, output_file) |
|
else: |
|
transformed_content = self.transform_fasta(fasta_file) |
|
output_file.write_text(transformed_content) |
|
|
|
|
|
class BoltzFastaValidator(BaseFastaValidator): |
|
"""Validate whether a given FASTA file follows the required format for Boltz.""" |
|
|
|
SUPPORTED_ENTITY_TYPES = { |
|
EntityType.PROTEIN, |
|
EntityType.RNA, |
|
EntityType.DNA, |
|
EntityType.SMILES, |
|
EntityType.CCD, |
|
} |
|
|
|
def is_valid_fasta(self, fasta_path: Path) -> tuple[bool, str | None]: |
|
"""Validate whether a given FASTA file follows the required format. |
|
|
|
The expected FASTA header format is: |
|
``` |
|
>CHAIN_ID|ENTITY_TYPE |
|
``` |
|
where `ENTITY_TYPE` must be one of: "protein", "rna", "dna", "smiles" or "ccd". |
|
|
|
Args: |
|
fasta_path (Path): Path to the FASTA file. |
|
|
|
Returns: |
|
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not |
|
""" |
|
with fasta_path.open("r") as f: |
|
for record in SeqIO.parse(f, "fasta"): |
|
header_parts = record.id.split("|") |
|
if not (1 < len(header_parts) <= 3): |
|
msg = "BOLTZ Validation Error: Invalid header format. Expected '>CHAIN_ID|ENTITY_TYPE'" |
|
return False, msg |
|
if header_parts[1].lower() not in self.SUPPORTED_ENTITY_TYPES: |
|
return ( |
|
False, |
|
f"BOLTZ Validation Error: Invalid entity type '{header_parts[1]}'. Supported types: {', '.join(self.SUPPORTED_ENTITY_TYPES)}", |
|
) |
|
return True, None |
|
|
|
def transform_fasta(self, fasta_path: Path) -> str: |
|
"""Transform a FASTA file into the '>CHAIN_ID|ENTITY_TYPE|MSA_ID' format. |
|
|
|
This function extracts chain identifiers from the FASTA header and determines |
|
the entity type (DNA, RNA, or PROTEIN) based on the sequence composition. |
|
|
|
Args: |
|
fasta_path (Path): Path to the FASTA file. |
|
|
|
Returns: |
|
Transformed FASTA content in the required format. |
|
""" |
|
transformed_lines = [] |
|
|
|
with fasta_path.open("r") as f: |
|
for record_index, record in enumerate(SeqIO.parse(f, "fasta")): |
|
chain = chr(ord("A") + record_index) |
|
|
|
entity_type = get_entity_type(str(record.seq)) |
|
transformed_lines.append(f">{chain.upper()}|{entity_type.value}") |
|
|
|
transformed_lines.append(str(record.seq)) |
|
|
|
return "\n".join(transformed_lines) |
|
|
|
|
|
class ChaiFastaValidator(BaseFastaValidator): |
|
"""Validate whether a given FASTA file follows the required format for Chai.""" |
|
|
|
SUPPORTED_ENTITY_TYPES = EntityType.__members__.values() |
|
|
|
def is_valid_fasta(self, fasta_path: Path) -> tuple[bool, str | None]: |
|
"""Validate whether a given FASTA file follows the required format. |
|
|
|
The expected FASTA header format is: |
|
``` |
|
>ENTITY_TYPE|name=NAME |
|
``` |
|
Args: |
|
fasta_path (Path): Path to the FASTA file. |
|
|
|
Returns: |
|
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not |
|
""" |
|
|
|
seen_names = set() |
|
with fasta_path.open("r") as f: |
|
for record in SeqIO.parse(f, "fasta"): |
|
|
|
match = re.match(r"^([A-Za-z]+)\|name=([\w\-]+)$", record.description) |
|
if not match: |
|
return ( |
|
False, |
|
"CHAI Validation Error: Invalid header format. Expected '>ENTITY_TYPE|name=NAME'", |
|
) |
|
|
|
entity_type, name = match.groups() |
|
if entity_type not in self.SUPPORTED_ENTITY_TYPES or not name: |
|
return ( |
|
False, |
|
f"CHAI Validation Error: Invalid entity type '{entity_type}'. Supported types: {', '.join(self.SUPPORTED_ENTITY_TYPES)}", |
|
) |
|
|
|
if name in seen_names: |
|
return ( |
|
False, |
|
f"CHAI Validation Error: Duplicate name '{name}'. Each sequence must have a unique name", |
|
) |
|
seen_names.add(name) |
|
|
|
sequence = str(record.seq).strip() |
|
if ( |
|
entity_type in {EntityType.PEPTIDE, EntityType.PROTEIN} |
|
and not get_entity_type(sequence) == entity_type |
|
): |
|
return ( |
|
False, |
|
f"CHAI Validation Error: Sequence type mismatch. Expected '{entity_type}' but found '{get_entity_type(sequence)}'", |
|
) |
|
|
|
return True, None |
|
|
|
def transform_fasta(self, fasta_path: Path) -> str: |
|
"""Transform a FASTA file into the '>TYPE|name=NAME' format by ensuring each main header |
|
is unique (adding a number if necessary). |
|
|
|
The expected output format is: |
|
'>protein|name=NAME' |
|
'SEQUENCE' |
|
|
|
Args: |
|
fasta_path (Path): Path to the FASTA file. |
|
|
|
Returns: |
|
Transformed FASTA content in the required Chai format. |
|
""" |
|
transformed_lines = [] |
|
header_map = {} |
|
|
|
with fasta_path.open("r") as f: |
|
for record in SeqIO.parse(f, "fasta"): |
|
main_header = record.description.split("|")[0].strip() |
|
|
|
if main_header not in header_map: |
|
header_map[main_header] = 1 |
|
updated_header = main_header |
|
|
|
else: |
|
header_map[main_header] += 1 |
|
updated_header = main_header + "_" + str(header_map[main_header]) |
|
|
|
entity_type = get_entity_type(str(record.seq)) |
|
header = f">{entity_type.value}|name={updated_header}" |
|
|
|
transformed_lines.append(header) |
|
transformed_lines.append(str(record.seq)) |
|
|
|
return "\n".join(transformed_lines) |
|
|
|
|
|
class ProtenixFastaValidator(BaseFastaValidator): |
|
"""Validate whether a given FASTA file follows the required format for Protenix.""" |
|
|
|
def is_valid_fasta(self, fasta_path: Path) -> tuple[bool, str | None]: |
|
"""Validate whether a given FASTA file follows the required format. |
|
|
|
The expected FASTA header format is: |
|
``` |
|
> UNIQUE ID[|...] |
|
``` |
|
|
|
Args: |
|
fasta_path (Path): Path to the FASTA file. |
|
|
|
Returns: |
|
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not |
|
""" |
|
seen_headers = set() |
|
|
|
with fasta_path.open("r") as f: |
|
for record in SeqIO.parse(f, "fasta"): |
|
main_header = record.description.split("|")[0].strip() |
|
if main_header in seen_headers: |
|
return ( |
|
False, |
|
f"PROTENIX Validation Error: Duplicate header '{main_header}'. Each sequence must have a unique header", |
|
) |
|
seen_headers.add(main_header) |
|
|
|
return True, None |
|
|
|
def transform_fasta(self, fasta_path: Path) -> str: |
|
"""Transform a FASTA file into the '>NAME|Chain X' format by ensuring each main header |
|
is unique (adding a number if necessary). |
|
|
|
The expected output format is: |
|
'>protein_1 | Chain A' |
|
'SEQUENCE' |
|
'>protein_2 | Chain B' |
|
'SEQUENCE' |
|
|
|
Args: |
|
fasta_path (Path): Path to the FASTA file. |
|
|
|
Returns: |
|
Transformed FASTA content in the required Protenix format. |
|
""" |
|
transformed_lines = [] |
|
header_count = defaultdict(int) |
|
|
|
with fasta_path.open("r") as f: |
|
for record in SeqIO.parse(f, "fasta"): |
|
header_parts = [part.strip() for part in record.description.split("|")] |
|
main_header = header_parts[0] |
|
|
|
|
|
header_count[main_header] += 1 |
|
updated_main_header = ( |
|
f"{main_header}_{header_count[main_header]}" |
|
if header_count[main_header] > 1 |
|
else main_header |
|
) |
|
|
|
transformed_lines.append(f">{updated_main_header}") |
|
transformed_lines.append(str(record.seq)) |
|
|
|
return "\n".join(transformed_lines) |
|
|