littlelittlecloud
commited on
Commit
·
2cc08ea
1
Parent(s):
7779efa
disposable pattern
Browse files- AutoencoderKL.cs +9 -2
- ClipEnocder.cs +9 -2
- DDPM.cs +7 -1
- Program.cs +10 -5
AutoencoderKL.cs
CHANGED
@@ -1,8 +1,9 @@
|
|
|
|
1 |
using TorchSharp;
|
2 |
|
3 |
-
public class AutoencoderKL
|
4 |
{
|
5 |
-
private
|
6 |
private readonly float _scale;
|
7 |
public torch.Device Device {get;}
|
8 |
|
@@ -21,4 +22,10 @@ public class AutoencoderKL
|
|
21 |
tokenTensor = 1.0f / _scale * tokenTensor;
|
22 |
return (torch.Tensor)_model.forward(tokenTensor);
|
23 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
}
|
|
|
1 |
+
using System;
|
2 |
using TorchSharp;
|
3 |
|
4 |
+
public class AutoencoderKL : IDisposable
|
5 |
{
|
6 |
+
private torch.jit.ScriptModule _model;
|
7 |
private readonly float _scale;
|
8 |
public torch.Device Device {get;}
|
9 |
|
|
|
22 |
tokenTensor = 1.0f / _scale * tokenTensor;
|
23 |
return (torch.Tensor)_model.forward(tokenTensor);
|
24 |
}
|
25 |
+
|
26 |
+
public void Dispose()
|
27 |
+
{
|
28 |
+
_model.Dispose();
|
29 |
+
_model = null;
|
30 |
+
}
|
31 |
}
|
ClipEnocder.cs
CHANGED
@@ -1,8 +1,9 @@
|
|
|
|
1 |
using TorchSharp;
|
2 |
|
3 |
-
public class ClipEncoder
|
4 |
{
|
5 |
-
private
|
6 |
public torch.Device Device {get;}
|
7 |
|
8 |
public ClipEncoder(string modelPath, torch.Device device)
|
@@ -17,4 +18,10 @@ public class ClipEncoder
|
|
17 |
{
|
18 |
return (torch.Tensor)_model.forward(tokenTensor);
|
19 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
}
|
|
|
1 |
+
using System;
|
2 |
using TorchSharp;
|
3 |
|
4 |
+
public class ClipEncoder : IDisposable
|
5 |
{
|
6 |
+
private torch.jit.ScriptModule _model;
|
7 |
public torch.Device Device {get;}
|
8 |
|
9 |
public ClipEncoder(string modelPath, torch.Device device)
|
|
|
18 |
{
|
19 |
return (torch.Tensor)_model.forward(tokenTensor);
|
20 |
}
|
21 |
+
|
22 |
+
public void Dispose()
|
23 |
+
{
|
24 |
+
_model.Dispose();
|
25 |
+
_model = null;
|
26 |
+
}
|
27 |
}
|
DDPM.cs
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
using System;
|
2 |
using TorchSharp;
|
3 |
|
4 |
-
public class DDPM
|
5 |
{
|
6 |
private readonly torch.jit.ScriptModule _model;
|
7 |
public torch.Device Device {get;}
|
@@ -36,4 +36,10 @@ public class DDPM
|
|
36 |
{
|
37 |
return _model.invoke<torch.Tensor>("predict_start_from_z_and_v", z, t, v);
|
38 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
}
|
|
|
1 |
using System;
|
2 |
using TorchSharp;
|
3 |
|
4 |
+
public class DDPM : IDisposable
|
5 |
{
|
6 |
private readonly torch.jit.ScriptModule _model;
|
7 |
public torch.Device Device {get;}
|
|
|
36 |
{
|
37 |
return _model.invoke<torch.Tensor>("predict_start_from_z_and_v", z, t, v);
|
38 |
}
|
39 |
+
|
40 |
+
public void Dispose()
|
41 |
+
{
|
42 |
+
_model.Dispose();
|
43 |
+
_model = null;
|
44 |
+
}
|
45 |
}
|
Program.cs
CHANGED
@@ -6,9 +6,6 @@ using TorchSharp;
|
|
6 |
|
7 |
torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
|
8 |
var device = TorchSharp.torch.device("cuda:0");
|
9 |
-
var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
|
10 |
-
var ddimSampler = new DDIMSampler(ddpm);
|
11 |
-
var autoencoderKL = new AutoencoderKL("autoencoder_kl.ckpt", device);
|
12 |
var clipEncoder = new ClipEncoder("clip_encoder.ckpt", device);
|
13 |
var start_token = 49406;
|
14 |
var end_token = 49407;
|
@@ -21,7 +18,7 @@ var dictionary = new Dictionary<string, long>(){
|
|
21 |
{"green", 1901},
|
22 |
};
|
23 |
|
24 |
-
var batch =
|
25 |
|
26 |
var prompt = "a wild cute green cat";
|
27 |
var tokens = prompt.Split(' ').Select(x => dictionary[x]).ToList();
|
@@ -33,15 +30,23 @@ var tokenTensor = torch.tensor(tokens.ToArray(), dtype: torch.ScalarType.Int64,
|
|
33 |
tokenTensor = tokenTensor.repeat(batch, 1);
|
34 |
var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
|
35 |
unconditional_tokenTensor = unconditional_tokenTensor.repeat(batch, 1);
|
36 |
-
var img = torch.randn(batch, 4,
|
37 |
var t = torch.full(new[]{batch, 1L}, value: batch, dtype: torch.ScalarType.Int32, device: device);
|
38 |
var condition = clipEncoder.Forward(tokenTensor);
|
39 |
var unconditional_condition = clipEncoder.Forward(unconditional_tokenTensor);
|
|
|
|
|
|
|
|
|
40 |
var ddim_steps = 50;
|
41 |
img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
|
|
|
|
|
|
|
42 |
var decoded_images = (torch.Tensor)autoencoderKL.Forward(img);
|
43 |
decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
|
44 |
|
|
|
45 |
for(int i = 0; i!= batch; ++i)
|
46 |
{
|
47 |
var image = decoded_images[i];
|
|
|
6 |
|
7 |
torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
|
8 |
var device = TorchSharp.torch.device("cuda:0");
|
|
|
|
|
|
|
9 |
var clipEncoder = new ClipEncoder("clip_encoder.ckpt", device);
|
10 |
var start_token = 49406;
|
11 |
var end_token = 49407;
|
|
|
18 |
{"green", 1901},
|
19 |
};
|
20 |
|
21 |
+
var batch = 1;
|
22 |
|
23 |
var prompt = "a wild cute green cat";
|
24 |
var tokens = prompt.Split(' ').Select(x => dictionary[x]).ToList();
|
|
|
30 |
tokenTensor = tokenTensor.repeat(batch, 1);
|
31 |
var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
|
32 |
unconditional_tokenTensor = unconditional_tokenTensor.repeat(batch, 1);
|
33 |
+
var img = torch.randn(batch, 4, 64, 64, dtype: torch.ScalarType.Float32, device: device);
|
34 |
var t = torch.full(new[]{batch, 1L}, value: batch, dtype: torch.ScalarType.Int32, device: device);
|
35 |
var condition = clipEncoder.Forward(tokenTensor);
|
36 |
var unconditional_condition = clipEncoder.Forward(unconditional_tokenTensor);
|
37 |
+
|
38 |
+
clipEncoder.Dispose();
|
39 |
+
var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
|
40 |
+
var ddimSampler = new DDIMSampler(ddpm);
|
41 |
var ddim_steps = 50;
|
42 |
img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
|
43 |
+
ddpm.Dispose();
|
44 |
+
|
45 |
+
var autoencoderKL = new AutoencoderKL("autoencoder_kl.ckpt", device);
|
46 |
var decoded_images = (torch.Tensor)autoencoderKL.Forward(img);
|
47 |
decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
|
48 |
|
49 |
+
|
50 |
for(int i = 0; i!= batch; ++i)
|
51 |
{
|
52 |
var image = decoded_images[i];
|