yjhuangcd
First commit
9965bf6
import math
import torch as th
from einops import rearrange
import numpy as np
from .generic_sampler import batch_mul
def split_wimg(himg, n_img, rtn_overlap=True):
if himg.ndim == 3:
himg = himg[None]
_, _, h, w = himg.shape
overlap_size = (n_img * h - w) // (n_img - 1)
assert n_img * h - overlap_size * (n_img - 1) == w
himg = himg[0]
rtn_img = [himg[:, :, :h]]
for i in range(n_img - 1):
rtn_img.append(himg[:, :, (h - overlap_size) * (i + 1) : h + (h - overlap_size) * (i + 1)])
if rtn_overlap:
return th.stack(rtn_img), overlap_size
return th.stack(rtn_img)
def merge_wimg(imgs, overlap_size):
_, _, _, w = imgs.shape
rtn_img = [imgs[0]]
for cur_img in imgs[1:]:
rtn_img.append(cur_img[:, :, overlap_size:])
first_img = th.cat(rtn_img, dim=-1)
rtn_img = []
for cur_img in imgs[:-1]:
rtn_img.append(cur_img[:, :, : w - overlap_size])
rtn_img.append(imgs[-1])
second_img = th.cat(rtn_img, dim=-1)
return (first_img + second_img) / 2.0
def get_x0_pred_fn(raw_net_model, cond_loss_fn, weight_fn, x0_fn, thres_t, init_fn=None):
def fn(xt, scalar_t):
if init_fn is not None:
xt = init_fn(xt, scalar_t)
xt = xt.requires_grad_(True)
x0_pred = raw_net_model(xt, scalar_t)
loss_info = {
"raw_x0": cond_loss_fn(x0_pred.detach()).cpu(),
}
traj_info = {
"t": scalar_t,
}
if scalar_t < thres_t:
x0_cor = x0_pred.detach()
else:
pred_loss = cond_loss_fn(x0_pred)
grad_term = th.autograd.grad(pred_loss.sum(), xt)[0]
weights = weight_fn(x0_pred, grad_term, cond_loss_fn)
x0_cor = (x0_pred - batch_mul(weights, grad_term)).detach()
loss_info["weight"] = weights.detach().cpu()
traj_info["grad"] = grad_term.detach().cpu()
if x0_fn:
x0 = x0_fn(x0_cor, scalar_t)
else:
x0 = x0_cor
loss_info["cor_x0"] = cond_loss_fn(x0_cor.detach()).cpu()
loss_info["x0"] = cond_loss_fn(x0.detach()).cpu()
traj_info.update({
"raw_x0": x0_pred.detach().cpu(),
"cor_x0": x0_cor.detach().cpu(),
"x0": x0.detach().cpu(),
}
)
return x0_cor, loss_info, traj_info
return fn
def simple_noise(cur_t, xt):
del cur_t
return th.randn_like(xt)
def get_fix_weight_fn(fix_weight):
def weight_fn(xs, grads, *args):
del grads, args
return th.ones(xs.shape[0]).to(xs) * fix_weight
return weight_fn
class SeqWorker:
def __init__(self, overlap_size=10, src_img=None):
self.overlap_size = overlap_size
self.src_img = src_img
def loss(self, x):
return th.sum(
(th.abs(self.src_img[:, :, :, -self.overlap_size :] - x[:, :, :, : self.overlap_size]))
** 2,
dim=(1, 2, 3),
)
def x0_replace(self, x0):
rtn_x0 = x0.clone()
rtn_x0[:, :, :, : self.overlap_size] = self.src_img[:, :, :, -self.overlap_size :]
return x0
def optimal_weight_fn(self, x0, grads, *args, ratio=1.0):
del args
overlap_size = self.overlap_size
# argmin_{w} (delta_pixel - w * delta_pixel)^2
delta_pixel = x0[:, :, :, :overlap_size] - self.src_img[:, :, :, -overlap_size:]
delta_grads = grads[:, :, :, :overlap_size]
num = th.sum(delta_pixel * delta_grads).item()
denum = th.sum(delta_grads * delta_grads).item()
_optimal_weight = num / denum
if math.isnan(_optimal_weight):
print(denum)
raise RuntimeError("nan for weights")
return ratio * _optimal_weight * th.ones(x0.shape[0]).to(x0)
class CircleWorker:
def __init__(self, overlap_size=10, adam_num_iter=100):
self.overlap_size = overlap_size
self.adam_num_iter = adam_num_iter
def get_match_patch(self, x):
tail = x[:, :, :, -self.overlap_size :]
head = x[:, :, :, : self.overlap_size]
tail = th.roll(tail, 1, 0)
return tail, head
def loss(self, x):
tail, head = self.get_match_patch(x)
return th.sum(
(tail - head)**2,
dim=(1, 2, 3),
)
def split_noise(self, cur_t, xt):
noise = simple_noise(cur_t, xt)
b, _, _, w = xt.shape
final_img_w = w * b - self.overlap_size * b
noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w]
noise = th.cat([noise, noise[:,:,:, :self.overlap_size]], dim=-1)
noise, _ = split_wimg(noise, b)
return noise
def merge_circle_image(self, xt):
merged_long_img = merge_wimg(xt, self.overlap_size)
return th.cat(
[
(merged_long_img[:,:,:self.overlap_size] + merged_long_img[:,:,-self.overlap_size:]) / 2.0,
merged_long_img[:,:,self.overlap_size:-self.overlap_size],
],
dim=-1
)
def split_circle_image(self, merged_long_img, n):
imgs,_ = split_wimg(
th.cat(
[
merged_long_img,
merged_long_img[:,:,:self.overlap_size],
],
dim = -1,
),
n
)
return imgs
def optimal_weight_fn(self, xs, grads, *args):
del args
# argmin_{w} (delta_pixel - w * delta_pixel)^2
tail, head = self.get_match_patch(xs)
delta_pixel = tail - head
tail, head = self.get_match_patch(grads)
delta_grads = tail - head
num = th.sum(delta_pixel * delta_grads).item()
denum = th.sum(delta_grads * delta_grads).item()
_optimal_weight = num / denum
return _optimal_weight * th.ones(xs.shape[0]).to(xs)
def adam_grad_weight(self, x0, grad_term, cond_loss_fn):
init_weight = self.optimal_weight_fn(x0, grad_term)
grad_term = grad_term.detach()
x0 = x0.detach()
with th.enable_grad():
weights = init_weight.requires_grad_()
optimizer = th.optim.Adam(
[
weights,
],
lr=1e-2,
)
def _loss(w):
cor_x0 = x0 - batch_mul(w, grad_term)
return cond_loss_fn(cor_x0).sum()
for _ in range(self.adam_num_iter):
optimizer.zero_grad()
_cur_loss = _loss(weights)
_cur_loss.backward()
optimizer.step()
return weights
# TODO:
def x0_replace(self, x0, sclar_t, thres_t):
if sclar_t > thres_t:
merge_x0 = merge_wimg(x0, self.overlap_size)
return split_wimg(merge_x0, x0.shape[0])[0]
else:
return x0
class ParaWorker:
def __init__(self, overlap_size=10, adam_num_iter=100):
self.overlap_size = overlap_size
self.adam_num_iter = adam_num_iter
def loss(self, x):
x1, x2 = x[:-1], x[1:]
return th.sum(
(th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
dim=(1, 2, 3),
)
def split_noise(self, xt, cur_t):
noise = simple_noise(cur_t, xt)
b, _, _, w = xt.shape
final_img_w = w * b - self.overlap_size * (b - 1)
noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w]
noise, _ = split_wimg(noise, b)
return noise
def optimal_weight_fn(self, xs, grads, *args):
del args
overlap_size = self.overlap_size
# argmin_{w} (delta_pixel - w * delta_pixel)^2
delta_pixel = xs[:-1, :, :, -overlap_size:] - xs[1:, :, :, :overlap_size]
delta_grads = grads[:-1, :, :, -overlap_size:] - grads[1:, :, :, :overlap_size]
num = th.sum(delta_pixel * delta_grads).item()
denum = th.sum(delta_grads * delta_grads).item()
_optimal_weight = num / denum
return _optimal_weight * th.ones(xs.shape[0]).to(xs)
def adam_grad_weight(self, x0, grad_term, cond_loss_fn):
init_weight = self.optimal_weight_fn(x0, grad_term)
grad_term = grad_term.detach()
x0 = x0.detach()
with th.enable_grad():
weights = init_weight.requires_grad_()
optimizer = th.optim.Adam(
[
weights,
],
lr=1e-2,
)
def _loss(w):
cor_x0 = x0 - batch_mul(w, grad_term)
return cond_loss_fn(cor_x0).sum()
for _ in range(self.adam_num_iter):
optimizer.zero_grad()
_cur_loss = _loss(weights)
_cur_loss.backward()
optimizer.step()
return weights
def x0_replace(self, x0, sclar_t, thres_t):
if sclar_t > thres_t:
merge_x0 = merge_wimg(x0, self.overlap_size)
return split_wimg(merge_x0, x0.shape[0])[0]
else:
return x0
class ParaWorkerC(ParaWorker):
def __init__(self, src_img, mask_img, inpaint_w = 1.0, overlap_size=10, adam_num_iter=100):
self.src_img = src_img
self.inpaint_w = inpaint_w
self.mask_img = mask_img # 1 indicate masked given pixels
super().__init__(overlap_size, adam_num_iter)
def loss(self, x):
if x.shape[0] == 1:
return th.sum(
th.sum(
th.square(self.src_img[:,:,:,:x.shape[-1]] - x), dim=(0,1)
) * self.mask_img[:,:x.shape[-1]]
)
else:
consistent_loss = super().loss(x)
# merge image
merge_x = merge_wimg(x, self.overlap_size)
inpating_loss = th.sum(
th.sum(
th.square(self.src_img[:,:,:,:merge_x.shape[-1]] - merge_x), dim=(0,1)
) * self.mask_img[:,:merge_x.shape[-1]]
)
return consistent_loss + inpating_loss / (x.shape[-1] - 1)
def x0_replace(self, x0, sclar_t, thres_t):
if sclar_t > thres_t:
merge_x = merge_wimg(x0, self.overlap_size)
src_img = self.src_img[:,:,:,:merge_x.shape[-1]]
mask_img = self.mask_img[:,:merge_x.shape[-1]]
merge_x = th.where(mask_img[None,None], src_img, merge_x)
return split_wimg(merge_x, x0.shape[0])[0]
else:
return x0
class SplitMergeOp:
def __init__(self, avg_overlap=32):
self.avg_overlap = avg_overlap
self.cur_overlap_int = None
def sample(self, n):
# lower_coef = 3 / 4.0
_lower_bound = self.avg_overlap - 6
base_overlap = np.ones(n) * _lower_bound
total_ball = (self.avg_overlap - _lower_bound) * n
random_number = np.random.randint(0, total_ball - n, n-1)
random_number = np.sort(random_number)
balls = np.append(random_number, total_ball - n) - np.insert(random_number, 0, 0) + np.ones(n) + base_overlap
assert np.sum(balls) == n * self.avg_overlap
# TODO: FIXME
balls = np.ones(n) * self.avg_overlap
return balls.astype(np.int)
def reset(self, n):
self.cur_overlap_int = self.sample(n)
def split(self, img, n, img_w=64):
assert img.ndim == 3
# assert img.shape[-1] > (n-1) * self.avg_overlap
assert (n-1) == self.cur_overlap_int.shape[0]
assert (n-1) * self.avg_overlap + img.shape[-1] == n * img_w
cur_idx = 0
imgs = []
for cur_overlap in self.cur_overlap_int:
imgs.append(img[:,:,cur_idx:cur_idx + img_w])
cur_idx = cur_idx + img_w - cur_overlap
imgs.append(img[:,:,cur_idx:])
return th.stack(imgs)
def merge(self, imgs):
b = imgs.shape[0]
img_size = imgs.shape[-1]
assert b - 1 == self.cur_overlap_int.shape[0]
img_width = b * imgs.shape[-1] - np.sum(self.cur_overlap_int)
wimg = th.zeros((3, imgs.shape[-2], img_width)).to(imgs)
ncnt = th.zeros(img_width).to(imgs)
cur_idx = 0
for i_th, cur_img in enumerate(imgs):
wimg[:,:,cur_idx:cur_idx + img_size] += cur_img
ncnt[cur_idx:cur_idx + img_size] += 1.0
if i_th < b -1:
cur_idx = cur_idx + img_size - self.cur_overlap_int[i_th]
return wimg / ncnt[None,None,:]
class ParaWorkerFix:
def __init__(self, overlap_size=10, adam_num_iter=100):
self.overlap_size = overlap_size
self.adam_num_iter = adam_num_iter
self.op = SplitMergeOp(overlap_size)
def loss(self, x):
avg_x = self.op.split(
self.op.merge(x), x.shape[0], x.shape[-1]
)
return th.sum(
(x - avg_x) ** 2,
dim=(1, 2, 3),
)
def split_noise(self, cur_t, xt):
noise = simple_noise(cur_t, xt)
b, _, _, w = xt.shape
final_img_w = w * b - self.overlap_size * (b - 1)
noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w][0]
noise = self.op.split(noise, b, w)
return noise
def adam_grad_weight(self, x0, grad_term, cond_loss_fn):
init_weight = th.ones(x0.shape[0]).to(x0)
grad_term = grad_term.detach()
x0 = x0.detach()
with th.enable_grad():
weights = init_weight.requires_grad_()
optimizer = th.optim.Adam(
[
weights,
],
lr=1e-2,
)
def _loss(w):
cor_x0 = x0 - batch_mul(w, grad_term)
return cond_loss_fn(cor_x0).sum()
for _ in range(self.adam_num_iter):
optimizer.zero_grad()
_cur_loss = _loss(weights)
_cur_loss.backward()
optimizer.step()
return weights
def x0_replace(self, x0, sclar_t, thres_t):
if sclar_t > thres_t:
merge_x0 = self.op.merge(x0)
return self.op.split(merge_x0, x0.shape[0], x0.shape[-1])
else:
return x0