Vijish commited on
Commit
682647c
Β·
verified Β·
1 Parent(s): c39c22f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +393 -0
app.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from diffusers import DiffusionPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, UniPCMultistepScheduler
3
+ from stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
4
+ from controlnet_aux import OpenposeDetector
5
+ #from transformers import DPTFeatureExtractor, DPTForDepthEstimation
6
+ from controlnet_aux import MidasDetector, ZoeDetector
7
+ from tqdm import tqdm
8
+
9
+ import torch
10
+ import numpy as np
11
+ import cv2
12
+ from PIL import Image
13
+ import os
14
+ import random
15
+ import gc
16
+
17
+
18
+ def clear_memory():
19
+ gc.collect()
20
+ if torch.cuda.is_available():
21
+ torch.cuda.empty_cache()
22
+ torch.cuda.ipc_collect()
23
+
24
+ # Global variable definitions
25
+ controlnet_pipe = None
26
+ reference_pipe = None
27
+ pipe = None
28
+
29
+ # Load the base model
30
+ model = "aicollective1/aicollective"
31
+ pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
32
+ pipe.to("cuda")
33
+
34
+ # Placeholder for ControlNet models to be loaded dynamically
35
+ controlnet_models = {
36
+ "Canny": None,
37
+ "Depth": None,
38
+ "OpenPose": None,
39
+ "Reference": None
40
+ }
41
+
42
+ # Load necessary models and feature extractors for depth estimation and OpenPose
43
+ #feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
44
+ #depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
45
+
46
+ processor_zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
47
+ processor_midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
48
+
49
+
50
+ openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
51
+
52
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
53
+ controlnet_pipe = None # Initial placeholder, will be loaded dynamically
54
+ reference_pipe = None # Initial placeholder for reference pipeline
55
+
56
+ # Define the prompts and negative prompts for each style
57
+ styles = {
58
+ "Anime Studio Dance": {
59
+ "prompt": ("anime screencap of a man wearing a white helmet with pointed ears,\n"
60
+ "Outfit: closed animal print shirt,\n"
61
+ "Action: anime style, looking at viewer, solo, upper body,\n"
62
+ "((masterpiece)), (best quality), (extremely detailed), depth of field, sketch, "
63
+ "dark intense shadows, sharp focus, soft lighting, hdr, colorful, good composition, spectacular,"),
64
+ "negative_prompt": ("realistic, (painting by bad-artist-anime:0.9), (painting by bad-artist:0.9), watermark, "
65
+ "text, error, blurry, jpeg artifacts, cropped, worst quality, low quality, normal quality, "
66
+ "jpeg artifacts, signature, watermark, username, artist name, (worst quality, low quality:1.4), "
67
+ "bad anatomy, watermark, signature, text, logo")
68
+ },
69
+ "Vintage Realistic": {
70
+ "prompt": ("a masterpiece close up shoot photography of an man wearing a animal print helmet with pointed ears,\n"
71
+ "Outfit: wearing an big oversized outfit, white leather jacket,\n"
72
+ "Action: sitting on steps,\n"
73
+ "hyper realistic with detailed textures, cinematic film still of Photorealism, realistic skin texture, "
74
+ "subsurface scattering, skinny, Photorealism, often for highly detailed representation, photographic accuracy, "
75
+ "shallow depth of field, vignette, highly detailed, bokeh, epic, gorgeous, sharp, perfect hands,\n"
76
+ "<lora:add-detail-xl:1> <lora:Vintage_Street_Photo:0.9>"),
77
+ "negative_prompt": ("deformed skin, skin veins, black skin, blurry, text, yellow, deformed, (worst quality, low resolution, "
78
+ "bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d render, "
79
+ "distorted, twisted, watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, "
80
+ "glitch, deformed, mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, "
81
+ "(high contrast:1.2), (over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, "
82
+ "ugly, tiling, poorly drawn hands, 3d render, impressionism, digital art")
83
+ },
84
+ "Anime 90's Aesthetic": {
85
+ "prompt": ("an man wearing a white helmet with pointed ears, perfect chin,\n"
86
+ "Outfit: wearing oversized hoodie, animal print pants,\n"
87
+ "Action: dancing in nature, music production, music instruments made of wood,\n"
88
+ "A screengrab of an anime, 90's aesthetic,"),
89
+ "negative_prompt": ("photo, real, realistic, blurry, text, yellow, deformed, (worst quality, low resolution, bad hands,), "
90
+ "text, watermark, artist name, distorted, twisted, watermark, 3d render, distorted, twisted, watermark, "
91
+ "text, abstract, glitch, deformed, mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, "
92
+ "canvas frame, (high contrast:1.2), (over saturated:1.2), (glossy:1.1), disfigured, Photoshop, video game, "
93
+ "ugly, tiling, poorly drawn hands, 3d render, impressionism, eyes, mouth, black skin, pale skin, hair, beard")
94
+ },
95
+ "Anime Style": {
96
+ "prompt": ("An man wearing a white helmet with pointed ears sitting on the steps of an Asian street shop,\n"
97
+ "Outfit: wearing blue pants and a yellow jacket with a red backpack, in the anime style with detailed "
98
+ "character design in the style of Atey Ghailan, featured in CGSociety, character concept art in the style of Katsuhiro Otomo"),
99
+ "negative_prompt": ("real, deformed fingers, chin, deformed hands, blurry, text, yellow, deformed, (worst quality, low resolution, "
100
+ "bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d, distorted, twisted, "
101
+ "watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, "
102
+ "ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, (high contrast:1.2), "
103
+ "(over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, ugly, tiling, "
104
+ "poorly drawn hands, 3d render, impressionism, digital art")
105
+ },
106
+ "Real 70s": {
107
+ "prompt": ("a masterpiece close up shoot photography of an man wearing a white helmet with pointed ears,\n"
108
+ "Outfit: wearing an oversized trippy 70s shirt and scarf,\n"
109
+ "Action: standing on the ocean,\n"
110
+ "shot in the style of Erwin Olaf, hyper realistic with detailed textures, cinematic film still of Photorealism, "
111
+ "realistic skin texture, subsurface scattering, skinny, Photorealism, often for highly detailed representation, "
112
+ "photographic accuracy, shallow depth of field, vignette, highly detailed, bokeh, epic, gorgeous, sharp,"),
113
+ "negative_prompt": ("deformed skin, skin veins, black skin, blurry, text, yellow, deformed, (worst quality, low resolution, "
114
+ "bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d render, distorted, "
115
+ "twisted, watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, "
116
+ "mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, (high contrast:1.2), "
117
+ "(over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, ugly, tiling, "
118
+ "poorly drawn hands, 3d render, impressionism, digital art")
119
+ }
120
+ }
121
+
122
+ # Define the style images
123
+ style_images = {
124
+ "Anime Studio Dance": "style/Anime Studio Dance.png",
125
+ "Vintage Realistic": "style/Vintage Realistic.png",
126
+ "Anime 90's Aesthetic": "style/Anime 90's Aesthetic.png",
127
+ "Anime Style": "style/Anime Style.png",
128
+ "Real 70s": "style/Real 70s.png"
129
+ }
130
+
131
+ # Function to load ControlNet models dynamically
132
+ def load_controlnet_model(controlnet_type):
133
+ global controlnet_pipe, pipe, reference_pipe, controlnet_models, vae, model
134
+
135
+ clear_memory()
136
+
137
+ if controlnet_models[controlnet_type] is None:
138
+ if controlnet_type in ["Canny", "Depth", "OpenPose"]:
139
+ controlnet_models[controlnet_type] = ControlNetModel.from_pretrained(
140
+ "xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
141
+ )
142
+ elif controlnet_type == "Reference":
143
+ controlnet_models[controlnet_type] = StableDiffusionXLReferencePipeline.from_pretrained(
144
+ model, torch_dtype=torch.float16, use_safetensors=True
145
+ )
146
+
147
+ if 'controlnet_pipe' in globals() and controlnet_pipe is not None:
148
+ controlnet_pipe.to("cpu")
149
+ del controlnet_pipe
150
+ globals()['controlnet_pipe'] = None
151
+
152
+ if 'reference_pipe' in globals() and reference_pipe is not None:
153
+ reference_pipe.to("cpu")
154
+ del reference_pipe
155
+ globals()['reference_pipe'] = None
156
+
157
+ if pipe is not None:
158
+ pipe.to("cpu")
159
+
160
+ clear_memory()
161
+
162
+ if controlnet_type == "Reference":
163
+ reference_pipe = controlnet_models[controlnet_type]
164
+ reference_pipe.scheduler = UniPCMultistepScheduler.from_config(reference_pipe.scheduler.config)
165
+ reference_pipe.to("cuda")
166
+ globals()['reference_pipe'] = reference_pipe
167
+ else:
168
+ controlnet_pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
169
+ model, controlnet=controlnet_models[controlnet_type], vae=vae, torch_dtype=torch.float16, use_safetensors=True
170
+ )
171
+ controlnet_pipe.scheduler = UniPCMultistepScheduler.from_config(controlnet_pipe.scheduler.config)
172
+ controlnet_pipe.to("cuda")
173
+ globals()['controlnet_pipe'] = controlnet_pipe
174
+
175
+ clear_memory()
176
+ return f"Loaded {controlnet_type} model."
177
+
178
+
179
+ # Preprocessing functions for each ControlNet type
180
+ def preprocess_canny(image):
181
+ if isinstance(image, Image.Image):
182
+ image = np.array(image)
183
+ if image.dtype != np.uint8:
184
+ image = (image * 255).astype(np.uint8)
185
+ image = cv2.Canny(image, 100, 200)
186
+ image = image[:, :, None]
187
+ image = np.concatenate([image, image, image], axis=2)
188
+ return Image.fromarray(image)
189
+
190
+
191
+ def preprocess_depth(image, target_size=(1024, 1024)):
192
+ if isinstance(image, Image.Image):
193
+ img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
194
+ else:
195
+ img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
196
+
197
+ depth_img = processor_zoe(img, output_type='cv2') if random.random() > 0.5 else processor_midas(img, output_type='cv2')
198
+
199
+ height, width = depth_img.shape[:2]
200
+ ratio = min(target_size[0] / width, target_size[1] / height)
201
+ new_width, new_height = int(width * ratio), int(height * ratio)
202
+ depth_img_resized = cv2.resize(depth_img, (new_width, new_height))
203
+
204
+ return Image.fromarray(depth_img_resized)
205
+
206
+ def preprocess_openpose(image):
207
+ if isinstance(image, Image.Image):
208
+ image = np.array(image)
209
+ image = openpose_processor(image, hand_and_face=False, output_type='cv2')
210
+ height, width = image.shape[:2]
211
+ ratio = np.sqrt(1024. * 1024. / (width * height))
212
+ new_width, new_height = int(width * ratio), int(height * ratio)
213
+ image = cv2.resize(image, (new_width, new_height))
214
+ return Image.fromarray(image)
215
+
216
+ def process_image_batch(images, pipe, prompt, negative_prompt, progress, batch_size=2):
217
+ all_processed_images = []
218
+ for i in range(0, len(images), batch_size):
219
+ batch = images[i:i+batch_size]
220
+ batch_prompt = [prompt] * len(batch)
221
+ batch_negative_prompt = [negative_prompt] * len(batch)
222
+
223
+ if isinstance(pipe, StableDiffusionXLReferencePipeline):
224
+ processed_batch = []
225
+ for img in batch:
226
+ result = pipe(
227
+ prompt=prompt,
228
+ negative_prompt=negative_prompt,
229
+ ref_image=img,
230
+ num_inference_steps=20
231
+ ).images
232
+ processed_batch.extend(result)
233
+ else:
234
+ processed_batch = pipe(
235
+ prompt=batch_prompt,
236
+ negative_prompt=batch_negative_prompt,
237
+ image=batch,
238
+ num_inference_steps=20
239
+ ).images
240
+
241
+ all_processed_images.extend(processed_batch)
242
+ progress((i + batch_size) / len(images)) # Update progress bar
243
+ clear_memory() # Clear memory after each batch
244
+ return all_processed_images
245
+
246
+
247
+ # Define the function to generate images
248
+ def generate_images_with_progress(prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, control_image, batch_images_input, progress=gr.Progress(track_tqdm=True)):
249
+ global controlnet_pipe, pipe, reference_pipe
250
+
251
+ clear_memory()
252
+
253
+ if use_controlnet:
254
+ if controlnet_type not in controlnet_models or controlnet_models[controlnet_type] is None:
255
+ raise ValueError(f"{controlnet_type} model not loaded. Please load the model first.")
256
+
257
+ if mode == "Single Image":
258
+ control_images = [control_image]
259
+ else:
260
+ control_images = [Image.open(img).convert("RGB") for img in batch_images_input]
261
+
262
+ preprocessed_images = []
263
+ for img in tqdm(control_images, desc="Preprocessing images"):
264
+ if controlnet_type == "Canny":
265
+ preprocessed_images.append(preprocess_canny(img))
266
+ elif controlnet_type == "Depth":
267
+ preprocessed_images.append(preprocess_depth(img))
268
+ elif controlnet_type == "OpenPose":
269
+ preprocessed_images.append(preprocess_openpose(img))
270
+ else: # Reference
271
+ preprocessed_images.append(img)
272
+
273
+ if controlnet_type == "Reference":
274
+ images = process_image_batch(preprocessed_images, reference_pipe, prompt, negative_prompt, progress)
275
+ else:
276
+ images = process_image_batch(preprocessed_images, controlnet_pipe, prompt, negative_prompt, progress)
277
+ else:
278
+ if 'controlnet_pipe' in globals() and controlnet_pipe is not None:
279
+ controlnet_pipe.to("cpu")
280
+ del controlnet_pipe
281
+ globals()['controlnet_pipe'] = None
282
+
283
+ if 'reference_pipe' in globals() and reference_pipe is not None:
284
+ reference_pipe.to("cpu")
285
+ del reference_pipe
286
+ globals()['reference_pipe'] = None
287
+
288
+ clear_memory()
289
+
290
+ if pipe is None:
291
+ pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
292
+ pipe.to("cuda")
293
+
294
+ images = []
295
+ for i in tqdm(range(batch_count), desc="Generating images"):
296
+ generated = pipe(prompt=[prompt], negative_prompt=[negative_prompt], num_inference_steps=20, width=1024, height=1024).images
297
+ images.extend(generated)
298
+ progress((i + 1) / batch_count) # Update progress bar
299
+ clear_memory() # Clear memory after each image, even in single image mode
300
+
301
+ clear_memory()
302
+ return images
303
+
304
+
305
+
306
+ # Function to extract PNG metadata
307
+ def extract_png_info(image_path):
308
+ metadata = image_path.info # This is a dictionary containing key-value pairs of metadata
309
+ return metadata
310
+
311
+ # Define the Gradio interface
312
+ with gr.Blocks() as demo:
313
+ gr.Markdown("# Image Generation with Custom Prompts and Styles")
314
+
315
+ with gr.Row():
316
+ with gr.Column():
317
+ prompt = gr.Textbox(label="Prompt", lines=8, interactive=True)
318
+ with gr.Accordion("Negative Prompt (Minimize/Expand)", open=False):
319
+ negative_prompt = gr.Textbox(
320
+ label="Negative Prompt",
321
+ value="",
322
+ lines=5
323
+ )
324
+ batch_count = gr.Slider(minimum=1, maximum=10, step=1, label="Batch Count", value=1)
325
+ use_controlnet = gr.Checkbox(label="Use ControlNet", value=False)
326
+ controlnet_type = gr.Dropdown(choices=["Canny", "Depth", "OpenPose", "Reference"], label="ControlNet Type")
327
+ controlnet_status = gr.Textbox(label="ControlNet Status", value="", interactive=False)
328
+ mode = gr.Radio(choices=["Single Image", "Batch"], label="Mode", value="Single Image")
329
+
330
+ with gr.Tabs() as tabs:
331
+ with gr.TabItem("Single Image"):
332
+ control_image = gr.Image(label="Control Image", type='pil')
333
+
334
+ with gr.TabItem("Batch"):
335
+ batch_images_input = gr.File(label="Upload Images", file_count='multiple')
336
+
337
+ with gr.TabItem("Extract Metadata"):
338
+ png_image = gr.Image(label="Upload PNG Image", type='pil')
339
+ metadata_output = gr.JSON(label="PNG Metadata")
340
+
341
+ with gr.Column(scale=2):
342
+ style_images_gallery = gr.Gallery(
343
+ label="Choose a Style",
344
+ value=list(style_images.values()),
345
+ interactive=True,
346
+ elem_id="style-gallery",
347
+ columns=5,
348
+ object_fit="contain",
349
+ allow_preview=False
350
+ )
351
+ gallery = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height="auto")
352
+
353
+ selected_style = gr.State(value="Anime Studio Dance")
354
+
355
+ def select_style(evt: gr.SelectData):
356
+ style_names = list(styles.keys())
357
+ if evt.index < 0 or evt.index >= len(style_names):
358
+ raise ValueError(f"Invalid index: {evt.index}")
359
+ selected_style = style_names[evt.index]
360
+ return styles[selected_style]["prompt"], styles[selected_style]["negative_prompt"], selected_style
361
+
362
+ style_images_gallery.select(fn=select_style, inputs=[], outputs=[prompt, negative_prompt, selected_style])
363
+
364
+ def update_controlnet(controlnet_type):
365
+ status = load_controlnet_model(controlnet_type)
366
+ return status
367
+
368
+ controlnet_type.change(fn=update_controlnet, inputs=controlnet_type, outputs=controlnet_status)
369
+
370
+ generate_button = gr.Button("Generate Images")
371
+ generate_button.click(
372
+ generate_images_with_progress,
373
+ inputs=[prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, control_image, batch_images_input],
374
+ outputs=gallery
375
+ )
376
+
377
+ metadata_button = gr.Button("Extract Metadata")
378
+ metadata_button.click(
379
+ fn=extract_png_info,
380
+ inputs=png_image,
381
+ outputs=metadata_output
382
+ )
383
+
384
+ with gr.Row():
385
+ generate_button
386
+
387
+
388
+
389
+ # At the end of your script:
390
+ if __name__ == "__main__":
391
+ # Your Gradio interface setup here
392
+ demo.launch(debug=True)
393
+ clear_memory()