littlelittlecloud
commited on
Commit
·
7779efa
1
Parent(s):
e3192e0
split model into small ones
Browse files- .gitattributes +2 -0
- AutoencoderKL.cs +24 -0
- ClipEnocder.cs +20 -0
- DDIMSampler.cs +6 -2
- DDPM.cs +1 -10
- Program.cs +16 -11
- autoencoder_kl.ckpt +3 -0
- clip_encoder.ckpt +3 -0
- ddim_v_sampler.ckpt +2 -2
.gitattributes
CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
cat.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
cat.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
autoencoder_kl.ckpt filter=lfs diff=lfs merge=lfs -text
|
38 |
+
clip_encoder.ckpt filter=lfs diff=lfs merge=lfs -text
|
AutoencoderKL.cs
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
using TorchSharp;
|
2 |
+
|
3 |
+
public class AutoencoderKL
|
4 |
+
{
|
5 |
+
private readonly torch.jit.ScriptModule _model;
|
6 |
+
private readonly float _scale;
|
7 |
+
public torch.Device Device {get;}
|
8 |
+
|
9 |
+
public AutoencoderKL(string modelPath, torch.Device device, float scale = 0.18215f)
|
10 |
+
{
|
11 |
+
_model = TorchSharp.torch.jit.load(modelPath);
|
12 |
+
Device = device;
|
13 |
+
_model.to(Device);
|
14 |
+
_model.eval();
|
15 |
+
_scale = scale;
|
16 |
+
}
|
17 |
+
|
18 |
+
public torch.Tensor Forward(torch.Tensor tokenTensor)
|
19 |
+
{
|
20 |
+
var context = torch.enable_grad(false);
|
21 |
+
tokenTensor = 1.0f / _scale * tokenTensor;
|
22 |
+
return (torch.Tensor)_model.forward(tokenTensor);
|
23 |
+
}
|
24 |
+
}
|
ClipEnocder.cs
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
using TorchSharp;
|
2 |
+
|
3 |
+
public class ClipEncoder
|
4 |
+
{
|
5 |
+
private readonly torch.jit.ScriptModule _model;
|
6 |
+
public torch.Device Device {get;}
|
7 |
+
|
8 |
+
public ClipEncoder(string modelPath, torch.Device device)
|
9 |
+
{
|
10 |
+
_model = TorchSharp.torch.jit.load(modelPath);
|
11 |
+
Device = device;
|
12 |
+
_model.to(Device);
|
13 |
+
_model.eval();
|
14 |
+
}
|
15 |
+
|
16 |
+
public torch.Tensor Forward(torch.Tensor tokenTensor)
|
17 |
+
{
|
18 |
+
return (torch.Tensor)_model.forward(tokenTensor);
|
19 |
+
}
|
20 |
+
}
|
DDIMSampler.cs
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
using TorchSharp;
|
2 |
|
3 |
public class DDIMSampler
|
@@ -15,17 +16,20 @@ public class DDIMSampler
|
|
15 |
public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f)
|
16 |
{
|
17 |
var gap = DDIMSampler.TIME_STEPS / steps;
|
|
|
|
|
18 |
using(var context = torch.enable_grad(false))
|
19 |
{
|
20 |
for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap)
|
21 |
{
|
22 |
-
var t_cur = torch.full(
|
23 |
-
var t_prev = torch.full(
|
24 |
(var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur);
|
25 |
var model_output = e_t_uncond + scale * (e_t - e_t_uncond);
|
26 |
e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output);
|
27 |
var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output);
|
28 |
img = _model.QSample(pred_x0, t_prev, e_t);
|
|
|
29 |
}
|
30 |
|
31 |
return img;
|
|
|
1 |
+
using System;
|
2 |
using TorchSharp;
|
3 |
|
4 |
public class DDIMSampler
|
|
|
16 |
public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f)
|
17 |
{
|
18 |
var gap = DDIMSampler.TIME_STEPS / steps;
|
19 |
+
var batch = img.shape[0];
|
20 |
+
|
21 |
using(var context = torch.enable_grad(false))
|
22 |
{
|
23 |
for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap)
|
24 |
{
|
25 |
+
var t_cur = torch.full(batch, i, dtype: torch.ScalarType.Int64, device: _device);
|
26 |
+
var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device);
|
27 |
(var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur);
|
28 |
var model_output = e_t_uncond + scale * (e_t - e_t_uncond);
|
29 |
e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output);
|
30 |
var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output);
|
31 |
img = _model.QSample(pred_x0, t_prev, e_t);
|
32 |
+
Console.WriteLine(img);
|
33 |
}
|
34 |
|
35 |
return img;
|
DDPM.cs
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
using TorchSharp;
|
2 |
|
3 |
public class DDPM
|
@@ -21,16 +22,6 @@ public class DDPM
|
|
21 |
return (res[0], res[1]);
|
22 |
}
|
23 |
|
24 |
-
public torch.Tensor DecodeImage(torch.Tensor img)
|
25 |
-
{
|
26 |
-
return _model.invoke<torch.Tensor>("decode_image", img);
|
27 |
-
}
|
28 |
-
|
29 |
-
public torch.Tensor ClipEncoder(torch.Tensor tokenTensor)
|
30 |
-
{
|
31 |
-
return _model.invoke<torch.Tensor>("clip_encoder", tokenTensor);
|
32 |
-
}
|
33 |
-
|
34 |
public torch.Tensor QSample(torch.Tensor z, torch.Tensor t, torch.Tensor v)
|
35 |
{
|
36 |
return _model.invoke<torch.Tensor>("q_sample",z, t, v);
|
|
|
1 |
+
using System;
|
2 |
using TorchSharp;
|
3 |
|
4 |
public class DDPM
|
|
|
22 |
return (res[0], res[1]);
|
23 |
}
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
public torch.Tensor QSample(torch.Tensor z, torch.Tensor t, torch.Tensor v)
|
26 |
{
|
27 |
return _model.invoke<torch.Tensor>("q_sample",z, t, v);
|
Program.cs
CHANGED
@@ -8,7 +8,8 @@ torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
|
|
8 |
var device = TorchSharp.torch.device("cuda:0");
|
9 |
var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
|
10 |
var ddimSampler = new DDIMSampler(ddpm);
|
11 |
-
|
|
|
12 |
var start_token = 49406;
|
13 |
var end_token = 49407;
|
14 |
var dictionary = new Dictionary<string, long>(){
|
@@ -20,7 +21,7 @@ var dictionary = new Dictionary<string, long>(){
|
|
20 |
{"green", 1901},
|
21 |
};
|
22 |
|
23 |
-
var batch =
|
24 |
|
25 |
var prompt = "a wild cute green cat";
|
26 |
var tokens = prompt.Split(' ').Select(x => dictionary[x]).ToList();
|
@@ -29,17 +30,21 @@ tokens = tokens.Append(end_token).ToList();
|
|
29 |
tokens = tokens.Concat(Enumerable.Repeat<long>(0, 77 - tokens.Count)).ToList();
|
30 |
var uncontional_tokens = new[]{start_token, end_token}.Concat(Enumerable.Repeat(0, 75)).ToList();
|
31 |
var tokenTensor = torch.tensor(tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
|
32 |
-
tokenTensor = tokenTensor.
|
33 |
var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
|
34 |
-
unconditional_tokenTensor = unconditional_tokenTensor.
|
35 |
var img = torch.randn(batch, 4, 96, 96, dtype: torch.ScalarType.Float32, device: device);
|
36 |
-
var t = torch.
|
37 |
-
var condition =
|
38 |
-
var unconditional_condition =
|
39 |
var ddim_steps = 50;
|
40 |
img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
|
41 |
-
var decoded_images = (torch.Tensor)
|
42 |
decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
8 |
var device = TorchSharp.torch.device("cuda:0");
|
9 |
var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
|
10 |
var ddimSampler = new DDIMSampler(ddpm);
|
11 |
+
var autoencoderKL = new AutoencoderKL("autoencoder_kl.ckpt", device);
|
12 |
+
var clipEncoder = new ClipEncoder("clip_encoder.ckpt", device);
|
13 |
var start_token = 49406;
|
14 |
var end_token = 49407;
|
15 |
var dictionary = new Dictionary<string, long>(){
|
|
|
21 |
{"green", 1901},
|
22 |
};
|
23 |
|
24 |
+
var batch = 2;
|
25 |
|
26 |
var prompt = "a wild cute green cat";
|
27 |
var tokens = prompt.Split(' ').Select(x => dictionary[x]).ToList();
|
|
|
30 |
tokens = tokens.Concat(Enumerable.Repeat<long>(0, 77 - tokens.Count)).ToList();
|
31 |
var uncontional_tokens = new[]{start_token, end_token}.Concat(Enumerable.Repeat(0, 75)).ToList();
|
32 |
var tokenTensor = torch.tensor(tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
|
33 |
+
tokenTensor = tokenTensor.repeat(batch, 1);
|
34 |
var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
|
35 |
+
unconditional_tokenTensor = unconditional_tokenTensor.repeat(batch, 1);
|
36 |
var img = torch.randn(batch, 4, 96, 96, dtype: torch.ScalarType.Float32, device: device);
|
37 |
+
var t = torch.full(new[]{batch, 1L}, value: batch, dtype: torch.ScalarType.Int32, device: device);
|
38 |
+
var condition = clipEncoder.Forward(tokenTensor);
|
39 |
+
var unconditional_condition = clipEncoder.Forward(unconditional_tokenTensor);
|
40 |
var ddim_steps = 50;
|
41 |
img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
|
42 |
+
var decoded_images = (torch.Tensor)autoencoderKL.Forward(img);
|
43 |
decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
|
44 |
+
|
45 |
+
for(int i = 0; i!= batch; ++i)
|
46 |
+
{
|
47 |
+
var image = decoded_images[i];
|
48 |
+
image = (image * 255.0).to(torch.ScalarType.Byte).cpu();
|
49 |
+
torchvision.io.write_image(image, $"{i}.png", torchvision.ImageFormat.Png);
|
50 |
+
}
|
autoencoder_kl.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0f5b15ed1a0f81a0ec4a274ac368a5f4fb84f0ce7c3676e683de527e69a59840
|
3 |
+
size 334940269
|
clip_encoder.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef9706f02a78b2cf93acff22f3036bc3e629d0a5b595c640ada1f73788826f37
|
3 |
+
size 1416615515
|
ddim_v_sampler.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ffa5c521f78e160bb4907a197f8308fa498f21bc3738ff49aded45afe9dbc47d
|
3 |
+
size 3465251643
|