|
using System; |
|
using TorchSharp; |
|
|
|
public class DDPM : IDisposable |
|
{ |
|
private readonly torch.jit.ScriptModule _model; |
|
public torch.Device Device {get;} |
|
public DDPM(string modelPath, torch.Device device) |
|
{ |
|
_model = TorchSharp.torch.jit.load(modelPath); |
|
Device = device; |
|
_model.to(Device); |
|
_model.eval(); |
|
} |
|
|
|
public (torch.Tensor e_T_Uncondition, torch.Tensor e_T) DiffusionModel(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, torch.Tensor t) |
|
{ |
|
var x_in = torch.cat(new[] { img, img }); |
|
var condition_in = torch.cat(new[] { unconditional_condition, condition }); |
|
var t_in = torch.cat(new[] { t, t }); |
|
var res = _model.invoke<torch.Tensor>("diffusion_model", x_in, t_in, condition_in).chunk(2); |
|
return (res[0], res[1]); |
|
} |
|
|
|
public torch.Tensor QSample(torch.Tensor z, torch.Tensor t, torch.Tensor v) |
|
{ |
|
return _model.invoke<torch.Tensor>("q_sample",z, t, v); |
|
} |
|
|
|
public torch.Tensor PredictEPSFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v) |
|
{ |
|
return _model.invoke<torch.Tensor>("predict_eps_from_z_and_v", z, t, v); |
|
} |
|
|
|
public torch.Tensor PredictStartFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v) |
|
{ |
|
return _model.invoke<torch.Tensor>("predict_start_from_z_and_v", z, t, v); |
|
} |
|
|
|
public void Dispose() |
|
{ |
|
_model.Dispose(); |
|
_model = null; |
|
} |
|
} |
|
|