"""Transducer decoder interface module."""

from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch


@dataclass
class Hypothesis:
    """Default hypothesis definition for beam search."""

    score: float
    yseq: List[int]
    dec_state: Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]], List[torch.Tensor], torch.Tensor
    ]
    lm_state: Union[Dict[str, Any], List[Any]] = None


@dataclass
class NSCHypothesis(Hypothesis):
    """Extended hypothesis definition for NSC beam search."""

    y: List[torch.Tensor] = None
    lm_scores: torch.Tensor = None


class TransducerDecoderInterface:
    """Decoder interface for transducer models."""

    def init_state(
        self,
        batch_size: int,
        device: torch.device,
    ) -> Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
    ]:
        """Initialize decoder states.

        Args:
            batch_size: Batch size for initial state
            device: Device for initial state

        Returns:
            state: Initialized state

        """
        raise NotImplementedError("init_state method is not implemented")

    def score(
        self,
        hyp: Union[Hypothesis, NSCHypothesis],
        cache: Dict[str, Any],
    ) -> Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]],
        torch.Tensor,
        List[Optional[torch.Tensor]],
    ]:
        """Forward one hypothesis.

        Args:
            hyp: Hypothesis.
            cache: Pairs of (y, state) for each token sequence (key)

        Returns:
            y: Decoder outputs
            new_state: New decoder state
            lm_tokens: Token id for LM

        """
        raise NotImplementedError("score method is not implemented")

    def batch_score(
        self,
        hyps: Union[List[Hypothesis], List[NSCHypothesis]],
        batch_states: Union[
            Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
        ],
        cache: Dict[str, Any],
    ) -> Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]],
        torch.Tensor,
        List[Optional[torch.Tensor]],
    ]:
        """Forward batch of hypotheses.

        Args:
            hyps: Batch of hypotheses
            batch_states: Batch of decoder states
            cache: pairs of (y, state) for each token sequence (key)

        Returns:
            batch_y: Decoder outputs
            batch_states: Batch of decoder states
            lm_tokens: Batch of token ids for LM

        """
        raise NotImplementedError("batch_score method is not implemented")

    def select_state(
        self,
        batch_states: Union[
            Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
        ],
        idx: int,
    ) -> Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
    ]:
        """Get decoder state from batch for given id.

        Args:
            batch_states: Batch of decoder states
            idx: Index to extract state from batch

        Returns:
            state_idx: Decoder state for given id

        """
        raise NotImplementedError("select_state method is not implemented")

    def create_batch_states(
        self,
        batch_states: Union[
            Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
        ],
        l_states: List[
            Union[
                Tuple[torch.Tensor, Optional[torch.Tensor]],
                List[Optional[torch.Tensor]],
            ]
        ],
        l_tokens: List[List[int]],
    ) -> Union[
        Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
    ]:
        """Create batch of decoder states.

        Args:
            batch_states: Batch of decoder states
            l_states: List of decoder states
            l_tokens: List of token sequences for input batch

        Returns:
            batch_states: Batch of decoder states

        """
        raise NotImplementedError("create_batch_states method is not implemented")