|
from PIL import Image |
|
import numpy as np |
|
|
|
import torch |
|
|
|
from .view_base import BaseView |
|
|
|
|
|
class SkewView(BaseView): |
|
def __init__(self, skew_factor=1.5): |
|
self.skew_factor = skew_factor |
|
|
|
def skew_image(self, im, skew_factor): |
|
''' |
|
Roll each column of the image by increasing displacements. |
|
This is a permutation of pixels |
|
''' |
|
|
|
|
|
c,h,w = im.shape |
|
h_center = h//2 |
|
|
|
|
|
cols = [] |
|
for i in range(w): |
|
d = int(skew_factor * (i - h_center)) |
|
col = im[:,:,i] |
|
cols.append(col.roll(d, dims=1)) |
|
|
|
|
|
skewed = torch.stack(cols, dim=2) |
|
return skewed |
|
|
|
def view(self, im): |
|
return self.skew_image(im, self.skew_factor) |
|
|
|
def inverse_view(self, noise): |
|
return self.skew_image(noise, -self.skew_factor) |
|
|
|
def make_frame(self, im, t): |
|
im_size = im.size[0] |
|
frame_size = int(im_size * 1.5) |
|
skew_factor = t * self.skew_factor |
|
|
|
|
|
im = torch.tensor(np.array(im) / 255.).permute(2,0,1) |
|
im = self.skew_image(im, skew_factor) |
|
im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8)) |
|
|
|
|
|
frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255)) |
|
frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2)) |
|
|
|
return frame |
|
|
|
|