rule-guided-music / diff_collage /condind_long.py
yjhuangcd
First commit
9965bf6
import torch
import torch as th
from einops import rearrange
from .generic_sampler import SimpleWork
from .w_img import split_wimg, avg_merge_wimg
class CondIndSimple(SimpleWork):
def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32):
c, h, w = shape
assert overlap_size == w // 2
self.overlap_size = overlap_size
self.num_img = num_img
final_img_w = w * num_img - self.overlap_size * (num_img - 1)
super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn))
def loss(self, x):
x1, x2 = x[:-1], x[1:]
return th.sum(
(th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
dim=(1, 2, 3),
)
def get_eps_t_fn(self, eps_scalar_t_fn):
def eps_t_fn(long_x, scalar_t, y=None):
xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
if y is not None:
y = y.repeat_interleave(self.num_img)
scalar_t = scalar_t.repeat_interleave(self.num_img)
full_eps = eps_scalar_t_fn(xs, scalar_t, y=y) #((b,n), c, h, w)
full_eps = rearrange(
full_eps,
"(b n) c h w -> n b c h w", n = self.num_img
)
# calculate half eps
half_eps = eps_scalar_t_fn(xs[:,:,:,-self.overlap_size:], scalar_t, y=y) #((b,n), c, h, w//2)
half_eps = rearrange(
half_eps,
"(b n) c h w -> n b c h w", n = self.num_img
)
half_eps[-1]=0
full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps
whole_eps = rearrange(
full_eps,
"n b c h w -> (b n) c h w"
)
return avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
return eps_t_fn
class CondIndSR(SimpleWork):
def __init__(self, shape, eps_scalar_t_fn, num_img, low_res, overlap_size=128):
c, h, w = shape
assert overlap_size == w // 2
self.overlap_size = overlap_size
self.low_overlap_size = low_res.shape[-2] // 2
self.num_img = num_img
final_img_w = w * num_img - self.overlap_size * (num_img - 1)
assert low_res.shape[-1] == self.low_overlap_size * (num_img + 1)
self.square_fn = self.get_square_sr_fn(eps_scalar_t_fn, low_res)
self.half_fn = self.get_half_sr_fn(eps_scalar_t_fn, low_res)
super().__init__((c, h, final_img_w), self.get_eps_t_fn())
def get_square_sr_fn(self, eps_fn, low_res):
low_res = split_wimg(low_res, self.num_img, False)
def _fn(_x, _t, enable_grad):
context = th.enable_grad if enable_grad else th.no_grad
with context():
vec_t = th.ones(_x.shape[0]).cuda() * _t
rtn = eps_fn(_x, vec_t, low_res)
rtn = rearrange(
rtn,
"(b n) c h w -> n b c h w", n = self.num_img
)
return rtn
return _fn
def get_half_sr_fn(self, eps_fn, low_res):
low_res = split_wimg(low_res, self.num_img, False)
def _fn(_x, _t, enable_grad):
context = th.enable_grad if enable_grad else th.no_grad
with context():
vec_t = th.ones(_x.shape[0]).cuda() * _t
half_eps = eps_fn(_x[:,:,:,-self.overlap_size:], vec_t, low_res[:,:,:,-self.low_overlap_size:])
half_eps = rearrange(
half_eps,
"(b n) c h w -> n b c h w", n = self.num_img
)
half_eps[-1]=0
return half_eps
return _fn
def get_eps_t_fn(self):
def eps_t_fn(in_x, scalar_t, enable_grad=False):
xs = split_wimg(in_x, self.num_img, rtn_overlap=False)
# full eps
full_eps = self.square_fn(xs, scalar_t, enable_grad)
# calculate half eps
half_eps = self.half_fn(xs, scalar_t, enable_grad)
full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps
whole_eps = rearrange(
full_eps,
"n b c h w -> (b n) c h w"
)
out_eps = avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
return out_eps
return eps_t_fn
# class CondIndLong(SimpleWork):
# def __init__(self, shape, eps_scalar_t_fn, overlap_size=32):
# super().__init__(shape, eps_scalar_t_fn)
# self.overlap_size = overlap_size
# def loss(self, x):
# x1, x2 = x[:-1], x[1:]
# return th.sum(
# (th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
# dim=(1, 2, 3),
# )
# def generate_xT(self, n):
# white_noise = th.randn((n , *self.shape)).cuda()
# return self.noise(white_noise, None) * 80.0
# def noise(self, xt, scalar_t):
# del scalar_t
# noise = th.randn_like(xt)
# b, _, _, w = xt.shape
# final_img_w = w * b - self.overlap_size * (b - 1)
# noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w]
# noise = split_wimg(noise, b, rtn_overlap=False)
# return noise
# def merge(self, xs):
# return avg_merge_wimg(xs, self.overlap_size)