import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from gymnasium import spaces
from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn
from transformers import GPTNeoModel, GPTNeoPreTrainedModel
from transformers.modeling_outputs import ModelOutput
from transformers.models.vit.modeling_vit import ViTPatchEmbeddings

from .configuration_jat import JatConfig
from .processing_jat import JatProcessor


def compute_mse_loss(
    predicted: FloatTensor, true: FloatTensor, mask: Optional[BoolTensor], weights: Optional[FloatTensor] = None
) -> FloatTensor:
    """
    Compute the Mean Squared Error (MSE) loss between predicted and true observations, considering valid timesteps.

    Args:
        predicted (`FloatTensor` of shape `(batch_size, max_seq_len, ...)`):
            Predicted observations at the output of the model.
        true (`FloatTensor` of shape `(batch_size, max_seq_len, ...)`):
            Ground truth observations.
        mask (`BoolTensor` of shape `(batch_size, max_seq_len)`, *optional*):
            Boolean mask indicating valid timesteps.
        weights (`FloatTensor` of shape `(batch_size, max_seq_len)`, *optional*):
            Weights to be applied to the loss.

    Returns:
        loss (`FloatTensor` of shape `(,)`):
            MSE loss between predicted and true observations.
    """
    # Compute element-wise MSE loss
    loss = F.mse_loss(predicted, true, reduction="none")

    # Average the loss over all dimensions after the second one
    for dim in reversed(range(2, loss.dim())):
        loss = loss.mean(dim=dim)

    # Use the mask to zero out invalid entries
    if mask is not None:
        loss = loss * mask

    # Apply weights if provided
    if weights is not None:
        loss = loss * weights

    # Sum the loss and normalize by the number of valid elements
    loss = loss.sum() / mask.sum() if mask is not None else loss.mean()

    return loss


def compute_ce_loss(
    logits: FloatTensor, labels: torch.LongTensor, mask: Optional[BoolTensor], weights: Optional[FloatTensor] = None
) -> FloatTensor:
    """
    Compute the Cross Entropy (CE) loss between predicted logits and true class labels, considering valid timesteps.

    Args:
        logits (`FloatTensor` of shape `(batch_size, max_seq_len, [inner_size,] num_classes)`):
            Predicted logits at the output of the model.
        labels (`torch.LongTensor` of shape `(batch_size, max_seq_len, [inner_size,])`):
            Ground truth class labels.
        mask (`BoolTensor` of shape `(batch_size, max_seq_len)`, *optional*):
            Boolean mask indicating valid timesteps.
        weights (`FloatTensor` of shape `(batch_size, max_seq_len)`, *optional*):
            Weights to be applied to the loss.

    Returns:
        loss (`FloatTensor` of shape `(,)`):
            CE loss between predicted logits and true class labels.
    """
    if mask is not None:
        logits = logits[mask.bool()]  # (Y, X, C)
        labels = labels[mask.bool()]  # (Y, X)
        if weights is not None:
            weights = weights[mask.bool()]  # (Y,)
    else:
        logits = logits.flatten(end_dim=2)  # (B, L, X, C) -> (B*L, X, C)
        labels = labels.flatten(end_dim=1)  # (B, L, X) -> (B*L, X)
        if weights is not None:
            weights = weights.flatten(end_dim=1)  # (B, L) -> (B*L,)

    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), reduction="none")  # (Y*X,)
    loss = loss.view(labels.size())  # (Y, X)
    loss = loss.mean(-1)  # (Y,)

    # Multiply the loss by the weights
    if weights is not None:
        loss = loss * weights  # (Y,)

    # Average the loss
    loss = loss.mean()

    return loss


def cyclic_expand_dim(tensor: Tensor, expanded_dim_size: int) -> Tensor:
    """
    Expands the last dimension of a tensor cyclically to a specified size.

    Args:
        tensor (`torch.Tensor` of shape `(batch_size, seq_len, ...)`):
            Input tensor whose last dimension is to be expanded cyclically.
        expanded_dim_size (`int`):
            The desired size of the last dimension after expansion.

    Returns:
        `torch.Tensor` of shape `(batch_size, seq_len, expanded_dim_size)`:
            A tensor with its last dimension expanded cyclically to the specified size.

    Examples:
        >>> tensor = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
        >>> cyclic_expand_dim(tensor, 5)
        tensor([[[1, 2, 1, 2, 1], [3, 4, 3, 4, 3]], [[5, 6, 5, 6, 5], [7, 8, 7, 8, 7]]])
    """
    B, L, X = tensor.shape
    if expanded_dim_size < X:
        raise ValueError(
            f"Expanded dimension size ({expanded_dim_size}) must be greater than the original dimension size ({X})."
        )
    indices = torch.arange(expanded_dim_size) % X
    return tensor[..., indices]


