import cv2
import tqdm
import torch
# import lpips
import numpy as np
from ops import utils
import torch.nn.functional as F
import torchvision.transforms as tvtf
from ops.gs.basic import Gaussian_Scene,Frame
from torchmetrics.image import StructuralSimilarityIndexMeasure

class RGB_Loss():
    def __init__(self,w_lpips=0.2,w_ssim=0.2):
        self.rgb_loss = F.smooth_l1_loss
        # self.lpips_alex = lpips.LPIPS(net='alex').to('cuda')
        self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to('cuda')
        self.w_ssim = w_ssim
        self.w_lpips = w_lpips
        
    def __call__(self,pr,gt,valid_mask=None):
        pr = torch.nan_to_num(pr)
        gt = torch.nan_to_num(gt)
        if len(pr.shape) < 3: pr = pr[:,:,None].repeat(1,1,3)
        if len(gt.shape) < 3: gt = gt[:,:,None].repeat(1,1,3)
        pr_valid = pr[valid_mask] if valid_mask is not None else pr.reshape(-1,pr.shape[-1])
        gt_valid = gt[valid_mask] if valid_mask is not None else gt.reshape(-1,gt.shape[-1])
        l_rgb = self.rgb_loss(pr_valid,gt_valid)
        l_ssim = 1.0 - self.ssim(pr[None].permute(0, 3, 1, 2), gt[None].permute(0, 3, 1, 2))
        # l_lpips = self.lpips_alex(pr[None].permute(0, 3, 1, 2), gt[None].permute(0, 3, 1, 2))
        return l_rgb + self.w_ssim * l_ssim

class GS_Train_Tool():
    '''
    Frames and well-trained gaussians are kept, refine the trainable gaussians
    The supervision comes from the Frames of GS_Scene
    '''
    def __init__(self,
                 GS:Gaussian_Scene,
                 iters = 100) -> None:
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        # hyperparameters for prune, densify, and update
        self.lr_factor = 1.00
        self.lr_update = 0.99
        # learning rate
        self.rgb_lr = 0.0005
        self.xyz_lr = 0.0001
        self.scale_lr = 0.005
        self.opacity_lr = 0.05
        self.rotation_lr = 0.001
        # GSs for training
        self.GS = GS
        # hyperparameters for training
        self.iters = iters
        self._init_optimizer()
        self.rgb_lossfunc = RGB_Loss(w_lpips=0)
    
    def _init_optimizer(self):
        self.optimize_frames = [gf for gf in self.GS.gaussian_frames if gf.rgb.requires_grad]
        # following https://github.com/pointrix-project/msplat
        self.optimizer = torch.optim.Adam([
            {'params': [gf.xyz for gf in self.optimize_frames],      'lr': self.xyz_lr},
            {'params': [gf.rgb for gf in self.optimize_frames],      'lr': self.rgb_lr},
            {'params': [gf.scale for gf in self.optimize_frames],    'lr': self.scale_lr},
            {'params': [gf.opacity for gf in self.optimize_frames],  'lr': self.opacity_lr},
            {'params': [gf.rotation for gf in self.optimize_frames], 'lr': self.rotation_lr}
        ]) 

    def _render(self,frame):
        rgb,dpt,alpha = self.GS._render_RGBD(frame)
        return rgb,dpt,alpha
    
    def _to_cuda(self,tensor):
        tensor = torch.from_numpy(tensor.astype(np.float32)).to('cuda')
        return tensor
    
    def __call__(self,target_frames=None):
        target_frames = self.GS.frames if target_frames is None else target_frames
        for iter in tqdm.tqdm(range(self.iters)):
            frame_idx = np.random.randint(0,len(target_frames))
            frame :Frame = target_frames[frame_idx]
            render_rgb,render_dpt,render_alpha=self._render(frame)
            loss_rgb = self.rgb_lossfunc(render_rgb,self._to_cuda(frame.rgb),valid_mask=frame.inpaint)
            # optimization
            loss = loss_rgb
            loss.backward()  
            self.optimizer.step()
            self.optimizer.zero_grad()
        refined_scene = self.GS
        for gf in refined_scene.gaussian_frames:
            gf._require_grad(False)            
        return refined_scene