littlelittlecloud
commited on
Commit
·
e3192e0
1
Parent(s):
db2ecf2
refactor and add DDPM and DDIMSampler
Browse files- .gitattributes +2 -0
- DDIMSampler.cs +34 -0
- DDPM.cs +48 -0
- Program.cs +11 -29
- cat.png +2 -2
- ddim_v_sampler.ckpt +2 -2
.gitattributes
CHANGED
@@ -22,6 +22,7 @@
|
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
@@ -32,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
26 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
27 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
28 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
cat.png filter=lfs diff=lfs merge=lfs -text
|
DDIMSampler.cs
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
using TorchSharp;
|
2 |
+
|
3 |
+
public class DDIMSampler
|
4 |
+
{
|
5 |
+
private readonly DDPM _model;
|
6 |
+
private const int TIME_STEPS = 1000;
|
7 |
+
private readonly torch.Device _device;
|
8 |
+
|
9 |
+
public DDIMSampler(DDPM model, float scale = 9.0f)
|
10 |
+
{
|
11 |
+
_model = model;
|
12 |
+
_device = model.Device;
|
13 |
+
}
|
14 |
+
|
15 |
+
public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f)
|
16 |
+
{
|
17 |
+
var gap = DDIMSampler.TIME_STEPS / steps;
|
18 |
+
using(var context = torch.enable_grad(false))
|
19 |
+
{
|
20 |
+
for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap)
|
21 |
+
{
|
22 |
+
var t_cur = torch.full(1, i, dtype: torch.ScalarType.Int64, device: _device);
|
23 |
+
var t_prev = torch.full(1, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device);
|
24 |
+
(var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur);
|
25 |
+
var model_output = e_t_uncond + scale * (e_t - e_t_uncond);
|
26 |
+
e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output);
|
27 |
+
var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output);
|
28 |
+
img = _model.QSample(pred_x0, t_prev, e_t);
|
29 |
+
}
|
30 |
+
|
31 |
+
return img;
|
32 |
+
}
|
33 |
+
}
|
34 |
+
}
|
DDPM.cs
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
using TorchSharp;
|
2 |
+
|
3 |
+
public class DDPM
|
4 |
+
{
|
5 |
+
private readonly torch.jit.ScriptModule _model;
|
6 |
+
public torch.Device Device {get;}
|
7 |
+
public DDPM(string modelPath, torch.Device device)
|
8 |
+
{
|
9 |
+
_model = TorchSharp.torch.jit.load(modelPath);
|
10 |
+
Device = device;
|
11 |
+
_model.to(Device);
|
12 |
+
_model.eval();
|
13 |
+
}
|
14 |
+
|
15 |
+
public (torch.Tensor e_T_Uncondition, torch.Tensor e_T) DiffusionModel(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, torch.Tensor t)
|
16 |
+
{
|
17 |
+
var x_in = torch.cat(new[] { img, img });
|
18 |
+
var condition_in = torch.cat(new[] { unconditional_condition, condition });
|
19 |
+
var t_in = torch.cat(new[] { t, t });
|
20 |
+
var res = _model.invoke<torch.Tensor>("diffusion_model", x_in, t_in, condition_in).chunk(2);
|
21 |
+
return (res[0], res[1]);
|
22 |
+
}
|
23 |
+
|
24 |
+
public torch.Tensor DecodeImage(torch.Tensor img)
|
25 |
+
{
|
26 |
+
return _model.invoke<torch.Tensor>("decode_image", img);
|
27 |
+
}
|
28 |
+
|
29 |
+
public torch.Tensor ClipEncoder(torch.Tensor tokenTensor)
|
30 |
+
{
|
31 |
+
return _model.invoke<torch.Tensor>("clip_encoder", tokenTensor);
|
32 |
+
}
|
33 |
+
|
34 |
+
public torch.Tensor QSample(torch.Tensor z, torch.Tensor t, torch.Tensor v)
|
35 |
+
{
|
36 |
+
return _model.invoke<torch.Tensor>("q_sample",z, t, v);
|
37 |
+
}
|
38 |
+
|
39 |
+
public torch.Tensor PredictEPSFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v)
|
40 |
+
{
|
41 |
+
return _model.invoke<torch.Tensor>("predict_eps_from_z_and_v", z, t, v);
|
42 |
+
}
|
43 |
+
|
44 |
+
public torch.Tensor PredictStartFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v)
|
45 |
+
{
|
46 |
+
return _model.invoke<torch.Tensor>("predict_start_from_z_and_v", z, t, v);
|
47 |
+
}
|
48 |
+
}
|
Program.cs
CHANGED
@@ -3,11 +3,11 @@ using System.Collections.Generic;
|
|
3 |
using System.IO;
|
4 |
using System.Linq;
|
5 |
using TorchSharp;
|
|
|
6 |
torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
|
7 |
var device = TorchSharp.torch.device("cuda:0");
|
8 |
-
var
|
9 |
-
|
10 |
-
ddpm_v_sampler.eval();
|
11 |
|
12 |
var start_token = 49406;
|
13 |
var end_token = 49407;
|
@@ -34,30 +34,12 @@ var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype
|
|
34 |
unconditional_tokenTensor = unconditional_tokenTensor.reshape((long)batch, -1);
|
35 |
var img = torch.randn(batch, 4, 96, 96, dtype: torch.ScalarType.Float32, device: device);
|
36 |
var t = torch.ones(batch, dtype: torch.ScalarType.Int32, device: device);
|
37 |
-
var condition =
|
38 |
-
var unconditional_condition =
|
39 |
-
Console.WriteLine(condition);
|
40 |
-
var timesteps = 1000;
|
41 |
var ddim_steps = 50;
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: device);
|
49 |
-
img = (torch.Tensor)ddpm_v_sampler.invoke("ddim_sampler", img, condition, unconditional_condition, t_cur, t_prev);
|
50 |
-
Console.WriteLine($"step {i}");
|
51 |
-
}
|
52 |
-
|
53 |
-
var decoded_images = (torch.Tensor)ddpm_v_sampler.invoke("decode_image", img);
|
54 |
-
decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
|
55 |
-
|
56 |
-
for(int i = 0; i!= batch; ++i)
|
57 |
-
{
|
58 |
-
// c * h * w
|
59 |
-
var image = decoded_images[i];
|
60 |
-
image = (image * 255.0).to(torch.ScalarType.Byte).cpu();
|
61 |
-
torchvision.io.write_image(image, $"{i}.png", torchvision.ImageFormat.Png);
|
62 |
-
}
|
63 |
-
}
|
|
|
3 |
using System.IO;
|
4 |
using System.Linq;
|
5 |
using TorchSharp;
|
6 |
+
|
7 |
torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
|
8 |
var device = TorchSharp.torch.device("cuda:0");
|
9 |
+
var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
|
10 |
+
var ddimSampler = new DDIMSampler(ddpm);
|
|
|
11 |
|
12 |
var start_token = 49406;
|
13 |
var end_token = 49407;
|
|
|
34 |
unconditional_tokenTensor = unconditional_tokenTensor.reshape((long)batch, -1);
|
35 |
var img = torch.randn(batch, 4, 96, 96, dtype: torch.ScalarType.Float32, device: device);
|
36 |
var t = torch.ones(batch, dtype: torch.ScalarType.Int32, device: device);
|
37 |
+
var condition = ddpm.ClipEncoder(tokenTensor);
|
38 |
+
var unconditional_condition = ddpm.ClipEncoder(unconditional_tokenTensor);
|
|
|
|
|
39 |
var ddim_steps = 50;
|
40 |
+
img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
|
41 |
+
var decoded_images = (torch.Tensor)ddpm.DecodeImage(img);
|
42 |
+
decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
|
43 |
+
var image = decoded_images[0];
|
44 |
+
image = (image * 255.0).to(torch.ScalarType.Byte).cpu();
|
45 |
+
torchvision.io.write_image(image, $"0.png", torchvision.ImageFormat.Png);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cat.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
ddim_v_sampler.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22b16b2fc18c3b20c0eb74ed49a8f1834388fbfd84a49110340943f22fd30fa1
|
3 |
+
size 5216915007
|