class ResidualBlock(nn.Module):
    """
    A residual block module that consists of two convolutional layers with a residual connection.

    Args:
        in_shape (`Tuple[int, int, int]`):
            Shape of the input tensor.
        out_channels (`int`):
            Number of output channels.

    Returns:
        `torch.Tensor` of shape `(batch_size, out_channels, in_shape[1], in_shape[2])`:
            Output tensor.
    """

    def __init__(self, in_shape: Tuple[int, int, int], out_channels: int) -> None:
        super().__init__()
        out_shape = (out_channels, in_shape[1], in_shape[2])

        self.conv1 = nn.Conv2d(in_shape[0], out_channels, kernel_size=3, stride=1, padding=1)
        self.norm1 = nn.LayerNorm(out_shape)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.LayerNorm(out_shape)

        # Handling the change in dimensions with a 1x1 convolution
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_shape[0], out_channels, kernel_size=1, stride=1), nn.LayerNorm(out_shape)
        )

    def forward(self, x: FloatTensor) -> FloatTensor:
        out = F.leaky_relu(self.norm1(self.conv1(x)))
        out = self.norm2(self.conv2(out))
        out += self.shortcut(x)
        return F.leaky_relu(out, inplace=True)


class AttentionLayer(nn.Module):
    """
    Attention layer that applies an attention mechanism to the input tensor.

    Args:
        num_channels (`int`):
            Number of channels.

    Returns:
        `torch.Tensor`:
            Output tensor of the same shape as the input tensor.
    """

    def __init__(self, num_channels: int) -> None:
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(num_channels, num_channels // 8, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(num_channels // 8, num_channels, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x: FloatTensor) -> FloatTensor:
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class ImageEncoder(nn.Module):
    """
    Image encoder that encodes a batch of images.

    Args:
        hidden_size (`int`):
            Size of the output hidden state.

    Returns:
        `torch.Tensor` of shape `(batch_size, hidden_size)`:
            Output tensor.
    """

    def __init__(self, hidden_size: int) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=3, stride=2, padding=1)  # 42x42
        self.norm1 = nn.InstanceNorm2d(32)
        self.att1 = AttentionLayer(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  # 21x21
        self.norm2 = nn.InstanceNorm2d(64)
        self.att2 = AttentionLayer(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)  # 11x11
        self.norm3 = nn.InstanceNorm2d(128)
        self.att3 = AttentionLayer(128)
        self.fc = nn.Linear(128 * 11 * 11, hidden_size)  # Adjusted to the new spatial dimension

    def forward(self, x: FloatTensor) -> FloatTensor:
        x = F.leaky_relu(self.norm1(self.conv1(x)), inplace=True)
        x = self.att1(x)
        x = F.leaky_relu(self.norm2(self.conv2(x)), inplace=True)
        x = self.att2(x)
        x = F.leaky_relu(self.norm3(self.conv3(x)), inplace=True)
        x = self.att3(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return x


class ImageDecoder(nn.Module):
    """
    Image decoder that decodes a batch of encoded representations.

    Args:
        hidden_size (`int`):
            Size of the input hidden state.

    Returns:
        `torch.Tensor` of shape `(batch_size, 4, 84, 84)`:
            Output tensor representing the reconstructed images.
    """

    def __init__(self, hidden_size: int) -> None:
        super().__init__()
        self.fc = nn.Linear(hidden_size, 128 * 11 * 11)
        self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)  # 21x21
        self.norm1 = nn.InstanceNorm2d(64)
        self.att1 = AttentionLayer(64)
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)  # 42x42
        self.norm2 = nn.InstanceNorm2d(32)
        self.att2 = AttentionLayer(32)
        self.deconv3 = nn.ConvTranspose2d(32, 4, kernel_size=3, stride=2, padding=1, output_padding=1)  # 84x84

    def forward(self, x: FloatTensor) -> FloatTensor:
        x = self.fc(x)
        x = x.view(x.size(0), 128, 11, 11)  # Reshape to the spatial dimension of encoder's last conv layer
        x = F.leaky_relu(self.norm1(self.deconv1(x)), inplace=True)  # 22x22
        x = F.interpolate(x, size=(21, 21))  # 21x21
        x = self.att1(x)
        x = F.leaky_relu(self.norm2(self.deconv2(x)), inplace=True)
        x = self.att2(x)
        x = F.tanh(self.deconv3(x))
        return x


