Thaweewat commited on
Commit
87cceff
·
1 Parent(s): 6863d4d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +269 -0
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import einops
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+
7
+ from pytorch_lightning import seed_everything
8
+ from util import resize_image, HWC3, apply_canny
9
+ from ldm.models.diffusion.ddim import DDIMSampler
10
+ from annotator.openpose import apply_openpose
11
+ from cldm.model import create_model, load_state_dict
12
+ from huggingface_hub import hf_hub_url, cached_download
13
+
14
+
15
+
16
+ REPO_ID = "Thaweewat/ControlNet-Architecture"
17
+ canny_checkpoint = "models/control_sd15_canny.pth"
18
+ scribble_checkpoint = "models/control_sd15_scribble.pth"
19
+ pose_checkpoint = "models/control_sd15_openpose.pth"
20
+
21
+
22
+ canny_model = create_model('./models/cldm_v15.yaml').cpu()
23
+ canny_model.load_state_dict(load_state_dict(cached_download(
24
+ hf_hub_url(REPO_ID, canny_checkpoint)
25
+ ), location='cpu'))
26
+ canny_model = canny_model.cuda()
27
+ ddim_sampler = DDIMSampler(canny_model)
28
+
29
+ pose_model = create_model('./models/cldm_v15.yaml').cpu()
30
+ pose_model.load_state_dict(load_state_dict(cached_download(
31
+ hf_hub_url(REPO_ID, pose_checkpoint)
32
+ ), location='cpu'))
33
+ pose_model = pose_model.cuda()
34
+ ddim_sampler_pose = DDIMSampler(pose_model)
35
+
36
+ scribble_model = create_model('./models/cldm_v15.yaml').cpu()
37
+ scribble_model.load_state_dict(load_state_dict(cached_download(
38
+ hf_hub_url(REPO_ID, scribble_checkpoint)
39
+ ), location='cpu'))
40
+ scribble_model = scribble_model.cuda()
41
+ ddim_sampler_scribble = DDIMSampler(scribble_model)
42
+
43
+ save_memory = False
44
+
45
+ def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
46
+ # TODO: Add other control tasks
47
+ if input_control == "Scribble":
48
+ return process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
49
+ elif input_control == "Pose":
50
+ return process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, image_resolution, ddim_steps, scale, seed, eta)
51
+
52
+ return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
53
+
54
+ def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
55
+ with torch.no_grad():
56
+ img = resize_image(HWC3(input_image), image_resolution)
57
+ H, W, C = img.shape
58
+
59
+ detected_map = apply_canny(img, low_threshold, high_threshold)
60
+ detected_map = HWC3(detected_map)
61
+
62
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
63
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
64
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
65
+
66
+ seed_everything(seed)
67
+
68
+ if save_memory:
69
+ canny_model.low_vram_shift(is_diffusing=False)
70
+
71
+ cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
72
+ un_cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([n_prompt] * num_samples)]}
73
+ shape = (4, H // 8, W // 8)
74
+
75
+ if save_memory:
76
+ canny_model.low_vram_shift(is_diffusing=False)
77
+
78
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
79
+ shape, cond, verbose=False, eta=eta,
80
+ unconditional_guidance_scale=scale,
81
+ unconditional_conditioning=un_cond)
82
+
83
+ if save_memory:
84
+ canny_model.low_vram_shift(is_diffusing=False)
85
+
86
+ x_samples = canny_model.decode_first_stage(samples)
87
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
88
+
89
+ results = [x_samples[i] for i in range(num_samples)]
90
+ return [255 - detected_map] + results
91
+
92
+ def process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta):
93
+ with torch.no_grad():
94
+ img = resize_image(HWC3(input_image), image_resolution)
95
+ H, W, C = img.shape
96
+
97
+ detected_map = np.zeros_like(img, dtype=np.uint8)
98
+ detected_map[np.min(img, axis=2) < 127] = 255
99
+
100
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
101
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
102
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
103
+
104
+ seed_everything(seed)
105
+
106
+ if save_memory:
107
+ scribble_model.low_vram_shift(is_diffusing=False)
108
+
109
+ cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
110
+ un_cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([n_prompt] * num_samples)]}
111
+ shape = (4, H // 8, W // 8)
112
+
113
+ if save_memory:
114
+ scribble_model.low_vram_shift(is_diffusing=False)
115
+
116
+ samples, intermediates = ddim_sampler_scribble.sample(ddim_steps, num_samples,
117
+ shape, cond, verbose=False, eta=eta,
118
+ unconditional_guidance_scale=scale,
119
+ unconditional_conditioning=un_cond)
120
+
121
+ if save_memory:
122
+ scribble_model.low_vram_shift(is_diffusing=False)
123
+
124
+ x_samples = scribble_model.decode_first_stage(samples)
125
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
126
+
127
+ results = [x_samples[i] for i in range(num_samples)]
128
+ return [255 - detected_map] + results
129
+
130
+ def process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta):
131
+ with torch.no_grad():
132
+ input_image = HWC3(input_image)
133
+ detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution))
134
+ detected_map = HWC3(detected_map)
135
+ img = resize_image(input_image, image_resolution)
136
+ H, W, C = img.shape
137
+
138
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
139
+
140
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
141
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
142
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
143
+
144
+ if seed == -1:
145
+ seed = random.randint(0, 65535)
146
+ seed_everything(seed)
147
+
148
+ if save_memory:
149
+ pose_model.low_vram_shift(is_diffusing=False)
150
+
151
+
152
+ cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
153
+ un_cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([n_prompt] * num_samples)]}
154
+ shape = (4, H // 8, W // 8)
155
+
156
+ if save_memory:
157
+ pose_model.low_vram_shift(is_diffusing=False)
158
+
159
+ samples, intermediates = ddim_sampler_pose.sample(ddim_steps, num_samples,
160
+ shape, cond, verbose=False, eta=eta,
161
+ unconditional_guidance_scale=scale,
162
+ unconditional_conditioning=un_cond)
163
+
164
+ if save_memory:
165
+ pose_model.low_vram_shift(is_diffusing=False)
166
+
167
+ x_samples = pose_model.decode_first_stage(samples)
168
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
169
+
170
+ results = [x_samples[i] for i in range(num_samples)]
171
+ return [detected_map] + results
172
+
173
+ def create_canvas(w, h):
174
+ new_control_options = ["Interactive Scribble"]
175
+ return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
176
+
177
+
178
+ block = gr.Blocks().queue()
179
+ control_task_list = [
180
+ "Canny Edge Map",
181
+ "Scribble",
182
+ "Pose"
183
+ ]
184
+ with block:
185
+ gr.Markdown("## Adding Conditional Control to Text-to-Image Diffusion Models")
186
+ gr.HTML('''
187
+ <p style="margin-bottom: 10px; font-size: 94%">
188
+ This is an unofficial demo for ControlNet, which is a neural network structure to control diffusion models by adding extra conditions such as canny edge detection. The demo is based on the <a href="https://github.com/lllyasviel/ControlNet" style="text-decoration: underline;" target="_blank"> Github </a> implementation.
189
+ </p>
190
+ ''')
191
+ gr.HTML("<p>You can duplicate this Space to run it privately without a queue and load additional checkpoints. : <a style='display:inline-block' href='https://huggingface.co/spaces/RamAnanth1/ControlNet?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14' alt='Duplicate Space'></a> <a style='display:inline-block' href='https://colab.research.google.com/github/camenduru/controlnet-colab/blob/main/controlnet-colab.ipynb'><img src = 'https://colab.research.google.com/assets/colab-badge.svg' alt='Open in Colab'></a></p>")
192
+ with gr.Row():
193
+ with gr.Column():
194
+ input_image = gr.Image(source='upload', type="numpy")
195
+ input_control = gr.Dropdown(control_task_list, value="Scribble", label="Control Task")
196
+ prompt = gr.Textbox(label="Prompt")
197
+ run_button = gr.Button(label="Run")
198
+
199
+ with gr.Accordion("Advanced options", open=False):
200
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
201
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
202
+ low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
203
+ high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
204
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
205
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
206
+ seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
207
+ eta = gr.Slider(label="eta (DDIM)", minimum=0.0,maximum =1.0, value=0.0, step=0.1)
208
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
209
+ n_prompt = gr.Textbox(label="Negative Prompt",
210
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality')
211
+ with gr.Column():
212
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
213
+ ips = [input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold]
214
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
215
+ examples_list = [
216
+ [
217
+ "bird.png",
218
+ "bird",
219
+ "Canny Edge Map",
220
+ "best quality, extremely detailed",
221
+ 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
222
+ 1,
223
+ 512,
224
+ 20,
225
+ 9.0,
226
+ 123490213,
227
+ 0.0,
228
+ 100,
229
+ 200
230
+
231
+ ],
232
+
233
+ [
234
+ "turtle.png",
235
+ "turtle",
236
+ "Scribble",
237
+ "best quality, extremely detailed",
238
+ 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
239
+ 1,
240
+ 512,
241
+ 20,
242
+ 9.0,
243
+ 123490213,
244
+ 0.0,
245
+ 100,
246
+ 200
247
+
248
+ ],
249
+ [
250
+ "pose1.png",
251
+ "Chef in the Kitchen",
252
+ "Pose",
253
+ "best quality, extremely detailed",
254
+ 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
255
+ 1,
256
+ 512,
257
+ 20,
258
+ 9.0,
259
+ 123490213,
260
+ 0.0,
261
+ 100,
262
+ 200
263
+
264
+ ]
265
+ ]
266
+ examples = gr.Examples(examples=examples_list,inputs = [input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold], outputs = [result_gallery], cache_examples = True, fn = process)
267
+ gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=RamAnanth1.ControlNet)")
268
+
269
+ block.launch(debug = True)