import logging
from time import perf_counter
from typing import List, Optional, TypeVar

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard.writer import SummaryWriter

from rl_algo_impls.shared.algorithm import Algorithm
from rl_algo_impls.shared.callbacks import Callback
from rl_algo_impls.shared.gae import compute_advantages
from rl_algo_impls.shared.policy.actor_critic import ActorCritic
from rl_algo_impls.shared.schedule import schedule, update_learning_rate
from rl_algo_impls.shared.stats import log_scalars
from rl_algo_impls.wrappers.vectorable_wrapper import (
    VecEnv,
    single_action_space,
    single_observation_space,
)

A2CSelf = TypeVar("A2CSelf", bound="A2C")


class A2C(Algorithm):
    def __init__(
        self,
        policy: ActorCritic,
        env: VecEnv,
        device: torch.device,
        tb_writer: SummaryWriter,
        learning_rate: float = 7e-4,
        learning_rate_decay: str = "none",
        n_steps: int = 5,
        gamma: float = 0.99,
        gae_lambda: float = 1.0,
        ent_coef: float = 0.0,
        ent_coef_decay: str = "none",
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        rms_prop_eps: float = 1e-5,
        use_rms_prop: bool = True,
        sde_sample_freq: int = -1,
        normalize_advantage: bool = False,
    ) -> None:
        super().__init__(policy, env, device, tb_writer)
        self.policy = policy

        self.lr_schedule = schedule(learning_rate_decay, learning_rate)
        if use_rms_prop:
            self.optimizer = torch.optim.RMSprop(
                policy.parameters(), lr=learning_rate, eps=rms_prop_eps
            )
        else:
            self.optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate)

        self.n_steps = n_steps

        self.gamma = gamma
        self.gae_lambda = gae_lambda

        self.vf_coef = vf_coef
        self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
        self.max_grad_norm = max_grad_norm

        self.sde_sample_freq = sde_sample_freq
        self.normalize_advantage = normalize_advantage

    def learn(
        self: A2CSelf,
        train_timesteps: int,
        callbacks: Optional[List[Callback]] = None,
        total_timesteps: Optional[int] = None,
        start_timesteps: int = 0,
    ) -> A2CSelf:
        if total_timesteps is None:
            total_timesteps = train_timesteps
        assert start_timesteps + train_timesteps <= total_timesteps
        epoch_dim = (self.n_steps, self.env.num_envs)
        step_dim = (self.env.num_envs,)
        obs_space = single_observation_space(self.env)
        act_space = single_action_space(self.env)

        obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
        actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
        rewards = np.zeros(epoch_dim, dtype=np.float32)
        episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
        values = np.zeros(epoch_dim, dtype=np.float32)
        logprobs = np.zeros(epoch_dim, dtype=np.float32)

        next_obs = self.env.reset()
        next_episode_starts = np.full(step_dim, True, dtype=np.bool8)

        timesteps_elapsed = start_timesteps
        while timesteps_elapsed < start_timesteps + train_timesteps:
            start_time = perf_counter()

            progress = timesteps_elapsed / total_timesteps
            ent_coef = self.ent_coef_schedule(progress)
            learning_rate = self.lr_schedule(progress)
            update_learning_rate(self.optimizer, learning_rate)
            log_scalars(
                self.tb_writer,
                "charts",
                {
                    "ent_coef": ent_coef,
                    "learning_rate": learning_rate,
                },
                timesteps_elapsed,
            )

            self.policy.eval()
            self.policy.reset_noise()
            for s in range(self.n_steps):
                timesteps_elapsed += self.env.num_envs
                if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
                    self.policy.reset_noise()

                obs[s] = next_obs
                episode_starts[s] = next_episode_starts

                actions[s], values[s], logprobs[s], clamped_action = self.policy.step(
                    next_obs
                )
                next_obs, rewards[s], next_episode_starts, _ = self.env.step(
                    clamped_action
                )

            advantages = compute_advantages(
                rewards,
                values,
                episode_starts,
                next_episode_starts,
                next_obs,
                self.policy,
                self.gamma,
                self.gae_lambda,
            )
            returns = advantages + values

            b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
            b_actions = torch.tensor(actions.reshape((-1,) + act_space.shape)).to(
                self.device
            )
            b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
            b_returns = torch.tensor(returns.reshape(-1)).to(self.device)

            if self.normalize_advantage:
                b_advantages = (b_advantages - b_advantages.mean()) / (
                    b_advantages.std() + 1e-8
                )

            self.policy.train()
            logp_a, entropy, v = self.policy(b_obs, b_actions)

            pi_loss = -(b_advantages * logp_a).mean()
            value_loss = F.mse_loss(b_returns, v)
            entropy_loss = -entropy.mean()

            loss = pi_loss + self.vf_coef * value_loss + ent_coef * entropy_loss

            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.optimizer.step()

            y_pred = values.reshape(-1)
            y_true = returns.reshape(-1)
            var_y = np.var(y_true).item()
            explained_var = (
                np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
            )

            end_time = perf_counter()
            rollout_steps = self.n_steps * self.env.num_envs
            self.tb_writer.add_scalar(
                "train/steps_per_second",
                (rollout_steps) / (end_time - start_time),
                timesteps_elapsed,
            )

            log_scalars(
                self.tb_writer,
                "losses",
                {
                    "loss": loss.item(),
                    "pi_loss": pi_loss.item(),
                    "v_loss": value_loss.item(),
                    "entropy_loss": entropy_loss.item(),
                    "explained_var": explained_var,
                },
                timesteps_elapsed,
            )

            if callbacks:
                if not all(
                    c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
                ):
                    logging.info(
                        f"Callback terminated training at {timesteps_elapsed} timesteps"
                    )
                    break

        return self