yjhuangcd
First commit
9965bf6
import torch as th
from .generic_sampler import batch_mul
def get_x0_grad_pred_fn(raw_net_model, cond_loss_fn, weight_fn, x0_update, thres_t):
def fn(xt, scalar_t):
xt = xt.requires_grad_(True)
x0_pred = raw_net_model(xt, scalar_t)
loss_info = {
"raw_x0": cond_loss_fn(x0_pred.detach()).cpu(),
}
traj_info = {
"t": scalar_t,
}
if scalar_t < thres_t:
x0_cor = x0_pred.detach()
else:
pred_loss = cond_loss_fn(x0_pred)
grad_term = th.autograd.grad(pred_loss.sum(), xt)[0]
weights = weight_fn(x0_pred, grad_term, cond_loss_fn)
x0_cor = (x0_pred - batch_mul(weights, grad_term)).detach()
loss_info["weight"] = weights.detach().cpu()
traj_info["grad"] = grad_term.detach().cpu()
if x0_update:
x0 = x0_update(x0_cor, scalar_t)
else:
x0 = x0_cor
loss_info["cor_x0"] = cond_loss_fn(x0_cor.detach()).cpu()
loss_info["x0"] = cond_loss_fn(x0.detach()).cpu()
traj_info.update({
"raw_x0": x0_pred.detach().cpu(),
"cor_x0": x0_cor.detach().cpu(),
"x0": x0.detach().cpu(),
}
)
return x0_cor, loss_info, traj_info
return fn