import configargparse
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.utils as tvutils
import torchvision.transforms
from video3d.utils.segmentation_transforms import *
from video3d.utils.misc import setup_runtime
from video3d import networks
from video3d.trainer import Trainer
from video3d.dataloaders import SegmentationDataset


class Segmentation:
    def __init__(self, cfgs, _):
        self.cfgs = cfgs
        self.device = cfgs.get('device', 'cpu')
        self.total_loss = None
        self.net = networks.EDDeconv(cin=3, cout=1, zdim=128, nf=64, activation=None)
        self.optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.net.parameters()),
            lr=cfgs.get('lr', 1e-4),
            betas=(0.9, 0.999),
            weight_decay=5e-4)

    def load_model_state(self, cp):
        self.net.load_state_dict(cp["net"])

    def load_optimizer_state(self, cp):
        self.net.load_state_dict(cp["optimizer"])

    @staticmethod
    def get_data_loaders(cfgs):
        batch_size = cfgs.get('batch_size', 64)
        num_workers = cfgs.get('num_workers', 4)
        data_dir = cfgs.get('data_dir', './data')
        img_size = cfgs.get('image_size', 64)
        min_size = int(img_size * cfgs.get('aug_min_resize', 0.5))
        max_size = int(img_size * cfgs.get('aug_max_resize', 2.0))
        transform = Compose([RandomResize(min_size, max_size),
                             RandomHorizontalFlip(cfgs.get("aug_horizontal_flip", 0.4)),
                             RandomCrop(img_size),
                             ImageOnly(torchvision.transforms.ColorJitter(**cfgs.get("aug_color_jitter", {}))),
                             ImageOnly(torchvision.transforms.RandomGrayscale(cfgs.get("aug_grayscale", 0.2))),
                             ToTensor()])
        train_loader = torch.utils.data.DataLoader(
            SegmentationDataset(data_dir, is_validation=False, transform=transform, sequence_range=(0, 0.5)),
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True
        )
        transform = Compose([ToTensor()])
        val_loader = torch.utils.data.DataLoader(
            SegmentationDataset(data_dir, is_validation=True, transform=transform, sequence_range=(0.5, 1.0)),
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        )
        return train_loader, val_loader, None

    def get_state_dict(self):
        return {
            "net": self.net.state_dict(),
            "optimizer": self.optimizer.state_dict()
        }

    def to(self, device):
        self.device = device
        self.net.to(device)

    def set_train(self):
        self.net.train()

    def set_eval(self):
        self.net.eval()

    def backward(self):
        self.optimizer.zero_grad()
        self.total_loss.backward()
        self.optimizer.step()

    def forward(self, batch, visualize=False):
        image, target = batch
        image = image.to(self.device)*2 - 1
        target = target[:, 0, :, :].to(self.device).unsqueeze(1)
        pred = self.net(image)

        self.total_loss = nn.functional.binary_cross_entropy_with_logits(pred, target)

        metrics = {'loss': self.total_loss}

        visuals = {}
        if visualize:
            visuals['rgb'] = self.image_visual(image, normalize=True, range=(-1, 1))
            visuals['target'] = self.image_visual(target, normalize=True, range=(0, 1))
            visuals['pred'] = self.image_visual(nn.functional.sigmoid(pred), normalize=True, range=(0, 1))

            return metrics, visuals

        return metrics

    def visualize(self, logger, total_iter, max_bs=25):
        pass

    def save_results(self, save_dir):
        pass

    def save_scores(self, path):
        pass

    @staticmethod
    def image_visual(tensor, **kwargs):
        if tensor.shape[1] == 1:
            tensor = tensor.repeat(1, 3, 1, 1)
        n = int(tensor.shape[0]**0.5 + 0.5)
        tensor = tvutils.make_grid(tensor.detach(), nrow=n, **kwargs).permute(1, 2, 0)
        return torch.clamp(tensor[:, :, :3] * 255, 0, 255).byte().cpu()


if __name__ == "__main__":
    parser = configargparse.ArgumentParser(description='Training configurations.')
    parser.add_argument('--config', default="config/train_segmentation.yml", type=str, is_config_file=True,
                        help='Specify a config file path')
    parser.add_argument('--gpu', default=1, type=int, help='Specify a GPU device')
    parser.add_argument('--seed', default=0, type=int, help='Specify a random seed')
    args, _ = parser.parse_known_args()

    cfgs = setup_runtime(args)
    trainer = Trainer(cfgs, Segmentation)
    trainer.train()