shrikant11 commited on
Commit
76af0c3
1 Parent(s): 52a5c48

checkpoint-10

Browse files
Files changed (1) hide show
  1. pipeline.py +33 -0
pipeline.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DiffusionPipeline
3
+
4
+
5
+ class MyPipeline(DiffusionPipeline):
6
+ def __init__(self, unet, scheduler):
7
+ super().__init__()
8
+
9
+ self.register_modules(unet=unet, scheduler=scheduler)
10
+
11
+ @torch.no_grad()
12
+ def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):
13
+ # Sample gaussian noise to begin loop
14
+ image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size))
15
+
16
+ image = image.to(self.device)
17
+
18
+ # set step values
19
+ self.scheduler.set_timesteps(num_inference_steps)
20
+
21
+ for t in self.progress_bar(self.scheduler.timesteps):
22
+ # 1. predict noise model_output
23
+ model_output = self.unet(image, t).sample
24
+
25
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
26
+ # eta corresponds to η in paper and should be between [0, 1]
27
+ # do x_t -> x_t-1
28
+ image = self.scheduler.step(model_output, t, image, eta).prev_sample
29
+
30
+ image = (image / 2 + 0.5).clamp(0, 1)
31
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
32
+
33
+ return image