# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple


@dataclass
class LocalColabFoldConfig:
    """Configuration for ColabFold search."""

    colabsearch: str
    query_fpath: str
    db_dir: str
    results_dir: str
    mmseqs_path: Optional[str] = None
    db1: str = "uniref30_2302_db"
    db2: Optional[str] = None
    db3: Optional[str] = "colabfold_envdb_202108_db"
    use_env: int = 1
    filter: int = 1
    db_load_mode: int = 0


class A3MProcessor:
    """Processor for A3M file format."""

    def __init__(self, a3m_file: str, out_dir: str):
        self.out_dir = out_dir
        self.a3m_file = Path(a3m_file)
        self.a3m_content = self._read_a3m_file()
        self.chain_info = self._parse_header()

    def _read_a3m_file(self) -> str:
        """Read A3M file content."""
        return self.a3m_file.read_text()

    def _parse_header(self) -> Tuple[List[str], Dict[str, Tuple[int, int]]]:
        """Parse A3M header to get chain information."""
        first_line = self.a3m_content.split("\n")[0]
        if first_line[0] == "#":
            lengths, oligomeric_state = first_line.split("\t")

            chain_lengths = [int(x) for x in lengths[1:].split(",")]
            chain_names = [f"10{x+1}" for x in range(len(oligomeric_state.split(",")))]

            # Calculate sequence ranges for each chain
            seq_ranges = {}
            for i, name in enumerate(chain_names):
                start = sum(chain_lengths[:i])
                end = sum(chain_lengths[: i + 1])
                seq_ranges[name] = (start, end)

            return chain_names, seq_ranges
        else:
            non_pairing = ">query\n" + "\n".join(self.a3m_content.split("\n")[1:])
            query_seq = self.a3m_content.split("\n")[1]
            pairing = f">query\n{query_seq}"
            msa_path = Path(self.out_dir) / "msa"
            msa_path.mkdir(exist_ok=True)
            msa_path = msa_path / "0"
            msa_path.mkdir(exist_ok=True)
            with open(msa_path / "non_pairing.a3m", "w") as f:
                f.write(non_pairing)

            with open(msa_path / "pairing.a3m", "w") as f:
                f.write(pairing)

            return [None]

    def _extract_sequence(self, line: str, range_tuple: Tuple[int, int]) -> str:
        """Extract sequence for specific range."""
        seq = []
        no_insert_count = 0
        start, end = range_tuple

        for char in line:
            if char.isupper() or char == "-":
                no_insert_count += 1
            # we keep insertions
            if start < no_insert_count <= end:
                seq.append(char)
            elif no_insert_count > end:
                break

        return "".join(seq)

    def split_sequences(self) -> None:
        """Split A3M file into pairing and non-pairing sequences."""
        out_dir = Path(self.out_dir) / "msa"
        chain_names, seq_ranges = self.chain_info

        pairing_a3ms = {name: [] for name in chain_names}
        nonpairing_a3ms = {name: [] for name in chain_names}

        current_query = None
        for line in self.a3m_content.split("\n"):
            if line.startswith("#"):
                continue

            if line.startswith(">"):
                name = line[1:]
                if name in chain_names:
                    current_query = chain_names[chain_names.index(name)]
                elif name == "\t".join(chain_names):
                    current_query = None

                # Add header line to appropriate dictionary
                if current_query:
                    nonpairing_a3ms[current_query].append(line)
                else:
                    for name in chain_names:
                        pairing_a3ms[name].append(line)
                continue

            # Process sequence line
            if not line:
                continue

            if current_query:
                seq = self._extract_sequence(line, seq_ranges[current_query])
                nonpairing_a3ms[current_query].append(seq)
            else:
                for name in chain_names:
                    seq = self._extract_sequence(line, seq_ranges[name])
                    pairing_a3ms[name].append(seq)

        self._write_output_files(out_dir, nonpairing_a3ms, pairing_a3ms)

    def _write_output_files(
        self,
        out_dir: Path,
        nonpairing_a3ms: Dict[str, List[str]],
        pairing_a3ms: Dict[str, List[str]],
    ) -> None:
        """Write split sequences to output files."""
        out_dir.mkdir(exist_ok=True)

        # Write non-pairing sequences
        for i, (name, lines) in enumerate(nonpairing_a3ms.items()):
            chain_dir = out_dir / str(i)
            chain_dir.mkdir(exist_ok=True)

            with open(chain_dir / "non_pairing.a3m", "w") as f:
                query_seq = lines[1]
                f.write(">query\n")
                f.write(f"{query_seq}\n")
                f.write("\n".join(lines[2:]))

        # Write pairing sequences
        for i, (name, lines) in enumerate(pairing_a3ms.items()):
            chain_dir = out_dir / str(i)
            chain_dir.mkdir(exist_ok=True)

            with open(chain_dir / "pairing.a3m", "w") as f:
                query_seq = lines[1]
                f.write(">query\n")
                f.write(f"{query_seq}\n")

                # Process remaining sequences
                sequences = {}
                for j, line in enumerate(lines[2:]):
                    if line.startswith(">"):
                        current_name = f"UniRef100_{line[1:].split()[i]}_{j}"
                        sequences[current_name] = ""
                    elif line and "DUMMY" not in current_name:
                        sequences[current_name] = line

                # Write processed sequences
                for seq_name, seq in sequences.items():
                    if seq:  # Only write non-empty sequences
                        f.write(f">{seq_name}\n{seq}\n")


