|
import torch |
|
import torch as th |
|
from einops import rearrange |
|
|
|
from .generic_sampler import SimpleWork |
|
from .w_img import split_wimg, avg_merge_wimg |
|
|
|
class CondIndSimple(SimpleWork): |
|
def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32): |
|
c, h, w = shape |
|
assert overlap_size == w // 2 |
|
self.overlap_size = overlap_size |
|
self.num_img = num_img |
|
final_img_w = w * num_img - self.overlap_size * (num_img - 1) |
|
super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn)) |
|
|
|
def loss(self, x): |
|
x1, x2 = x[:-1], x[1:] |
|
return th.sum( |
|
(th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2, |
|
dim=(1, 2, 3), |
|
) |
|
|
|
def get_eps_t_fn(self, eps_scalar_t_fn): |
|
def eps_t_fn(long_x, scalar_t, y=None): |
|
xs = split_wimg(long_x, self.num_img, rtn_overlap=False) |
|
if y is not None: |
|
y = y.repeat_interleave(self.num_img) |
|
scalar_t = scalar_t.repeat_interleave(self.num_img) |
|
full_eps = eps_scalar_t_fn(xs, scalar_t, y=y) |
|
full_eps = rearrange( |
|
full_eps, |
|
"(b n) c h w -> n b c h w", n = self.num_img |
|
) |
|
|
|
|
|
half_eps = eps_scalar_t_fn(xs[:,:,:,-self.overlap_size:], scalar_t, y=y) |
|
half_eps = rearrange( |
|
half_eps, |
|
"(b n) c h w -> n b c h w", n = self.num_img |
|
) |
|
|
|
half_eps[-1]=0 |
|
|
|
full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps |
|
whole_eps = rearrange( |
|
full_eps, |
|
"n b c h w -> (b n) c h w" |
|
) |
|
return avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False) |
|
return eps_t_fn |
|
|
|
|
|
|
|
class CondIndSR(SimpleWork): |
|
def __init__(self, shape, eps_scalar_t_fn, num_img, low_res, overlap_size=128): |
|
c, h, w = shape |
|
assert overlap_size == w // 2 |
|
self.overlap_size = overlap_size |
|
self.low_overlap_size = low_res.shape[-2] // 2 |
|
self.num_img = num_img |
|
final_img_w = w * num_img - self.overlap_size * (num_img - 1) |
|
assert low_res.shape[-1] == self.low_overlap_size * (num_img + 1) |
|
|
|
self.square_fn = self.get_square_sr_fn(eps_scalar_t_fn, low_res) |
|
self.half_fn = self.get_half_sr_fn(eps_scalar_t_fn, low_res) |
|
|
|
super().__init__((c, h, final_img_w), self.get_eps_t_fn()) |
|
|
|
def get_square_sr_fn(self, eps_fn, low_res): |
|
low_res = split_wimg(low_res, self.num_img, False) |
|
def _fn(_x, _t, enable_grad): |
|
context = th.enable_grad if enable_grad else th.no_grad |
|
with context(): |
|
vec_t = th.ones(_x.shape[0]).cuda() * _t |
|
rtn = eps_fn(_x, vec_t, low_res) |
|
rtn = rearrange( |
|
rtn, |
|
"(b n) c h w -> n b c h w", n = self.num_img |
|
) |
|
return rtn |
|
return _fn |
|
|
|
def get_half_sr_fn(self, eps_fn, low_res): |
|
low_res = split_wimg(low_res, self.num_img, False) |
|
def _fn(_x, _t, enable_grad): |
|
context = th.enable_grad if enable_grad else th.no_grad |
|
with context(): |
|
vec_t = th.ones(_x.shape[0]).cuda() * _t |
|
half_eps = eps_fn(_x[:,:,:,-self.overlap_size:], vec_t, low_res[:,:,:,-self.low_overlap_size:]) |
|
half_eps = rearrange( |
|
half_eps, |
|
"(b n) c h w -> n b c h w", n = self.num_img |
|
) |
|
|
|
half_eps[-1]=0 |
|
return half_eps |
|
return _fn |
|
|
|
def get_eps_t_fn(self): |
|
def eps_t_fn(in_x, scalar_t, enable_grad=False): |
|
xs = split_wimg(in_x, self.num_img, rtn_overlap=False) |
|
|
|
|
|
full_eps = self.square_fn(xs, scalar_t, enable_grad) |
|
|
|
half_eps = self.half_fn(xs, scalar_t, enable_grad) |
|
|
|
full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps |
|
whole_eps = rearrange( |
|
full_eps, |
|
"n b c h w -> (b n) c h w" |
|
) |
|
out_eps = avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False) |
|
return out_eps |
|
return eps_t_fn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|