nateraw commited on
Commit
07b4fcf
·
1 Parent(s): 7a4f2ae

Create new file

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from pathlib import Path
3
+
4
+ import gradio as gr
5
+ from stable_diffusion_videos import StableDiffusionWalkPipeline, generate_images
6
+ from diffusers.schedulers import LMSDiscreteScheduler
7
+ import torch
8
+
9
+
10
+ class ImageGenerationInterface:
11
+ def __init__(self, pipeline):
12
+ self.pipeline = pipeline
13
+ self.interface_images = gr.Interface(
14
+ self.fn,
15
+ inputs=[
16
+ gr.Textbox("blueberry spaghetti", label='Prompt'),
17
+ gr.Slider(1, 24, 16, step=1, label='Batch size'),
18
+ gr.Slider(1, 16, 1, step=1, label='# Batches'),
19
+ gr.Slider(10, 100, 50, step=1, label='# Inference Steps'),
20
+ gr.Slider(5.0, 15.0, 7.5, step=0.5, label='Guidance Scale'),
21
+ gr.Slider(512, 1024, 512, step=64, label='Height'),
22
+ gr.Slider(512, 1024, 512, step=64, label='Width'),
23
+ gr.Checkbox(False, label='Upsample'),
24
+ gr.Textbox("nateraw/stable-diffusion-gallery", label='(Optional) Repo ID'),
25
+ gr.Checkbox(False, label='Push to Hub'),
26
+ gr.Checkbox(False, label='Private'),
27
+ gr.Textbox("./images", label='Output directory'),
28
+ ],
29
+ outputs=gr.Gallery(),
30
+ )
31
+
32
+ self.interface_videos = gr.Interface(
33
+ self.fn_videos,
34
+ inputs=[
35
+ gr.Textbox("blueberry spaghetti\nstrawberry spaghetti", lines=2, label='Prompts, separated by new line'),
36
+ gr.Textbox("42\n1337", lines=2, label='Seeds, separated by new line'),
37
+ gr.Textbox("25\n27", lines=2, label='Audio Offsets (seconds in song), separated by new line'),
38
+ gr.Audio(type="filepath"),
39
+ gr.Slider(3, 60, 5, step=1, label='FPS'),
40
+ gr.Slider(1, 24, 16, step=1, label='Batch size'),
41
+ gr.Slider(10, 100, 50, step=1, label='# Inference Steps'),
42
+ gr.Slider(5.0, 15.0, 7.5, step=0.5, label='Guidance Scale'),
43
+ gr.Slider(512, 1024, 512, step=64, label='Height'),
44
+ gr.Slider(512, 1024, 512, step=64, label='Width'),
45
+ gr.Checkbox(False, label='Upsample'),
46
+ ],
47
+ outputs=gr.Video(),
48
+ )
49
+ self.interface = gr.TabbedInterface(
50
+ [self.interface_images, self.interface_videos],
51
+ ['Images!', 'Videos!'],
52
+ )
53
+
54
+ def fn_videos(
55
+ self,
56
+ prompts,
57
+ seeds,
58
+ audio_offsets,
59
+ audio_filepath,
60
+ fps,
61
+ batch_size,
62
+ num_inference_steps,
63
+ guidance_scale,
64
+ height,
65
+ width,
66
+ upsample,
67
+ ):
68
+ prompts = [x.strip() for x in prompts.split('\n')]
69
+ seeds = [int(x.strip()) for x in seeds.split('\n')]
70
+ audio_offsets = [float(x.strip()) for x in audio_offsets.split('\n')]
71
+ num_interpolation_steps = [(b-a) * fps for a, b in zip(audio_offsets, audio_offsets[1:])]
72
+
73
+ return self.pipeline.walk(
74
+ prompts=prompts,
75
+ seeds=seeds,
76
+ num_interpolation_steps=num_interpolation_steps,
77
+ audio_filepath=audio_filepath,
78
+ audio_start_sec=audio_offsets[0],
79
+ fps=fps,
80
+ height=height,
81
+ width=width,
82
+ output_dir='dreams',
83
+ guidance_scale=guidance_scale,
84
+ num_inference_steps=num_inference_steps,
85
+ upsample=upsample,
86
+ batch_size=batch_size
87
+ )
88
+
89
+ def fn(
90
+ self,
91
+ prompt,
92
+ batch_size,
93
+ num_batches,
94
+ num_inference_steps,
95
+ guidance_scale,
96
+ height,
97
+ width,
98
+ upsample,
99
+ repo_id,
100
+ push_to_hub,
101
+ private,
102
+ output_dir,
103
+ ):
104
+ output_path = Path(output_dir)
105
+ name = time.strftime("%Y%m%d-%H%M%S")
106
+ save_path = output_path / name
107
+ image_filepaths = generate_images(
108
+ self.pipeline,
109
+ prompt,
110
+ batch_size=batch_size,
111
+ num_batches=num_batches,
112
+ num_inference_steps=num_inference_steps,
113
+ guidance_scale=guidance_scale,
114
+ output_dir=output_dir,
115
+ name=name,
116
+ image_file_ext='.jpg',
117
+ upsample=upsample,
118
+ height=height,
119
+ width=width,
120
+ push_to_hub=push_to_hub,
121
+ repo_id=repo_id,
122
+ private=private,
123
+ create_pr=False,
124
+ )
125
+ return [(x, Path(x).stem) for x in sorted(image_filepaths)]
126
+
127
+ def launch(self, *args, **kwargs):
128
+ self.interface.launch(*args, **kwargs)
129
+
130
+
131
+ def main(
132
+ model_id: str = "CompVis/stable-diffusion-v1-4",
133
+ tiled=False,
134
+ disable_safety_checker=False,
135
+ ):
136
+ safety_checker_kwargs = {'safety_checker': None} if disable_safety_checker else {}
137
+ pipeline = StableDiffusionWalkPipeline.from_pretrained(
138
+ model_id,
139
+ revision="fp16",
140
+ torch_dtype=torch.float16,
141
+ scheduler=LMSDiscreteScheduler(
142
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
143
+ ),
144
+ tiled=tiled,
145
+ **safety_checker_kwargs
146
+ ).to("cuda")
147
+ ImageGenerationInterface(pipeline).launch(debug=True)
148
+
149
+
150
+ if __name__ == '__main__':
151
+ import fire
152
+
153
+ fire.Fire(main)