aaronb commited on
Commit
e077396
·
1 Parent(s): 13935bd
Files changed (2) hide show
  1. gradio_dmd.py +71 -0
  2. scheduling_dmd.py +48 -0
gradio_dmd.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from diffusers import DiffusionPipeline, UNet2DConditionModel
7
+
8
+ from scheduling_dmd import DMDScheduler
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--unet-path", type='Lykon/dreamshaper-8')
12
+ parser.add_argument("--model-path", type='aaronb/dreamshaper-8-dmd-1kstep')
13
+ args = parser.parse_args()
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ unet = UNet2DConditionModel.from_pretrained(args.unet_path)
18
+ pipe = DiffusionPipeline.from_pretrained(args.model_path, unet=unet)
19
+ pipe.scheduler = DMDScheduler.from_config(pipe.scheduler.config)
20
+ pipe.to(device=device, dtype=torch.float16)
21
+
22
+
23
+ def predict(prompt, seed=1231231):
24
+ generator = torch.manual_seed(seed)
25
+ last_time = time.time()
26
+
27
+ image = pipe(
28
+ prompt,
29
+ num_inference_steps=1,
30
+ guidance_scale=0.0,
31
+ generator=generator,
32
+ ).images[0]
33
+
34
+ print(f"Pipe took {time.time() - last_time} seconds")
35
+ return image
36
+
37
+
38
+ css = """
39
+ #container{
40
+ margin: 0 auto;
41
+ max-width: 40rem;
42
+ }
43
+ #intro{
44
+ max-width: 100%;
45
+ text-align: center;
46
+ margin: 0 auto;
47
+ }
48
+ """
49
+ with gr.Blocks(css=css) as demo:
50
+ with gr.Column(elem_id="container"):
51
+ gr.Markdown(
52
+ """# Distribution Matching Distillation
53
+ """,
54
+ elem_id="intro",
55
+ )
56
+ with gr.Row():
57
+ with gr.Row():
58
+ prompt = gr.Textbox(placeholder="Insert your prompt here:", scale=5, container=False)
59
+ generate_bt = gr.Button("Generate", scale=1)
60
+
61
+ image = gr.Image(type="filepath")
62
+ with gr.Accordion("Advanced options", open=False):
63
+ seed = gr.Slider(randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1)
64
+
65
+ inputs = [prompt, seed]
66
+ generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
67
+ prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
68
+ seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
69
+
70
+ demo.queue(api_open=False)
71
+ demo.launch(show_api=False)
scheduling_dmd.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Tuple, Union, Optional
3
+
4
+ import torch
5
+ from diffusers import DDPMScheduler
6
+ from diffusers.utils import BaseOutput
7
+
8
+
9
+ @dataclass
10
+ class DMDSchedulerOutput(BaseOutput):
11
+ pred_original_sample: Optional[torch.FloatTensor] = None
12
+
13
+
14
+ class DMDScheduler(DDPMScheduler):
15
+ def set_timesteps(
16
+ self,
17
+ num_inference_steps: Optional[int] = None,
18
+ device: Union[str, torch.device] = None,
19
+ timesteps: Optional[List[int]] = None,
20
+ ):
21
+ self.timesteps = torch.tensor([self.config.num_train_timesteps-1]).long().to(device)
22
+
23
+ def step(
24
+ self,
25
+ model_output: torch.FloatTensor,
26
+ timestep: int,
27
+ sample: torch.FloatTensor,
28
+ generator=None,
29
+ return_dict: bool = True,
30
+ ) -> Union[DMDSchedulerOutput, Tuple]:
31
+ t = self.config.num_train_timesteps - 1
32
+
33
+ # 1. compute alphas, betas
34
+ alpha_prod_t = self.alphas_cumprod[t]
35
+ beta_prod_t = 1 - alpha_prod_t
36
+
37
+ if self.config.prediction_type == "epsilon":
38
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
39
+ else:
40
+ raise ValueError(
41
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
42
+ " `v_prediction` for the DDPMScheduler."
43
+ )
44
+
45
+ if not return_dict:
46
+ return (pred_original_sample,)
47
+
48
+ return DMDSchedulerOutput(pred_original_sample=pred_original_sample)