Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from diffusers import DiffusionPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, UniPCMultistepScheduler
|
3 |
+
from stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
|
4 |
+
from controlnet_aux import OpenposeDetector
|
5 |
+
#from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
6 |
+
from controlnet_aux import MidasDetector, ZoeDetector
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import cv2
|
12 |
+
from PIL import Image
|
13 |
+
import os
|
14 |
+
import random
|
15 |
+
import gc
|
16 |
+
|
17 |
+
|
18 |
+
def clear_memory():
|
19 |
+
gc.collect()
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
torch.cuda.empty_cache()
|
22 |
+
torch.cuda.ipc_collect()
|
23 |
+
|
24 |
+
# Global variable definitions
|
25 |
+
controlnet_pipe = None
|
26 |
+
reference_pipe = None
|
27 |
+
pipe = None
|
28 |
+
|
29 |
+
# Load the base model
|
30 |
+
model = "aicollective1/aicollective"
|
31 |
+
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
|
32 |
+
pipe.to("cuda")
|
33 |
+
|
34 |
+
# Placeholder for ControlNet models to be loaded dynamically
|
35 |
+
controlnet_models = {
|
36 |
+
"Canny": None,
|
37 |
+
"Depth": None,
|
38 |
+
"OpenPose": None,
|
39 |
+
"Reference": None
|
40 |
+
}
|
41 |
+
|
42 |
+
# Load necessary models and feature extractors for depth estimation and OpenPose
|
43 |
+
#feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
44 |
+
#depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
|
45 |
+
|
46 |
+
processor_zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
47 |
+
processor_midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
48 |
+
|
49 |
+
|
50 |
+
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
51 |
+
|
52 |
+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
|
53 |
+
controlnet_pipe = None # Initial placeholder, will be loaded dynamically
|
54 |
+
reference_pipe = None # Initial placeholder for reference pipeline
|
55 |
+
|
56 |
+
# Define the prompts and negative prompts for each style
|
57 |
+
styles = {
|
58 |
+
"Anime Studio Dance": {
|
59 |
+
"prompt": ("anime screencap of a man wearing a white helmet with pointed ears,\n"
|
60 |
+
"Outfit: closed animal print shirt,\n"
|
61 |
+
"Action: anime style, looking at viewer, solo, upper body,\n"
|
62 |
+
"((masterpiece)), (best quality), (extremely detailed), depth of field, sketch, "
|
63 |
+
"dark intense shadows, sharp focus, soft lighting, hdr, colorful, good composition, spectacular,"),
|
64 |
+
"negative_prompt": ("realistic, (painting by bad-artist-anime:0.9), (painting by bad-artist:0.9), watermark, "
|
65 |
+
"text, error, blurry, jpeg artifacts, cropped, worst quality, low quality, normal quality, "
|
66 |
+
"jpeg artifacts, signature, watermark, username, artist name, (worst quality, low quality:1.4), "
|
67 |
+
"bad anatomy, watermark, signature, text, logo")
|
68 |
+
},
|
69 |
+
"Vintage Realistic": {
|
70 |
+
"prompt": ("a masterpiece close up shoot photography of an man wearing a animal print helmet with pointed ears,\n"
|
71 |
+
"Outfit: wearing an big oversized outfit, white leather jacket,\n"
|
72 |
+
"Action: sitting on steps,\n"
|
73 |
+
"hyper realistic with detailed textures, cinematic film still of Photorealism, realistic skin texture, "
|
74 |
+
"subsurface scattering, skinny, Photorealism, often for highly detailed representation, photographic accuracy, "
|
75 |
+
"shallow depth of field, vignette, highly detailed, bokeh, epic, gorgeous, sharp, perfect hands,\n"
|
76 |
+
"<lora:add-detail-xl:1> <lora:Vintage_Street_Photo:0.9>"),
|
77 |
+
"negative_prompt": ("deformed skin, skin veins, black skin, blurry, text, yellow, deformed, (worst quality, low resolution, "
|
78 |
+
"bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d render, "
|
79 |
+
"distorted, twisted, watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, "
|
80 |
+
"glitch, deformed, mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, "
|
81 |
+
"(high contrast:1.2), (over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, "
|
82 |
+
"ugly, tiling, poorly drawn hands, 3d render, impressionism, digital art")
|
83 |
+
},
|
84 |
+
"Anime 90's Aesthetic": {
|
85 |
+
"prompt": ("an man wearing a white helmet with pointed ears, perfect chin,\n"
|
86 |
+
"Outfit: wearing oversized hoodie, animal print pants,\n"
|
87 |
+
"Action: dancing in nature, music production, music instruments made of wood,\n"
|
88 |
+
"A screengrab of an anime, 90's aesthetic,"),
|
89 |
+
"negative_prompt": ("photo, real, realistic, blurry, text, yellow, deformed, (worst quality, low resolution, bad hands,), "
|
90 |
+
"text, watermark, artist name, distorted, twisted, watermark, 3d render, distorted, twisted, watermark, "
|
91 |
+
"text, abstract, glitch, deformed, mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, "
|
92 |
+
"canvas frame, (high contrast:1.2), (over saturated:1.2), (glossy:1.1), disfigured, Photoshop, video game, "
|
93 |
+
"ugly, tiling, poorly drawn hands, 3d render, impressionism, eyes, mouth, black skin, pale skin, hair, beard")
|
94 |
+
},
|
95 |
+
"Anime Style": {
|
96 |
+
"prompt": ("An man wearing a white helmet with pointed ears sitting on the steps of an Asian street shop,\n"
|
97 |
+
"Outfit: wearing blue pants and a yellow jacket with a red backpack, in the anime style with detailed "
|
98 |
+
"character design in the style of Atey Ghailan, featured in CGSociety, character concept art in the style of Katsuhiro Otomo"),
|
99 |
+
"negative_prompt": ("real, deformed fingers, chin, deformed hands, blurry, text, yellow, deformed, (worst quality, low resolution, "
|
100 |
+
"bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d, distorted, twisted, "
|
101 |
+
"watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, "
|
102 |
+
"ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, (high contrast:1.2), "
|
103 |
+
"(over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, ugly, tiling, "
|
104 |
+
"poorly drawn hands, 3d render, impressionism, digital art")
|
105 |
+
},
|
106 |
+
"Real 70s": {
|
107 |
+
"prompt": ("a masterpiece close up shoot photography of an man wearing a white helmet with pointed ears,\n"
|
108 |
+
"Outfit: wearing an oversized trippy 70s shirt and scarf,\n"
|
109 |
+
"Action: standing on the ocean,\n"
|
110 |
+
"shot in the style of Erwin Olaf, hyper realistic with detailed textures, cinematic film still of Photorealism, "
|
111 |
+
"realistic skin texture, subsurface scattering, skinny, Photorealism, often for highly detailed representation, "
|
112 |
+
"photographic accuracy, shallow depth of field, vignette, highly detailed, bokeh, epic, gorgeous, sharp,"),
|
113 |
+
"negative_prompt": ("deformed skin, skin veins, black skin, blurry, text, yellow, deformed, (worst quality, low resolution, "
|
114 |
+
"bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d render, distorted, "
|
115 |
+
"twisted, watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, "
|
116 |
+
"mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, (high contrast:1.2), "
|
117 |
+
"(over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, ugly, tiling, "
|
118 |
+
"poorly drawn hands, 3d render, impressionism, digital art")
|
119 |
+
}
|
120 |
+
}
|
121 |
+
|
122 |
+
# Define the style images
|
123 |
+
style_images = {
|
124 |
+
"Anime Studio Dance": "style/Anime Studio Dance.png",
|
125 |
+
"Vintage Realistic": "style/Vintage Realistic.png",
|
126 |
+
"Anime 90's Aesthetic": "style/Anime 90's Aesthetic.png",
|
127 |
+
"Anime Style": "style/Anime Style.png",
|
128 |
+
"Real 70s": "style/Real 70s.png"
|
129 |
+
}
|
130 |
+
|
131 |
+
# Function to load ControlNet models dynamically
|
132 |
+
def load_controlnet_model(controlnet_type):
|
133 |
+
global controlnet_pipe, pipe, reference_pipe, controlnet_models, vae, model
|
134 |
+
|
135 |
+
clear_memory()
|
136 |
+
|
137 |
+
if controlnet_models[controlnet_type] is None:
|
138 |
+
if controlnet_type in ["Canny", "Depth", "OpenPose"]:
|
139 |
+
controlnet_models[controlnet_type] = ControlNetModel.from_pretrained(
|
140 |
+
"xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
|
141 |
+
)
|
142 |
+
elif controlnet_type == "Reference":
|
143 |
+
controlnet_models[controlnet_type] = StableDiffusionXLReferencePipeline.from_pretrained(
|
144 |
+
model, torch_dtype=torch.float16, use_safetensors=True
|
145 |
+
)
|
146 |
+
|
147 |
+
if 'controlnet_pipe' in globals() and controlnet_pipe is not None:
|
148 |
+
controlnet_pipe.to("cpu")
|
149 |
+
del controlnet_pipe
|
150 |
+
globals()['controlnet_pipe'] = None
|
151 |
+
|
152 |
+
if 'reference_pipe' in globals() and reference_pipe is not None:
|
153 |
+
reference_pipe.to("cpu")
|
154 |
+
del reference_pipe
|
155 |
+
globals()['reference_pipe'] = None
|
156 |
+
|
157 |
+
if pipe is not None:
|
158 |
+
pipe.to("cpu")
|
159 |
+
|
160 |
+
clear_memory()
|
161 |
+
|
162 |
+
if controlnet_type == "Reference":
|
163 |
+
reference_pipe = controlnet_models[controlnet_type]
|
164 |
+
reference_pipe.scheduler = UniPCMultistepScheduler.from_config(reference_pipe.scheduler.config)
|
165 |
+
reference_pipe.to("cuda")
|
166 |
+
globals()['reference_pipe'] = reference_pipe
|
167 |
+
else:
|
168 |
+
controlnet_pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
169 |
+
model, controlnet=controlnet_models[controlnet_type], vae=vae, torch_dtype=torch.float16, use_safetensors=True
|
170 |
+
)
|
171 |
+
controlnet_pipe.scheduler = UniPCMultistepScheduler.from_config(controlnet_pipe.scheduler.config)
|
172 |
+
controlnet_pipe.to("cuda")
|
173 |
+
globals()['controlnet_pipe'] = controlnet_pipe
|
174 |
+
|
175 |
+
clear_memory()
|
176 |
+
return f"Loaded {controlnet_type} model."
|
177 |
+
|
178 |
+
|
179 |
+
# Preprocessing functions for each ControlNet type
|
180 |
+
def preprocess_canny(image):
|
181 |
+
if isinstance(image, Image.Image):
|
182 |
+
image = np.array(image)
|
183 |
+
if image.dtype != np.uint8:
|
184 |
+
image = (image * 255).astype(np.uint8)
|
185 |
+
image = cv2.Canny(image, 100, 200)
|
186 |
+
image = image[:, :, None]
|
187 |
+
image = np.concatenate([image, image, image], axis=2)
|
188 |
+
return Image.fromarray(image)
|
189 |
+
|
190 |
+
|
191 |
+
def preprocess_depth(image, target_size=(1024, 1024)):
|
192 |
+
if isinstance(image, Image.Image):
|
193 |
+
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
194 |
+
else:
|
195 |
+
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
196 |
+
|
197 |
+
depth_img = processor_zoe(img, output_type='cv2') if random.random() > 0.5 else processor_midas(img, output_type='cv2')
|
198 |
+
|
199 |
+
height, width = depth_img.shape[:2]
|
200 |
+
ratio = min(target_size[0] / width, target_size[1] / height)
|
201 |
+
new_width, new_height = int(width * ratio), int(height * ratio)
|
202 |
+
depth_img_resized = cv2.resize(depth_img, (new_width, new_height))
|
203 |
+
|
204 |
+
return Image.fromarray(depth_img_resized)
|
205 |
+
|
206 |
+
def preprocess_openpose(image):
|
207 |
+
if isinstance(image, Image.Image):
|
208 |
+
image = np.array(image)
|
209 |
+
image = openpose_processor(image, hand_and_face=False, output_type='cv2')
|
210 |
+
height, width = image.shape[:2]
|
211 |
+
ratio = np.sqrt(1024. * 1024. / (width * height))
|
212 |
+
new_width, new_height = int(width * ratio), int(height * ratio)
|
213 |
+
image = cv2.resize(image, (new_width, new_height))
|
214 |
+
return Image.fromarray(image)
|
215 |
+
|
216 |
+
def process_image_batch(images, pipe, prompt, negative_prompt, progress, batch_size=2):
|
217 |
+
all_processed_images = []
|
218 |
+
for i in range(0, len(images), batch_size):
|
219 |
+
batch = images[i:i+batch_size]
|
220 |
+
batch_prompt = [prompt] * len(batch)
|
221 |
+
batch_negative_prompt = [negative_prompt] * len(batch)
|
222 |
+
|
223 |
+
if isinstance(pipe, StableDiffusionXLReferencePipeline):
|
224 |
+
processed_batch = []
|
225 |
+
for img in batch:
|
226 |
+
result = pipe(
|
227 |
+
prompt=prompt,
|
228 |
+
negative_prompt=negative_prompt,
|
229 |
+
ref_image=img,
|
230 |
+
num_inference_steps=20
|
231 |
+
).images
|
232 |
+
processed_batch.extend(result)
|
233 |
+
else:
|
234 |
+
processed_batch = pipe(
|
235 |
+
prompt=batch_prompt,
|
236 |
+
negative_prompt=batch_negative_prompt,
|
237 |
+
image=batch,
|
238 |
+
num_inference_steps=20
|
239 |
+
).images
|
240 |
+
|
241 |
+
all_processed_images.extend(processed_batch)
|
242 |
+
progress((i + batch_size) / len(images)) # Update progress bar
|
243 |
+
clear_memory() # Clear memory after each batch
|
244 |
+
return all_processed_images
|
245 |
+
|
246 |
+
|
247 |
+
# Define the function to generate images
|
248 |
+
def generate_images_with_progress(prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, control_image, batch_images_input, progress=gr.Progress(track_tqdm=True)):
|
249 |
+
global controlnet_pipe, pipe, reference_pipe
|
250 |
+
|
251 |
+
clear_memory()
|
252 |
+
|
253 |
+
if use_controlnet:
|
254 |
+
if controlnet_type not in controlnet_models or controlnet_models[controlnet_type] is None:
|
255 |
+
raise ValueError(f"{controlnet_type} model not loaded. Please load the model first.")
|
256 |
+
|
257 |
+
if mode == "Single Image":
|
258 |
+
control_images = [control_image]
|
259 |
+
else:
|
260 |
+
control_images = [Image.open(img).convert("RGB") for img in batch_images_input]
|
261 |
+
|
262 |
+
preprocessed_images = []
|
263 |
+
for img in tqdm(control_images, desc="Preprocessing images"):
|
264 |
+
if controlnet_type == "Canny":
|
265 |
+
preprocessed_images.append(preprocess_canny(img))
|
266 |
+
elif controlnet_type == "Depth":
|
267 |
+
preprocessed_images.append(preprocess_depth(img))
|
268 |
+
elif controlnet_type == "OpenPose":
|
269 |
+
preprocessed_images.append(preprocess_openpose(img))
|
270 |
+
else: # Reference
|
271 |
+
preprocessed_images.append(img)
|
272 |
+
|
273 |
+
if controlnet_type == "Reference":
|
274 |
+
images = process_image_batch(preprocessed_images, reference_pipe, prompt, negative_prompt, progress)
|
275 |
+
else:
|
276 |
+
images = process_image_batch(preprocessed_images, controlnet_pipe, prompt, negative_prompt, progress)
|
277 |
+
else:
|
278 |
+
if 'controlnet_pipe' in globals() and controlnet_pipe is not None:
|
279 |
+
controlnet_pipe.to("cpu")
|
280 |
+
del controlnet_pipe
|
281 |
+
globals()['controlnet_pipe'] = None
|
282 |
+
|
283 |
+
if 'reference_pipe' in globals() and reference_pipe is not None:
|
284 |
+
reference_pipe.to("cpu")
|
285 |
+
del reference_pipe
|
286 |
+
globals()['reference_pipe'] = None
|
287 |
+
|
288 |
+
clear_memory()
|
289 |
+
|
290 |
+
if pipe is None:
|
291 |
+
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
|
292 |
+
pipe.to("cuda")
|
293 |
+
|
294 |
+
images = []
|
295 |
+
for i in tqdm(range(batch_count), desc="Generating images"):
|
296 |
+
generated = pipe(prompt=[prompt], negative_prompt=[negative_prompt], num_inference_steps=20, width=1024, height=1024).images
|
297 |
+
images.extend(generated)
|
298 |
+
progress((i + 1) / batch_count) # Update progress bar
|
299 |
+
clear_memory() # Clear memory after each image, even in single image mode
|
300 |
+
|
301 |
+
clear_memory()
|
302 |
+
return images
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
# Function to extract PNG metadata
|
307 |
+
def extract_png_info(image_path):
|
308 |
+
metadata = image_path.info # This is a dictionary containing key-value pairs of metadata
|
309 |
+
return metadata
|
310 |
+
|
311 |
+
# Define the Gradio interface
|
312 |
+
with gr.Blocks() as demo:
|
313 |
+
gr.Markdown("# Image Generation with Custom Prompts and Styles")
|
314 |
+
|
315 |
+
with gr.Row():
|
316 |
+
with gr.Column():
|
317 |
+
prompt = gr.Textbox(label="Prompt", lines=8, interactive=True)
|
318 |
+
with gr.Accordion("Negative Prompt (Minimize/Expand)", open=False):
|
319 |
+
negative_prompt = gr.Textbox(
|
320 |
+
label="Negative Prompt",
|
321 |
+
value="",
|
322 |
+
lines=5
|
323 |
+
)
|
324 |
+
batch_count = gr.Slider(minimum=1, maximum=10, step=1, label="Batch Count", value=1)
|
325 |
+
use_controlnet = gr.Checkbox(label="Use ControlNet", value=False)
|
326 |
+
controlnet_type = gr.Dropdown(choices=["Canny", "Depth", "OpenPose", "Reference"], label="ControlNet Type")
|
327 |
+
controlnet_status = gr.Textbox(label="ControlNet Status", value="", interactive=False)
|
328 |
+
mode = gr.Radio(choices=["Single Image", "Batch"], label="Mode", value="Single Image")
|
329 |
+
|
330 |
+
with gr.Tabs() as tabs:
|
331 |
+
with gr.TabItem("Single Image"):
|
332 |
+
control_image = gr.Image(label="Control Image", type='pil')
|
333 |
+
|
334 |
+
with gr.TabItem("Batch"):
|
335 |
+
batch_images_input = gr.File(label="Upload Images", file_count='multiple')
|
336 |
+
|
337 |
+
with gr.TabItem("Extract Metadata"):
|
338 |
+
png_image = gr.Image(label="Upload PNG Image", type='pil')
|
339 |
+
metadata_output = gr.JSON(label="PNG Metadata")
|
340 |
+
|
341 |
+
with gr.Column(scale=2):
|
342 |
+
style_images_gallery = gr.Gallery(
|
343 |
+
label="Choose a Style",
|
344 |
+
value=list(style_images.values()),
|
345 |
+
interactive=True,
|
346 |
+
elem_id="style-gallery",
|
347 |
+
columns=5,
|
348 |
+
object_fit="contain",
|
349 |
+
allow_preview=False
|
350 |
+
)
|
351 |
+
gallery = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height="auto")
|
352 |
+
|
353 |
+
selected_style = gr.State(value="Anime Studio Dance")
|
354 |
+
|
355 |
+
def select_style(evt: gr.SelectData):
|
356 |
+
style_names = list(styles.keys())
|
357 |
+
if evt.index < 0 or evt.index >= len(style_names):
|
358 |
+
raise ValueError(f"Invalid index: {evt.index}")
|
359 |
+
selected_style = style_names[evt.index]
|
360 |
+
return styles[selected_style]["prompt"], styles[selected_style]["negative_prompt"], selected_style
|
361 |
+
|
362 |
+
style_images_gallery.select(fn=select_style, inputs=[], outputs=[prompt, negative_prompt, selected_style])
|
363 |
+
|
364 |
+
def update_controlnet(controlnet_type):
|
365 |
+
status = load_controlnet_model(controlnet_type)
|
366 |
+
return status
|
367 |
+
|
368 |
+
controlnet_type.change(fn=update_controlnet, inputs=controlnet_type, outputs=controlnet_status)
|
369 |
+
|
370 |
+
generate_button = gr.Button("Generate Images")
|
371 |
+
generate_button.click(
|
372 |
+
generate_images_with_progress,
|
373 |
+
inputs=[prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, control_image, batch_images_input],
|
374 |
+
outputs=gallery
|
375 |
+
)
|
376 |
+
|
377 |
+
metadata_button = gr.Button("Extract Metadata")
|
378 |
+
metadata_button.click(
|
379 |
+
fn=extract_png_info,
|
380 |
+
inputs=png_image,
|
381 |
+
outputs=metadata_output
|
382 |
+
)
|
383 |
+
|
384 |
+
with gr.Row():
|
385 |
+
generate_button
|
386 |
+
|
387 |
+
|
388 |
+
|
389 |
+
# At the end of your script:
|
390 |
+
if __name__ == "__main__":
|
391 |
+
# Your Gradio interface setup here
|
392 |
+
demo.launch(debug=True)
|
393 |
+
clear_memory()
|