from ddpg import Agent
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import torch
from captum.attr import (IntegratedGradients)
from gymnasium.wrappers import RecordVideo


class TrainingLoop:
    def __init__(self, env_spec, output_path='./output/', seed=0, **kwargs):
        assert env_spec in gym.envs.registry.keys()

        self.defaults = {
            "id": env_spec,
            "continuous": True,
            "gravity": -10.0,
            "render_mode": None
        }

        self.env = gym.make(
            **self.defaults
        )

        self.defaults.update(**kwargs)

        torch.manual_seed(seed)

        self.agent = None
        self.output_path = output_path

    # TODO: spec-to-hyperparameters look-up
    def create_agent(self, alpha=0.000025, beta=0.00025, input_dims=[8], tau=0.001, batch_size=64, layer1_size=400, layer2_size=300, n_actions=4):
        self.agent = Agent(alpha=alpha, beta=beta, input_dims=input_dims, tau=tau, env=self.env, batch_size=batch_size, layer1_size=layer1_size, layer2_size=layer2_size, n_actions=n_actions)

    def train(self):
        assert self.agent is not None
        
        self.defaults["render_mode"] = None
        
        self.env = gym.make(
            **self.defaults
        )

        # self.agent.load_models()

        score_history = []

        for i in range(10000):
            done = False
            score = 0
            obs, _ = self.env.reset()
            while not done:
                act = self.agent.choose_action(obs)
                new_state, reward, terminated, truncated, info = self.env.step(act)
                done = terminated or truncated
                self.agent.remember(obs, act, reward, new_state, int(done))
                self.agent.learn()
                score += reward
                obs = new_state

            score_history.append(score)
            print("episode", i, "score %.2f" % score, "100 game average %.2f" % np.mean(score_history[-100:]))
            if i % 25 == 0:
                self.agent.save_models()

        self.env.close()


    def load_trained(self):
        assert self.agent is not None

        self.defaults["render_mode"] = None
        
        self.env = gym.make(
            **self.defaults
        )

        self.agent.load_models()

        score_history = []

        for i in range(50):
            done = False
            score = 0
            obs, _ = self.env.reset()
            
            while not done:
                act = self.agent.choose_action(obs)
                new_state, reward, terminated, truncated, info = self.env.step(act)
                done = terminated or truncated
                score += reward
                obs = new_state

            score_history.append(score)
            print("episode", i, "score %.2f" % score, "100 game average %.2f" % np.mean(score_history[-100:]))

        self.env.close()

    # Video Recording

    # def render_video(self, episode_trigger=100):
    #     assert self.agent is not None

    #     self.defaults["render_mode"] = "rgb_array"
    #     self.env = gym.make(
    #         **self.defaults
    #     )

    #     episode_trigger_callable = lambda x: x % episode_trigger == 0

    #     self.env = RecordVideo(env=self.env, video_folder=self.output_path, name_prefix=f"{self.defaults['id']}-recording", episode_trigger=episode_trigger_callable, disable_logger=True)

    #     self.agent.load_models()

    #     score_history = []

    #     for i in range(200):
    #         done = False
    #         score = 0
    #         obs, _ = self.env.reset()
    #         while not done:
    #             act = self.agent.choose_action(observation=obs)
    #             new_state, reward, terminated, truncated, info = self.env.step(act)
    #             done = terminated or truncated
    #             score += reward
    #             obs = new_state


    #         score_history.append(score)
    #         print("episode", i, "score %.2f" % score, "100 game average %.2f" % np.mean(score_history[-100:]))

    #     self.env.close()
    

    # Model Explainability

    from captum.attr import (IntegratedGradients)

    def _collect_running_baseline_average(self, num_iterations: int) -> torch.Tensor:
        assert self.agent is not None

        self.defaults["render_mode"] = None
        
        self.env = gym.make(
            **self.defaults
        )

        print("--------- Collecting running baseline average ----------")

        self.agent.load_models()

        sum_obs = torch.zeros(8)

        for i in range(num_iterations):
            done = False
            score = 0
            obs, _ = self.env.reset()
            
            sum_obs += obs
            # print(f"Baseline on interation #{i}: {obs}")

            while not done:
                act = self.agent.choose_action(obs, baseline=None)
                new_state, reward, terminated, truncated, info = self.env.step(act)
                done = terminated or truncated
                score += reward
                obs = new_state

        print(f"Baseline collected: {sum_obs / num_iterations}")

        self.env.close()
        

        return sum_obs / num_iterations


    def explain_trained(self, option: str, num_iterations :int = 10) -> None:
        assert self.agent is not None
        
        baseline_options = {
            0: torch.zeros(8),
            1: self._collect_running_baseline_average(num_iterations), 
        }

        baseline = baseline_options[option]

        self.defaults["render_mode"] = "rgb_array"

        self.env = gym.make(
            **self.defaults
        )
        

        print("\n\n\n\n--------- Performing Attributions -----------")

        self.agent.load_models()

        
        print(self.agent.actor)
        ig = IntegratedGradients(self.agent.actor)
        self.agent.ig = ig

        score_history = []
        frames = []

        for i in range(10):
            done = False
            score = 0
            obs, _ = self.env.reset()
            while not done:
                frames.append(self.env.render())
                act = self.agent.choose_action(observation=obs, baseline=baseline)
                new_state, reward, terminated, truncated, info = self.env.step(act)
                done = terminated or truncated
                score += reward
                obs = new_state


            score_history.append(score)
            print("episode", i, "score %.2f" % score, "100 game average %.2f" % np.mean(score_history[-100:]))

        self.env.close()

        try:
            assert len(frames) == len(self.agent.attributions)
        except AssertionError:
            print("Frames and agent attribution history are not the same shape!")
        else: 
            pass

        return (frames, self.agent.attributions)