def run_colabfold_search(config: LocalColabFoldConfig) -> str:
    """Run ColabFold search with given configuration."""
    cmd = [config.colabsearch, config.query_fpath, config.db_dir, config.results_dir]

    # Add optional parameters
    if config.db1:
        cmd.extend(["--db1", config.db1])
    if config.db2:
        cmd.extend(["--db2", config.db2])
    if config.db3:
        cmd.extend(["--db3", config.db3])
    if config.mmseqs_path:
        cmd.extend(["--mmseqs", config.mmseqs_path])
    else:
        cmd.extend(["--mmseqs", "mmseqs"])
    if config.use_env:
        cmd.extend(["--use-env", str(config.use_env)])
    if config.filter:
        cmd.extend(["--filter", str(config.filter)])
    if config.db_load_mode:
        cmd.extend(["--db-load-mode", str(config.db_load_mode)])

    cmd = " ".join(cmd)
    os.system(cmd)

    # Return the first .a3m file found in results directory
    result_files = list(Path(config.results_dir).glob("*.a3m"))
    if not result_files:
        raise FileNotFoundError(f"No .a3m files found in {config.results_dir}")
    return str(result_files[0])


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="ColabFold search and A3M processing tool",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # Required arguments
    parser.add_argument("query_fpath", help="Path to the query FASTA file")
    parser.add_argument("db_dir", help="Directory containing the databases")
    parser.add_argument("results_dir", help="Directory for storing results")

    # Optional arguments
    parser.add_argument(
        "--colabsearch", help="Path to colabfold_search", default="colabfold_search"
    )
    parser.add_argument(
        "--mmseqs_path", help="Path to MMseqs2 binary", default="mmseqs"
    )
    parser.add_argument("--db1", help="First database name", default="uniref30_2302_db")
    parser.add_argument("--db2", help="Templates database")
    parser.add_argument(
        "--db3", help="Environmental database (default: colabfold_envdb_202108_db)"
    )
    parser.add_argument(
        "--use_env", help="Use environment settings", type=int, default=1
    )
    parser.add_argument("--filter", help="Apply filtering", type=int, default=1)
    parser.add_argument(
        "--db_load_mode", help="Database load mode", type=int, default=0
    )
    parser.add_argument(
        "--output_split", help="Directory for split A3M files", default=None
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    # Create configuration from arguments
    config = LocalColabFoldConfig(
        colabsearch=args.colabsearch,
        query_fpath=args.query_fpath,
        db_dir=args.db_dir,
        results_dir=args.results_dir,
        mmseqs_path=args.mmseqs_path,
        db1=args.db1,
        db2=args.db2,
        db3=args.db3,
        use_env=args.use_env,
        filter=args.filter,
        db_load_mode=args.db_load_mode,
    )

    # Run search
    results_a3m = run_colabfold_search(config)

    processor = A3MProcessor(results_a3m, args.results_dir)
    if len(processor.chain_info) == 2:
        processor.split_sequences()