using System; using TorchSharp; public class DDIMSampler { private readonly DDPM _model; private const int TIME_STEPS = 1000; private readonly torch.Device _device; public DDIMSampler(DDPM model, float scale = 9.0f) { _model = model; _device = model.Device; } public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f) { var gap = DDIMSampler.TIME_STEPS / steps; var batch = img.shape[0]; using(var context = torch.enable_grad(false)) { for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap) { var t_cur = torch.full(batch, i, dtype: torch.ScalarType.Int64, device: _device); var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device); (var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur); var model_output = e_t_uncond + scale * (e_t - e_t_uncond); e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output); var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output); img = _model.QSample(pred_x0, t_prev, e_t); Console.WriteLine(img); } return img; } } }