littlelittlecloud commited on
Commit
2cc08ea
·
1 Parent(s): 7779efa

disposable pattern

Browse files
Files changed (4) hide show
  1. AutoencoderKL.cs +9 -2
  2. ClipEnocder.cs +9 -2
  3. DDPM.cs +7 -1
  4. Program.cs +10 -5
AutoencoderKL.cs CHANGED
@@ -1,8 +1,9 @@
 
1
  using TorchSharp;
2
 
3
- public class AutoencoderKL
4
  {
5
- private readonly torch.jit.ScriptModule _model;
6
  private readonly float _scale;
7
  public torch.Device Device {get;}
8
 
@@ -21,4 +22,10 @@ public class AutoencoderKL
21
  tokenTensor = 1.0f / _scale * tokenTensor;
22
  return (torch.Tensor)_model.forward(tokenTensor);
23
  }
 
 
 
 
 
 
24
  }
 
1
+ using System;
2
  using TorchSharp;
3
 
4
+ public class AutoencoderKL : IDisposable
5
  {
6
+ private torch.jit.ScriptModule _model;
7
  private readonly float _scale;
8
  public torch.Device Device {get;}
9
 
 
22
  tokenTensor = 1.0f / _scale * tokenTensor;
23
  return (torch.Tensor)_model.forward(tokenTensor);
24
  }
25
+
26
+ public void Dispose()
27
+ {
28
+ _model.Dispose();
29
+ _model = null;
30
+ }
31
  }
ClipEnocder.cs CHANGED
@@ -1,8 +1,9 @@
 
1
  using TorchSharp;
2
 
3
- public class ClipEncoder
4
  {
5
- private readonly torch.jit.ScriptModule _model;
6
  public torch.Device Device {get;}
7
 
8
  public ClipEncoder(string modelPath, torch.Device device)
@@ -17,4 +18,10 @@ public class ClipEncoder
17
  {
18
  return (torch.Tensor)_model.forward(tokenTensor);
19
  }
 
 
 
 
 
 
20
  }
 
1
+ using System;
2
  using TorchSharp;
3
 
4
+ public class ClipEncoder : IDisposable
5
  {
6
+ private torch.jit.ScriptModule _model;
7
  public torch.Device Device {get;}
8
 
9
  public ClipEncoder(string modelPath, torch.Device device)
 
18
  {
19
  return (torch.Tensor)_model.forward(tokenTensor);
20
  }
21
+
22
+ public void Dispose()
23
+ {
24
+ _model.Dispose();
25
+ _model = null;
26
+ }
27
  }
DDPM.cs CHANGED
@@ -1,7 +1,7 @@
1
  using System;
2
  using TorchSharp;
3
 
4
- public class DDPM
5
  {
6
  private readonly torch.jit.ScriptModule _model;
7
  public torch.Device Device {get;}
@@ -36,4 +36,10 @@ public class DDPM
36
  {
37
  return _model.invoke<torch.Tensor>("predict_start_from_z_and_v", z, t, v);
38
  }
 
 
 
 
 
 
39
  }
 
1
  using System;
2
  using TorchSharp;
3
 
4
+ public class DDPM : IDisposable
5
  {
6
  private readonly torch.jit.ScriptModule _model;
7
  public torch.Device Device {get;}
 
36
  {
37
  return _model.invoke<torch.Tensor>("predict_start_from_z_and_v", z, t, v);
38
  }
39
+
40
+ public void Dispose()
41
+ {
42
+ _model.Dispose();
43
+ _model = null;
44
+ }
45
  }
Program.cs CHANGED
@@ -6,9 +6,6 @@ using TorchSharp;
6
 
7
  torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
8
  var device = TorchSharp.torch.device("cuda:0");
9
- var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
10
- var ddimSampler = new DDIMSampler(ddpm);
11
- var autoencoderKL = new AutoencoderKL("autoencoder_kl.ckpt", device);
12
  var clipEncoder = new ClipEncoder("clip_encoder.ckpt", device);
13
  var start_token = 49406;
14
  var end_token = 49407;
@@ -21,7 +18,7 @@ var dictionary = new Dictionary<string, long>(){
21
  {"green", 1901},
22
  };
23
 
24
- var batch = 2;
25
 
26
  var prompt = "a wild cute green cat";
27
  var tokens = prompt.Split(' ').Select(x => dictionary[x]).ToList();
@@ -33,15 +30,23 @@ var tokenTensor = torch.tensor(tokens.ToArray(), dtype: torch.ScalarType.Int64,
33
  tokenTensor = tokenTensor.repeat(batch, 1);
34
  var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
35
  unconditional_tokenTensor = unconditional_tokenTensor.repeat(batch, 1);
36
- var img = torch.randn(batch, 4, 96, 96, dtype: torch.ScalarType.Float32, device: device);
37
  var t = torch.full(new[]{batch, 1L}, value: batch, dtype: torch.ScalarType.Int32, device: device);
38
  var condition = clipEncoder.Forward(tokenTensor);
39
  var unconditional_condition = clipEncoder.Forward(unconditional_tokenTensor);
 
 
 
 
40
  var ddim_steps = 50;
41
  img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
 
 
 
42
  var decoded_images = (torch.Tensor)autoencoderKL.Forward(img);
43
  decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
44
 
 
45
  for(int i = 0; i!= batch; ++i)
46
  {
47
  var image = decoded_images[i];
 
6
 
7
  torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
8
  var device = TorchSharp.torch.device("cuda:0");
 
 
 
9
  var clipEncoder = new ClipEncoder("clip_encoder.ckpt", device);
10
  var start_token = 49406;
11
  var end_token = 49407;
 
18
  {"green", 1901},
19
  };
20
 
21
+ var batch = 1;
22
 
23
  var prompt = "a wild cute green cat";
24
  var tokens = prompt.Split(' ').Select(x => dictionary[x]).ToList();
 
30
  tokenTensor = tokenTensor.repeat(batch, 1);
31
  var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
32
  unconditional_tokenTensor = unconditional_tokenTensor.repeat(batch, 1);
33
+ var img = torch.randn(batch, 4, 64, 64, dtype: torch.ScalarType.Float32, device: device);
34
  var t = torch.full(new[]{batch, 1L}, value: batch, dtype: torch.ScalarType.Int32, device: device);
35
  var condition = clipEncoder.Forward(tokenTensor);
36
  var unconditional_condition = clipEncoder.Forward(unconditional_tokenTensor);
37
+
38
+ clipEncoder.Dispose();
39
+ var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
40
+ var ddimSampler = new DDIMSampler(ddpm);
41
  var ddim_steps = 50;
42
  img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
43
+ ddpm.Dispose();
44
+
45
+ var autoencoderKL = new AutoencoderKL("autoencoder_kl.ckpt", device);
46
  var decoded_images = (torch.Tensor)autoencoderKL.Forward(img);
47
  decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
48
 
49
+
50
  for(int i = 0; i!= batch; ++i)
51
  {
52
  var image = decoded_images[i];