|
import numpy as np |
|
import torch as th |
|
from einops import rearrange |
|
|
|
from .generic_sampler import SimpleWork |
|
from .w_img import split_wimg, avg_merge_wimg |
|
|
|
class AvgCircle(SimpleWork): |
|
def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32): |
|
c, h, w = shape |
|
self.base_img_w = w |
|
self.overlap_size = overlap_size |
|
self.num_img = num_img |
|
final_img_w = w * num_img - self.overlap_size * num_img |
|
super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn)) |
|
|
|
def get_eps_t_fn(self, eps_scalar_t_fn): |
|
def eps_t_fn(long_x, scalar_t, enable_grad=False): |
|
shift = np.random.randint(self.base_img_w) |
|
long_x = th.cat( |
|
[ |
|
long_x[:,:,:,shift:], |
|
long_x[:,:,:,:shift] |
|
], |
|
dim=-1 |
|
) |
|
|
|
x = th.cat( |
|
[ |
|
long_x, |
|
long_x[:,:,:,:self.overlap_size] |
|
], |
|
dim=-1, |
|
) |
|
xs, _overlap = split_wimg(x, self.num_img, rtn_overlap=True) |
|
assert _overlap == self.overlap_size |
|
full_eps = eps_scalar_t_fn(xs, scalar_t, enable_grad) |
|
|
|
eps = avg_merge_wimg(full_eps, self.overlap_size, n=self.num_img) |
|
eps = th.cat( |
|
[ |
|
(eps[:,:,:,:self.overlap_size] + eps[:,:,:,-self.overlap_size:])/2.0, |
|
eps[:,:,:,self.overlap_size:-self.overlap_size] |
|
], |
|
dim=-1 |
|
) |
|
assert eps.shape == long_x.shape |
|
return th.cat( |
|
[ |
|
eps[:,:,:,-shift:], |
|
eps[:,:,:,:-shift], |
|
], |
|
dim=-1 |
|
) |
|
|
|
|
|
return eps_t_fn |
|
|
|
def x0_fn(self, xt, scalar_t, enable_grad=False): |
|
cur_eps = self.eps_scalar_t_fn(xt, scalar_t, enable_grad) |
|
x0 = xt - scalar_t * cur_eps |
|
return x0, {}, { |
|
"x0": x0.cpu() |
|
} |