Spaces:
Sleeping
Sleeping
# Standard libraries | |
import itertools | |
import numpy as np | |
# PyTorch | |
import torch | |
import torch.nn as nn | |
# Local | |
from . import JPEG_utils as utils | |
class y_dequantize(nn.Module): | |
"""Dequantize Y channel | |
Inputs: | |
image(tensor): batch x height x width | |
factor(float): compression factor | |
Outputs: | |
image(tensor): batch x height x width | |
""" | |
def __init__(self, factor=1): | |
super(y_dequantize, self).__init__() | |
self.y_table = utils.y_table | |
self.factor = factor | |
def forward(self, image): | |
return image * (self.y_table * self.factor) | |
class c_dequantize(nn.Module): | |
"""Dequantize CbCr channel | |
Inputs: | |
image(tensor): batch x height x width | |
factor(float): compression factor | |
Outputs: | |
image(tensor): batch x height x width | |
""" | |
def __init__(self, factor=1): | |
super(c_dequantize, self).__init__() | |
self.factor = factor | |
self.c_table = utils.c_table | |
def forward(self, image): | |
return image * (self.c_table * self.factor) | |
class idct_8x8(nn.Module): | |
"""Inverse discrete Cosine Transformation | |
Input: | |
dcp(tensor): batch x height x width | |
Output: | |
image(tensor): batch x height x width | |
""" | |
def __init__(self): | |
super(idct_8x8, self).__init__() | |
alpha = np.array([1.0 / np.sqrt(2)] + [1] * 7) | |
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) | |
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) | |
for x, y, u, v in itertools.product(range(8), repeat=4): | |
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos( | |
(2 * v + 1) * y * np.pi / 16 | |
) | |
self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) | |
def forward(self, image): | |
image = image * self.alpha | |
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 | |
result.view(image.shape) | |
return result | |
class block_merging(nn.Module): | |
"""Merge pathces into image | |
Inputs: | |
patches(tensor) batch x height*width/64, height x width | |
height(int) | |
width(int) | |
Output: | |
image(tensor): batch x height x width | |
""" | |
def __init__(self): | |
super(block_merging, self).__init__() | |
def forward(self, patches, height, width): | |
k = 8 | |
batch_size = patches.shape[0] | |
# print(patches.shape) # (1,1024,8,8) | |
image_reshaped = patches.view(batch_size, height // k, width // k, k, k) | |
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) | |
return image_transposed.contiguous().view(batch_size, height, width) | |
class chroma_upsampling(nn.Module): | |
"""Upsample chroma layers | |
Input: | |
y(tensor): y channel image | |
cb(tensor): cb channel | |
cr(tensor): cr channel | |
Ouput: | |
image(tensor): batch x height x width x 3 | |
""" | |
def __init__(self): | |
super(chroma_upsampling, self).__init__() | |
def forward(self, y, cb, cr): | |
def repeat(x, k=2): | |
height, width = x.shape[1:3] | |
x = x.unsqueeze(-1) | |
x = x.repeat(1, 1, k, k) | |
x = x.view(-1, height * k, width * k) | |
return x | |
cb = repeat(cb) | |
cr = repeat(cr) | |
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) | |
class ycbcr_to_rgb_jpeg(nn.Module): | |
"""Converts YCbCr image to RGB JPEG | |
Input: | |
image(tensor): batch x height x width x 3 | |
Outpput: | |
result(tensor): batch x 3 x height x width | |
""" | |
def __init__(self): | |
super(ycbcr_to_rgb_jpeg, self).__init__() | |
matrix = np.array( | |
[[1.0, 0.0, 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], | |
dtype=np.float32, | |
).T | |
self.shift = nn.Parameter(torch.tensor([0, -128.0, -128.0])) | |
self.matrix = nn.Parameter(torch.from_numpy(matrix)) | |
def forward(self, image): | |
result = torch.tensordot(image + self.shift, self.matrix, dims=1) | |
# result = torch.from_numpy(result) | |
result.view(image.shape) | |
return result.permute(0, 3, 1, 2) | |
class decompress_jpeg(nn.Module): | |
"""Full JPEG decompression algortihm | |
Input: | |
compressed(dict(tensor)): batch x h*w/64 x 8 x 8 | |
rounding(function): rounding function to use | |
factor(float): Compression factor | |
Ouput: | |
image(tensor): batch x 3 x height x width | |
""" | |
# def __init__(self, height, width, rounding=torch.round, factor=1): | |
def __init__(self, rounding=torch.round, factor=1): | |
super(decompress_jpeg, self).__init__() | |
self.c_dequantize = c_dequantize(factor=factor) | |
self.y_dequantize = y_dequantize(factor=factor) | |
self.idct = idct_8x8() | |
self.merging = block_merging() | |
# comment this line if no subsampling | |
self.chroma = chroma_upsampling() | |
self.colors = ycbcr_to_rgb_jpeg() | |
# self.height, self.width = height, width | |
def forward(self, y, cb, cr, height, width): | |
components = {"y": y, "cb": cb, "cr": cr} | |
# height = y.shape[0] | |
# width = y.shape[1] | |
self.height = height | |
self.width = width | |
for k in components.keys(): | |
if k in ("cb", "cr"): | |
comp = self.c_dequantize(components[k]) | |
# comment this line if no subsampling | |
height, width = int(self.height / 2), int(self.width / 2) | |
# height, width = int(self.height), int(self.width) | |
else: | |
comp = self.y_dequantize(components[k]) | |
# comment this line if no subsampling | |
height, width = self.height, self.width | |
comp = self.idct(comp) | |
components[k] = self.merging(comp, height, width) | |
# | |
# comment this line if no subsampling | |
image = self.chroma(components["y"], components["cb"], components["cr"]) | |
# image = torch.cat([components['y'].unsqueeze(3), components['cb'].unsqueeze(3), components['cr'].unsqueeze(3)], dim=3) | |
image = self.colors(image) | |
image = torch.min( | |
255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image) | |
) | |
return image / 255 | |