import numpy as np | |
import torch | |
import math | |
#------------------Denormalization--------------------------------------------- | |
def denormalize_bin(tensor): | |
tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1 | |
tr = tr.add(1).div(2) # Shift to [0, 1] | |
return tr | |
def denormalize_tr(tensor): | |
tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1 | |
tr = tr.add(1).div(2).mul(255) # Shift to [0, 1] and scale to [0, 255] | |
tr = tr.byte() # Convert the tensor to uint8 | |
return tr | |
def denormalize_ar(tensor): | |
tr = torch.clamp(tensor, -1., 1.) # Clamp the values between -1 and 1 | |
tr = tr.add(1).div(2).mul(255) # Shift to [0, 1] and scale to [0, 255] | |
tr = tr.byte() # Convert the tensor to uint8 | |
arr = tr.permute(0, 2, 3, 1).cpu().detach().numpy() # Convert to (N, H, W, C) and numpy array | |
return arr |