class DualBatchReshapeWrapper(nn.Module):
    """
    Wrapper to make a module designed for a single batch work with a dual batch.

    Args:
        module (`nn.Module`):
            Module to be wrapped.
    """

    def __init__(self, module: nn.Module) -> None:
        super().__init__()
        self.module = module

    def forward(self, x: FloatTensor) -> FloatTensor:
        n1, n2 = x.shape[:2]
        x = x.view(n1 * n2, *x.shape[2:])
        x = self.module(x)
        x = x.view(n1, n2, *x.shape[1:])
        return x


@dataclass
class JatOutput(ModelOutput):
    """
    Output of the Jat model.

    The model can be used for both RL and NLP tasks. For RL tasks, the model takes in observations and actions
    (`continuous_observations`, `discrete_actions`, etc.). For textual tasks, the model takes in a sequence of tokens
    and/or images (`input_ids`, `image`). The output depends on the type of input.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
            For RL input, the loss is the sum of the observation loss and the action loss.
            For textual input, the causal language modeling loss.
        observation_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
            Only returned when RL input is provided. The MSE loss between predicted and true observations for
            continuous observations and the cross-entropy loss for discrete observations.
        action_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
            Only returned when RL input is provided. The MSE loss between predicted and true actions for
            continuous actions and the cross-entropy loss for discrete actions.
        pred_observations (`torch.FloatTensor` of shape `(batch_size, max_seq_len, ...)`):
            Only returned when RL input is provided. Predicted observations from t=1 to t=max_seq_len+1.
        pred_actions (`torch.FloatTensor` of shape `(batch_size, max_seq_len, ...)`):
            Only returned when RL input is provided. Predicted actions from t=0 to t=max_seq_len. When input actions
            are discrete, the predicted actions are logits.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.

            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
            hidden_size)` is output.
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
            when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
            `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
            encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
            input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
            when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when
            `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    loss: Optional[FloatTensor] = None
    observation_loss: Optional[FloatTensor] = None
    action_loss: Optional[FloatTensor] = None
    pred_observations: Optional[FloatTensor] = None
    pred_actions: Optional[FloatTensor] = None
    logits: Optional[FloatTensor] = None
    past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None
    hidden_states: Optional[Tuple[FloatTensor]] = None
    attentions: Optional[Tuple[FloatTensor]] = None


class JatModel(GPTNeoPreTrainedModel):
    """
    Jat model.
    """

    config_class = JatConfig

    def __init__(self, config: JatConfig) -> None:
        super().__init__(config)

        vocab_size = config.vocab_size
        hidden_size = config.hidden_size
        max_discrete_value = config.max_discrete_value
        max_continuous_size = config.max_continuous_size
        self.observation_loss_coef = config.observation_loss_coef
        self.action_loss_coef = config.action_loss_coef

        # Transformer
        self.transformer = GPTNeoModel(config)

        # Encoders
        self.vit_encoder = ViTPatchEmbeddings(config)
        self.single_discrete_encoder = self.transformer.wte
        self.continuous_encoder = nn.Linear(max_continuous_size, hidden_size)
        self.multi_discrete_encoder = nn.Sequential(
            self.single_discrete_encoder,  # (B, L, X, H)
            nn.Linear(hidden_size, hidden_size // 50),  # (B, L, X, H // 50)
            nn.ReLU(),
            nn.Flatten(start_dim=2),  # (B, L, X * (H // 50))
            nn.Linear(max_discrete_value * (hidden_size // 50), hidden_size - 1),  # (B, L, H)
        )  # -1 to account for the reward
        self.image_encoder = DualBatchReshapeWrapper(ImageEncoder(hidden_size))

        # Decoders
        self.single_discrete_decoder = nn.Linear(hidden_size, vocab_size, bias=False)
        self.continuous_decoder = nn.Linear(hidden_size, max_continuous_size)
        self.multi_discrete_decoder = nn.Sequential(
            nn.Linear(hidden_size, max_discrete_value * (hidden_size // 50)),  # (B, L, X * (H // 50))
            nn.Unflatten(dim=2, unflattened_size=(max_discrete_value, hidden_size // 50)),  # (B, L, X, H // 50)
            nn.ReLU(),
            nn.Linear(hidden_size // 50, hidden_size),  # (B, L, X, H)
            nn.ReLU(),
            nn.Linear(hidden_size, 8, bias=False),  # (B, L, X, 8) - the max possible value in the dataset is 8
        )
        self.image_decoder = DualBatchReshapeWrapper(ImageDecoder(hidden_size))

        # Initialize weights and apply final processing
        self.post_init()

    def embed_textual(
        self,
        input_ids: Optional[LongTensor],
        pixel_values: Optional[FloatTensor] = None,
        attention_mask: Optional[BoolTensor] = None,
    ) -> Tensor:
        text_inputs_embeds = self.single_discrete_encoder(input_ids) if input_ids is not None else None
        image_inputs_embeds = self.vit_encoder(pixel_values) if pixel_values is not None else None
        # Concatenate text and image inputs
        if image_inputs_embeds is not None and text_inputs_embeds is not None:
            inputs_embeds = torch.cat((image_inputs_embeds, text_inputs_embeds), dim=1)
            # Add attention mask for image inputs
            image_mask = torch.ones(image_inputs_embeds.shape[:2], dtype=torch.bool, device=self.device)
            if attention_mask is None:
                attention_mask = torch.ones(text_inputs_embeds.shape[:2], dtype=torch.bool, device=self.device)
            attention_mask = torch.cat((image_mask, attention_mask), dim=1)
        elif image_inputs_embeds is not None:
            inputs_embeds = image_inputs_embeds
        elif text_inputs_embeds is not None:
            inputs_embeds = text_inputs_embeds
            attention_mask = attention_mask
        else:
            raise ValueError("At least one of `input_ids` or `pixel_values` must be provided.")
        return inputs_embeds, attention_mask

    def embed_rl(
        self,
        continuous_observations: Optional[FloatTensor] = None,
        discrete_observations: Optional[LongTensor] = None,
        image_observations: Optional[FloatTensor] = None,
        continuous_actions: Optional[FloatTensor] = None,
        discrete_actions: Optional[LongTensor] = None,
        rewards: Optional[FloatTensor] = None,
        attention_mask: Optional[BoolTensor] = None,
    ):
        # Prepare RL inputs (pad and cat rewards to observations)
        assert rewards is not None
        if continuous_observations is not None:
            continuous_observations = torch.cat((continuous_observations, rewards.unsqueeze(-1)), dim=-1)
            continuous_observations = cyclic_expand_dim(continuous_observations, self.config.max_continuous_size)
        if continuous_actions is not None:
            continuous_actions = cyclic_expand_dim(continuous_actions, self.config.max_continuous_size)

        # Encode
        if continuous_observations is not None:
            batch_size, seq_len = continuous_observations.shape[:2]
            inputs_embeds_observations = self.continuous_encoder(continuous_observations)
        elif discrete_observations is not None:
            batch_size, seq_len = discrete_observations.shape[:2]
            inputs_embeds_observations = self.multi_discrete_encoder(discrete_observations)
            inputs_embeds_observations = torch.cat((inputs_embeds_observations, rewards.unsqueeze(-1)), dim=-1)
        elif image_observations is not None:
            batch_size, seq_len = image_observations.shape[:2]
            inputs_embeds_observations = self.image_encoder(image_observations)
        else:
            raise ValueError("Missing observations.")
        if continuous_actions is not None:
            inputs_embeds_actions = self.continuous_encoder(continuous_actions)
        elif discrete_actions is not None:
            inputs_embeds_actions = self.single_discrete_encoder(discrete_actions)
        else:
            raise ValueError("Missing actions.")

        # Concatenate observations and actions
        inputs_embeds = torch.cat((inputs_embeds_observations, inputs_embeds_actions), dim=2)
        inputs_embeds = inputs_embeds.view(batch_size, 2 * seq_len, self.config.hidden_size)
        if attention_mask is not None:
            attention_mask = torch.repeat_interleave(attention_mask, repeats=2, dim=1)
        return inputs_embeds, attention_mask

    def output_textual(
        self,
        transformer_outputs,
        input_ids: Optional[LongTensor] = None,
        attention_mask: Optional[BoolTensor] = None,
        return_loss: bool = True,
        return_dict: Optional[bool] = None,
    ):
        hidden_states = transformer_outputs[0]
        loss = None
        # Get only textual hidden states
        lm_logits = self.single_discrete_decoder(hidden_states)
        if return_loss:
            if input_ids is None:
                raise ValueError("Input IDs must be provided when `return_loss=True`.")

            # Shift so that tokens < n predict n
            num_text_tokens = input_ids.shape[1]
            shift_logits = lm_logits[:, -num_text_tokens:-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            if attention_mask is not None:
                shift_attention_mask = attention_mask[:, -num_text_tokens:]
                shift_attention_mask = shift_attention_mask[:, 1:]
            else:
                shift_attention_mask = torch.ones(shift_labels.shape, dtype=bool, device=self.device)
            shift_logits = shift_logits[shift_attention_mask.bool()]
            shift_labels = shift_labels[shift_attention_mask.bool()]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return JatOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

    def output_rl(
        self,
        transformer_outputs,
        continuous_observations: Optional[FloatTensor] = None,
        discrete_observations: Optional[LongTensor] = None,
        image_observations: Optional[FloatTensor] = None,
        continuous_actions: Optional[FloatTensor] = None,
        discrete_actions: Optional[LongTensor] = None,
        rewards: Optional[FloatTensor] = None,
        attention_mask: Optional[BoolTensor] = None,
        return_loss: bool = True,
        return_dict: Optional[bool] = None,
        loss_weight: Optional[FloatTensor] = None,
    ):
        hidden_states = transformer_outputs.last_hidden_state
        loss, observation_loss, action_loss = None, None, None
        # Observations
        assert rewards is not None
        observations_mask = attention_mask[:, 1::2] if attention_mask is not None else None
        if continuous_observations is not None:
            if self.observation_loss_coef == 0.0:
                warnings.warn("observation_loss_coef is 0.0, skipping memory-intensive observations prediction.")
                pred_observations = None
                observation_loss = 0.0
            else:
                obs_size = continuous_observations.shape[-1]
                continuous_observations = torch.cat((continuous_observations, rewards.unsqueeze(-1)), dim=-1)
                continuous_observations = cyclic_expand_dim(continuous_observations, self.config.max_continuous_size)
                pred_observations = self.continuous_decoder(hidden_states[:, 1::2])
                if return_loss:
                    observation_loss = compute_mse_loss(
                        pred_observations[:, :-1],
                        continuous_observations[:, 1:],
                        observations_mask[:, 1:] if observations_mask is not None else None,
                        weights=loss_weight[:, 1:] if loss_weight is not None else None,
                    )
                pred_observations = pred_observations[..., :obs_size]
        elif discrete_observations is not None:  # Note: reward is not predicted
            if self.observation_loss_coef == 0.0:
                warnings.warn("observation_loss_coef is 0.0, skipping memory-intensive observations prediction.")
                pred_observations = None
                observation_loss = 0.0
            else:
                warnings.warn("Discrete observations prediction are not supported yet.")  # way too expensive
                pred_observations = None
                observation_loss = 0.0
                # pred_observations = self.multi_discrete_decoder(hidden_states[:, 1::2])
                # if return_loss:
                #     observation_loss = compute_ce_loss(
                #         pred_observations[:, :-1],
                #         discrete_observations[:, 1:],
                #         observations_mask[:, 1:] if observations_mask is not None else None,
                #         weights=loss_weight[:, 1:] if loss_weight is not None else None,
                #     )
        elif image_observations is not None:
            if self.observation_loss_coef == 0.0:
                warnings.warn("observation_loss_coef is 0.0, skipping memory-intensive observations prediction.")
                pred_observations = None
                observation_loss = 0.0
            else:
                pred_observations = self.image_decoder(hidden_states[:, 1::2])
                if return_loss:
                    observation_loss = compute_mse_loss(
                        pred_observations[:, :-1],
                        image_observations[:, 1:],
                        observations_mask[:, 1:] if observations_mask is not None else None,
                        weights=loss_weight[:, 1:] if loss_weight is not None else None,
                    )

        # Actions
        actions_mask = attention_mask[:, ::2] if attention_mask is not None else None
        if continuous_actions is not None:
            act_size = continuous_actions.shape[-1]
            continuous_actions = cyclic_expand_dim(continuous_actions, self.config.max_continuous_size)
            pred_actions = self.continuous_decoder(hidden_states[:, ::2])
            if return_loss:
                action_loss = compute_mse_loss(pred_actions, continuous_actions, actions_mask, weights=loss_weight)
            pred_actions = pred_actions[..., :act_size]
        elif discrete_actions is not None:
            pred_actions = self.single_discrete_decoder(hidden_states[:, ::2])
            if return_loss:
                action_loss = compute_ce_loss(pred_actions, discrete_actions, actions_mask, weights=loss_weight)

        # Return output
        if return_loss:
            loss = self.observation_loss_coef * observation_loss + self.action_loss_coef * action_loss

        if not return_dict:
            output = (pred_observations, pred_actions) + transformer_outputs[1:]
            return ((loss, observation_loss, action_loss) + output) if loss is not None else output

        return JatOutput(
            loss=loss,
            observation_loss=observation_loss,
            action_loss=action_loss,
            pred_observations=pred_observations,
            pred_actions=pred_actions,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

    def forward(
        self,
        input_ids: Optional[LongTensor] = None,
        pixel_values: Optional[FloatTensor] = None,
        continuous_observations: Optional[FloatTensor] = None,
        discrete_observations: Optional[LongTensor] = None,
        image_observations: Optional[FloatTensor] = None,
        continuous_actions: Optional[FloatTensor] = None,
        discrete_actions: Optional[LongTensor] = None,
        rewards: Optional[FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None,
        attention_mask: Optional[BoolTensor] = None,
        token_type_ids: Optional[LongTensor] = None,
        position_ids: Optional[LongTensor] = None,
        return_loss: bool = True,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        loss_weight: Optional[FloatTensor] = None,
    ) -> JatOutput:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Textual tasks
        if input_ids is not None or pixel_values is not None:
            inputs_embeds, attention_mask = self.embed_textual(input_ids, pixel_values, attention_mask)
        # RL tasks
        elif (
            continuous_observations is not None or discrete_observations is not None or image_observations is not None
        ):
            inputs_embeds, attention_mask = self.embed_rl(
                continuous_observations,
                discrete_observations,
                image_observations,
                continuous_actions,
                discrete_actions,
                rewards,
                attention_mask,
            )
        else:
            raise ValueError("Input not provided.")

        # Pass through transformer
        transformer_outputs = self.transformer(
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if input_ids is not None or pixel_values is not None:
            return self.output_textual(transformer_outputs, input_ids, attention_mask, return_loss, return_dict)
        else:
            return self.output_rl(
                transformer_outputs,
                continuous_observations,
                discrete_observations,
                image_observations,
                continuous_actions,
                discrete_actions,
                rewards,
                attention_mask,
                return_loss,
                return_dict,
                loss_weight,
            )

    def reset_rl(self):
        self._last_key_values = None
        self.last_discrete_observation = None
        self.last_continuous_observation = None
        self.last_text_observation = None
        self.last_image_observation = None
        self.last_discrete_action = None
        self.last_continuous_action = None
        self.last_reward = None

    @torch.no_grad()
    def get_next_action(
        self,
        processor: JatProcessor,
        continuous_observation: Optional[List[float]] = None,
        discrete_observation: Optional[List[int]] = None,
        text_observation: Optional[str] = None,
        image_observation: Optional[np.ndarray] = None,
        action_space: Union[spaces.Box, spaces.Discrete] = None,
        reward: Optional[float] = None,
        deterministic: bool = False,
    ):
        # Get the maximum sequence length
        max_length = self.config.max_position_embeddings // 2

        # Convert everything to lists
        def to_list(x):
            return x.tolist() if isinstance(x, np.ndarray) else x

        continuous_observation = to_list(continuous_observation)
        discrete_observation = to_list(discrete_observation)

        # Add a fake action to the end of the sequence
        if isinstance(action_space, spaces.Box):
            fake_continuous_action = [0.0 for _ in range(action_space.shape[0])]
            fake_discrete_action = None
        elif isinstance(action_space, spaces.Discrete):
            fake_continuous_action = None
            fake_discrete_action = 0

        continuous_observations = [continuous_observation] if continuous_observation is not None else None
        discrete_observations = [discrete_observation] if discrete_observation is not None else None
        text_observations = [text_observation] if text_observation is not None else None
        image_observations = [image_observation] if image_observation is not None else None
        continuous_actions = [fake_continuous_action] if fake_continuous_action is not None else None
        discrete_actions = [fake_discrete_action] if fake_discrete_action is not None else None
        rewards = [reward] if reward is not None else [0.0]

        if self._last_key_values is not None:
            # We concatenate the last observation with the current one
            continuous_observations = (
                [self.last_continuous_observation] + continuous_observations
                if continuous_observations is not None
                else None
            )
            discrete_observations = (
                [self.last_discrete_observation] + discrete_observations if discrete_observations is not None else None
            )
            text_observations = (
                [self.last_text_observation] + text_observations if text_observations is not None else None
            )
            image_observations = (
                [self.last_image_observation] + image_observations if image_observations is not None else None
            )
            continuous_actions = (
                [self.last_continuous_action] + continuous_actions if continuous_actions is not None else None
            )
            discrete_actions = [self.last_discrete_action] + discrete_actions if discrete_actions is not None else None
            rewards = [self.last_reward] + rewards

        # Store the last observation
        self.last_continuous_observation = continuous_observations[-1] if continuous_observations is not None else None
        self.last_discrete_observation = discrete_observations[-1] if discrete_observations is not None else None
        self.last_text_observation = text_observations[-1] if text_observations is not None else None
        self.last_image_observation = image_observations[-1] if image_observations is not None else None
        self.last_reward = rewards[-1]

        # Add the batch dimension
        continuous_observations = [continuous_observations] if continuous_observations is not None else None
        discrete_observations = [discrete_observations] if discrete_observations is not None else None
        text_observations = [text_observations] if text_observations is not None else None
        image_observations = [image_observations] if image_observations is not None else None
        continuous_actions = [continuous_actions] if continuous_actions is not None else None
        discrete_actions = [discrete_actions] if discrete_actions is not None else None
        rewards = [rewards]

        # Process the inputs
        processed = processor(
            continuous_observations=continuous_observations,
            discrete_observations=discrete_observations,
            text_observations=text_observations,
            image_observations=image_observations,
            continuous_actions=continuous_actions,
            discrete_actions=discrete_actions,
            rewards=rewards,
            truncation=True,
            truncation_side="left",
            max_length=max_length,
            return_tensors="pt",
        )
        processed.to(self.device)

        # Forward pass
        outputs = self(**processed, past_key_values=self._last_key_values, return_loss=False)

        # Truncate the past key-values
        self._last_key_values = tuple(
            tuple(pkv[:, :, -self.config.max_position_embeddings + 2 :] for pkv in pkvs)
            for pkvs in outputs.past_key_values
        )
        # Store the last key values
        # We remove the last two values, as the inputs are [s_0, 0], [s_0, a_0, s_1, 0], [s_1, a_1, s_2, 0], ...
        self._last_key_values = tuple(tuple(pkv[:, :, :-2] for pkv in pkvs) for pkvs in self._last_key_values)

        # Return the predicted action
        if continuous_actions is not None:
            self.last_continuous_action = outputs.pred_actions[0, -1].cpu().tolist()
            return self.last_continuous_action
        elif discrete_actions is not None:
            logits = outputs.pred_actions[0, -1, : action_space.n]
            if deterministic:
                self.last_discrete_action = logits.argmax().cpu().item()
            else:  # sample
                self.last_discrete_action = torch.multinomial(logits.softmax(dim=-1), num_samples=1)[0].item()
            return self.last_discrete_action

    # Allows to use .generate()
    def prepare_inputs_for_generation(self, input_ids, pixel_values=None, past_key_values=None, **kwargs):
        # only last token for inputs_ids if past is defined in kwargs
        if past_key_values is not None:
            pixel_values = None
            input_ids = input_ids[:, -1].unsqueeze(-1)

        model_inputs = {
            "input_ids": input_ids,
            "pixel_values": pixel_values,
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
        }

        return model_inputs


JatModel.register_for_auto_class("AutoModelForCausalLM")