|
r"""Colour space conversion functions""" |
|
from typing import Union, Dict |
|
import torch |
|
|
|
|
|
def rgb2lmn(x: torch.Tensor) -> torch.Tensor: |
|
r"""Convert a batch of RGB images to a batch of LMN images |
|
|
|
Args: |
|
x: Batch of images with shape (N, 3, H, W). RGB colour space. |
|
|
|
Returns: |
|
Batch of images with shape (N, 3, H, W). LMN colour space. |
|
""" |
|
weights_rgb_to_lmn = torch.tensor([[0.06, 0.63, 0.27], |
|
[0.30, 0.04, -0.35], |
|
[0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t() |
|
x_lmn = torch.matmul(x.permute(0, 2, 3, 1), weights_rgb_to_lmn).permute(0, 3, 1, 2) |
|
return x_lmn |
|
|
|
|
|
def rgb2xyz(x: torch.Tensor) -> torch.Tensor: |
|
r"""Convert a batch of RGB images to a batch of XYZ images |
|
|
|
Args: |
|
x: Batch of images with shape (N, 3, H, W). RGB colour space. |
|
|
|
Returns: |
|
Batch of images with shape (N, 3, H, W). XYZ colour space. |
|
""" |
|
mask_below = (x <= 0.04045).type(x.dtype) |
|
mask_above = (x > 0.04045).type(x.dtype) |
|
|
|
tmp = x / 12.92 * mask_below + torch.pow((x + 0.055) / 1.055, 2.4) * mask_above |
|
|
|
weights_rgb_to_xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375], |
|
[0.2126729, 0.7151522, 0.0721750], |
|
[0.0193339, 0.1191920, 0.9503041]], dtype=x.dtype, device=x.device) |
|
|
|
x_xyz = torch.matmul(tmp.permute(0, 2, 3, 1), weights_rgb_to_xyz.t()).permute(0, 3, 1, 2) |
|
return x_xyz |
|
|
|
|
|
def xyz2lab(x: torch.Tensor, illuminant: str = 'D50', observer: str = '2') -> torch.Tensor: |
|
r"""Convert a batch of XYZ images to a batch of LAB images |
|
|
|
Args: |
|
x: Batch of images with shape (N, 3, H, W). XYZ colour space. |
|
illuminant: {“A”, “D50”, “D55”, “D65”, “D75”, “E”}, optional. The name of the illuminant. |
|
observer: {“2”, “10”}, optional. The aperture angle of the observer. |
|
|
|
Returns: |
|
Batch of images with shape (N, 3, H, W). LAB colour space. |
|
""" |
|
epsilon = 0.008856 |
|
kappa = 903.3 |
|
illuminants: Dict[str, Dict] = \ |
|
{"A": {'2': (1.098466069456375, 1, 0.3558228003436005), |
|
'10': (1.111420406956693, 1, 0.3519978321919493)}, |
|
"D50": {'2': (0.9642119944211994, 1, 0.8251882845188288), |
|
'10': (0.9672062750333777, 1, 0.8142801513128616)}, |
|
"D55": {'2': (0.956797052643698, 1, 0.9214805860173273), |
|
'10': (0.9579665682254781, 1, 0.9092525159847462)}, |
|
"D65": {'2': (0.95047, 1., 1.08883), |
|
'10': (0.94809667673716, 1, 1.0730513595166162)}, |
|
"D75": {'2': (0.9497220898840717, 1, 1.226393520724154), |
|
'10': (0.9441713925645873, 1, 1.2064272211720228)}, |
|
"E": {'2': (1.0, 1.0, 1.0), |
|
'10': (1.0, 1.0, 1.0)}} |
|
|
|
illuminants_to_use = torch.tensor(illuminants[illuminant][observer], |
|
dtype=x.dtype, device=x.device).view(1, 3, 1, 1) |
|
|
|
tmp = x / illuminants_to_use |
|
|
|
mask_below = (tmp <= epsilon).type(x.dtype) |
|
mask_above = (tmp > epsilon).type(x.dtype) |
|
tmp = torch.pow(tmp, 1. / 3.) * mask_above + (kappa * tmp + 16.) / 116. * mask_below |
|
|
|
weights_xyz_to_lab = torch.tensor([[0, 116., 0], |
|
[500., -500., 0], |
|
[0, 200., -200.]], dtype=x.dtype, device=x.device) |
|
bias_xyz_to_lab = torch.tensor([-16., 0., 0.], dtype=x.dtype, device=x.device).view(1, 3, 1, 1) |
|
|
|
x_lab = torch.matmul(tmp.permute(0, 2, 3, 1), weights_xyz_to_lab.t()).permute(0, 3, 1, 2) + bias_xyz_to_lab |
|
return x_lab |
|
|
|
|
|
def rgb2lab(x: torch.Tensor, data_range: Union[int, float] = 255) -> torch.Tensor: |
|
r"""Convert a batch of RGB images to a batch of LAB images |
|
|
|
Args: |
|
x: Batch of images with shape (N, 3, H, W). RGB colour space. |
|
data_range: dynamic range of the input image. |
|
|
|
Returns: |
|
Batch of images with shape (N, 3, H, W). LAB colour space. |
|
""" |
|
return xyz2lab(rgb2xyz(x / float(data_range))) |
|
|
|
|
|
def rgb2yiq(x: torch.Tensor) -> torch.Tensor: |
|
r"""Convert a batch of RGB images to a batch of YIQ images |
|
|
|
Args: |
|
x: Batch of images with shape (N, 3, H, W). RGB colour space. |
|
|
|
Returns: |
|
Batch of images with shape (N, 3, H, W). YIQ colour space. |
|
""" |
|
yiq_weights = torch.tensor([ |
|
[0.299, 0.587, 0.114], |
|
[0.5959, -0.2746, -0.3213], |
|
[0.2115, -0.5227, 0.3112]], dtype=x.dtype, device=x.device).t() |
|
x_yiq = torch.matmul(x.permute(0, 2, 3, 1), yiq_weights).permute(0, 3, 1, 2) |
|
return x_yiq |
|
|
|
|
|
def rgb2lhm(x: torch.Tensor) -> torch.Tensor: |
|
r"""Convert a batch of RGB images to a batch of LHM images |
|
|
|
Args: |
|
x: Batch of images with shape (N, 3, H, W). RGB colour space. |
|
|
|
Returns: |
|
Batch of images with shape (N, 3, H, W). LHM colour space. |
|
|
|
Reference: |
|
https://arxiv.org/pdf/1608.07433.pdf |
|
""" |
|
lhm_weights = torch.tensor([ |
|
[0.2989, 0.587, 0.114], |
|
[0.3, 0.04, -0.35], |
|
[0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t() |
|
x_lhm = torch.matmul(x.permute(0, 2, 3, 1), lhm_weights).permute(0, 3, 1, 2) |
|
return x_lhm |
|
|