Spaces:
Runtime error
Runtime error
import gradio as gr | |
from diffusers import DiffusionPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, UniPCMultistepScheduler | |
from stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline | |
from controlnet_aux import OpenposeDetector, MidasDetector, ZoeDetector | |
from tqdm import tqdm | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import os | |
import random | |
import gc | |
import tempfile | |
def clear_memory(): | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
def reset_ui(): | |
clear_memory() | |
return ( | |
"", # Reset prompt | |
"", # Reset negative prompt | |
1, # Reset batch count | |
30, # Reset number of inference steps | |
False, # Reset use controlnet | |
None, # Reset controlnet type | |
"Restart/Refresh completed", # Reset controlnet status with message | |
"Single Image", # Reset mode | |
False, # Reset use control folder | |
None, # Reset control image | |
[], # Reset selected folder images | |
None, # Reset batch images input | |
) | |
# Function to resize images while preserving the aspect ratio | |
def resize_image(image, max_size=1024): | |
width, height = image.size | |
if max(width, height) > max_size: | |
ratio = max_size / max(width, height) | |
new_width = int(width * ratio) | |
new_height = int(height * ratio) | |
image = image.resize((new_width, new_height), Image.ANTIALIAS) | |
return image | |
# Global variable definitions | |
controlnet_pipe = None | |
reference_pipe = None | |
pipe = None | |
current_controlnet_type = None | |
# Load the base model | |
model = "aicollective1/aicollective" | |
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16) | |
pipe.to("cuda") | |
# Placeholder for ControlNet models to be loaded dynamically | |
controlnet_models = { | |
"Canny": None, | |
"Depth": None, | |
"OpenPose": None, | |
"Reference": None | |
} | |
# Load necessary models and feature extractors for depth estimation and OpenPose | |
processor_zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators") | |
processor_midas = MidasDetector.from_pretrained("lllyasviel/Annotators") | |
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True) | |
controlnet_model_shared = ControlNetModel.from_pretrained( | |
"xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True | |
) | |
# Define the prompts and negative prompts for each style | |
styles = { | |
"Anime Studio Dance": { | |
"prompt": ( | |
"anime screencap of a man wearing a white helmet with pointed ears,\n" | |
"\n" | |
"closed animal print shirt,\n" | |
"\n" | |
"anime style, looking at viewer, solo, upper body,\n" | |
"\n" | |
"((masterpiece)), (best quality), (extremely detailed), depth of field, sketch, " | |
"dark intense shadows, sharp focus, soft lighting, hdr, colorful, good composition, spectacular," | |
), | |
"negative_prompt": ( | |
"realistic, (painting by bad-artist-anime:0.9), (painting by bad-artist:0.9), watermark, " | |
"text, error, blurry, jpeg artifacts, cropped, worst quality, low quality, normal quality, " | |
"jpeg artifacts, signature, watermark, username, artist name, (worst quality, low quality:1.4), " | |
"bad anatomy, watermark, signature, text, logo" | |
), | |
"steps": 40 | |
}, | |
"Vintage Realistic": { | |
"prompt": ( | |
"a masterpiece close up shoot photography of an man wearing a animal print helmet with pointed ears,\n" | |
"\n" | |
"wearing an big oversized outfit, white leather jacket,\n" | |
"\n" | |
"sitting on steps,\n" | |
"\n" | |
"hyper realistic with detailed textures, cinematic film still of Photorealism, realistic skin texture, " | |
"subsurface scattering, skinny, Photorealism, often for highly detailed representation, photographic accuracy, " | |
"shallow depth of field, vignette, highly detailed, bokeh, epic, gorgeous, sharp, perfect hands,\n" | |
"<lora:add-detail-xl:1> <lora:Vintage_Street_Photo:0.9>" | |
), | |
"negative_prompt": ( | |
"deformed skin, skin veins, black skin, blurry, text, yellow, deformed, (worst quality, low resolution, " | |
"bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d render, " | |
"distorted, twisted, watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, " | |
"glitch, deformed, mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, " | |
"(high contrast:1.2), (over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, " | |
"ugly, tiling, poorly drawn hands, 3d render, impressionism, digital art" | |
), | |
"steps": 30 | |
}, | |
"Anime 90's Aesthetic": { | |
"prompt": ( | |
"an man wearing a white helmet with pointed ears, perfect chin,\n" | |
"\n" | |
"wearing oversized hoodie, animal print pants,\n" | |
"\n" | |
"dancing in nature, music production, music instruments made of wood,\n" | |
"\n" | |
"A screengrab of an anime, 90's aesthetic," | |
), | |
"negative_prompt": ( | |
"photo, real, realistic, blurry, text, yellow, deformed, (worst quality, low resolution, bad hands,), " | |
"text, watermark, artist name, distorted, twisted, watermark, 3d render, distorted, twisted, watermark, " | |
"text, abstract, glitch, deformed, mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, " | |
"canvas frame, (high contrast:1.2), (over saturated:1.2), (glossy:1.1), disfigured, Photoshop, video game, " | |
"ugly, tiling, poorly drawn hands, 3d render, impressionism, eyes, mouth, black skin, pale skin, hair, beard" | |
), | |
"steps": 34 | |
}, | |
"Anime Style": { | |
"prompt": ( | |
"An man wearing a white helmet with pointed ears sitting on the steps of an Asian street shop,\n" | |
"\n" | |
"wearing blue pants and a yellow jacket with a red backpack, in the anime style with detailed " | |
"character design in the style of Atey Ghailan, featured in CGSociety, character concept art in the style of Katsuhiro Otomo" | |
), | |
"negative_prompt": ( | |
"real, deformed fingers, chin, deformed hands, blurry, text, yellow, deformed, (worst quality, low resolution, " | |
"bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d, distorted, twisted, " | |
"watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, " | |
"ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, (high contrast:1.2), " | |
"(over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, ugly, tiling, " | |
"poorly drawn hands, 3d render, impressionism, digital art" | |
), | |
"steps": 28 | |
}, | |
"Real 70s": { | |
"prompt": ( | |
"a masterpiece close up shoot photography of an man wearing a white helmet with pointed ears,\n" | |
"\n" | |
"wearing an oversized trippy 70s shirt and scarf,\n" | |
"\n" | |
"standing on the ocean,\n" | |
"\n" | |
"shot in the style of Erwin Olaf, hyper realistic with detailed textures, cinematic film still of Photorealism, " | |
"realistic skin texture, subsurface scattering, skinny, Photorealism, often for highly detailed representation, " | |
"photographic accuracy, shallow depth of field, vignette, highly detailed, bokeh, epic, gorgeous, sharp," | |
), | |
"negative_prompt": ( | |
"deformed skin, skin veins, black skin, blurry, text, yellow, deformed, (worst quality, low resolution, " | |
"bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d render, distorted, " | |
"twisted, watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, " | |
"mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, (high contrast:1.2), " | |
"(over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, ugly, tiling, " | |
"poorly drawn hands, 3d render, impressionism, digital art" | |
), | |
"steps": 40 | |
} | |
} | |
# Define the style images | |
style_images = { | |
"Anime Studio Dance": "style/Anime Studio Dance.png", | |
"Vintage Realistic": "style/Vintage Realistic.png", | |
"Anime 90's Aesthetic": "style/Anime 90's Aesthetic.png", | |
"Anime Style": "style/Anime Style.png", | |
"Real 70s": "style/Real 70s.png" | |
} | |
# Function to load ControlNet models dynamically | |
def load_controlnet_model(controlnet_type): | |
global controlnet_pipe, pipe, reference_pipe, controlnet_models, vae, model, current_controlnet_type, controlnet_model_shared | |
clear_memory() | |
if controlnet_models[controlnet_type] is None: | |
if controlnet_type in ["Canny", "Depth", "OpenPose"]: | |
controlnet_models[controlnet_type] = controlnet_model_shared | |
elif controlnet_type == "Reference": | |
controlnet_models[controlnet_type] = StableDiffusionXLReferencePipeline.from_pretrained( | |
model, torch_dtype=torch.float16, use_safetensors=True | |
) | |
if current_controlnet_type == controlnet_type: | |
return f"{controlnet_type} model already loaded." | |
if 'controlnet_pipe' in globals() and controlnet_pipe is not None: | |
controlnet_pipe.to("cpu") | |
del controlnet_pipe | |
globals()['controlnet_pipe'] = None | |
if 'reference_pipe' in globals() and reference_pipe is not None: | |
reference_pipe.to("cpu") | |
del reference_pipe | |
globals()['reference_pipe'] = None | |
if pipe is not None: | |
pipe.to("cpu") | |
clear_memory() | |
if controlnet_type == "Reference": | |
reference_pipe = controlnet_models[controlnet_type] | |
reference_pipe.scheduler = UniPCMultistepScheduler.from_config(reference_pipe.scheduler.config) | |
reference_pipe.to("cuda") | |
globals()['reference_pipe'] = reference_pipe | |
else: | |
controlnet_pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | |
model, controlnet=controlnet_models[controlnet_type], vae=vae, torch_dtype=torch.float16, use_safetensors=True | |
) | |
controlnet_pipe.scheduler = UniPCMultistepScheduler.from_config(controlnet_pipe.scheduler.config) | |
controlnet_pipe.to("cuda") | |
globals()['controlnet_pipe'] = controlnet_pipe | |
current_controlnet_type = controlnet_type | |
clear_memory() | |
return f"Loaded {controlnet_type} model." | |
# Preprocessing functions for each ControlNet type | |
def preprocess_canny(image): | |
if isinstance(image, str): | |
image = Image.open(image).convert("RGB") | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
if image.dtype != np.uint8: | |
image = (image * 255).astype(np.uint8) | |
image = cv2.Canny(image, 100, 200) | |
image = image[:, :, None] | |
image = np.concatenate([image, image, image], axis=2) | |
return Image.fromarray(image) | |
def preprocess_depth(image, target_size=(1024, 1024)): | |
if isinstance(image, str): | |
image = Image.open(image).convert("RGB") | |
if isinstance(image, Image.Image): | |
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
else: | |
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
depth_img = processor_zoe(img, output_type='cv2') if random.random() > 0.5 else processor_midas(img, output_type='cv2') | |
height, width = depth_img.shape[:2] | |
ratio = min(target_size[0] / width, target_size[1] / height) | |
new_width, new_height = int(width * ratio), int(height * ratio) | |
depth_img_resized = cv2.resize(depth_img, (new_width, new_height)) | |
return Image.fromarray(depth_img_resized) | |
def preprocess_openpose(image): | |
if isinstance(image, str): | |
image = Image.open(image).convert("RGB") | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
image = openpose_processor(image, hand_and_face=False, output_type='cv2') | |
height, width = image.shape[:2] | |
ratio = np.sqrt(1024. * 1024. / (width * height)) | |
new_width, new_height = int(width * ratio), int(height * ratio) | |
image = cv2.resize(image, (new_width, new_height)) | |
return Image.fromarray(image) | |
def process_image_batch(images, pipe, prompt, negative_prompt, num_inference_steps, progress, batch_size=2): | |
all_processed_images = [] | |
for i in range(0, len(images), batch_size): | |
batch = images[i:i+batch_size] | |
batch_prompt = [prompt] * len(batch) | |
batch_negative_prompt = [negative_prompt] * len(batch) | |
if isinstance(pipe, StableDiffusionXLReferencePipeline): | |
processed_batch = [] | |
for img in batch: | |
result = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
ref_image=img, | |
num_inference_steps=num_inference_steps, | |
reference_attn=True, | |
reference_adain=True | |
).images | |
processed_batch.extend(result) | |
else: | |
processed_batch = pipe( | |
prompt=batch_prompt, | |
negative_prompt=batch_negative_prompt, | |
image=batch, | |
num_inference_steps=num_inference_steps | |
).images | |
all_processed_images.extend(processed_batch) | |
progress((i + batch_size) / len(images)) # Update progress bar | |
clear_memory() # Clear memory after each batch | |
return all_processed_images | |
# Function to save images as PNG and return their paths | |
def save_images_as_png(images): | |
temp_dir = tempfile.mkdtemp() | |
png_paths = [] | |
for i, img in enumerate(images): | |
png_path = os.path.join(temp_dir, f"image_{i}.png") | |
img.save(png_path, "PNG") | |
png_paths.append(png_path) | |
return png_paths | |
# Define the function to generate images | |
def generate_images_with_progress(prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, control_images, num_inference_steps, progress=gr.Progress(track_tqdm=True)): | |
global controlnet_pipe, pipe, reference_pipe | |
clear_memory() | |
chunk_size = 1 # Adjust this number based on your memory capacity | |
if use_controlnet: | |
if controlnet_type not in controlnet_models or controlnet_models[controlnet_type] is None: | |
raise ValueError(f"{controlnet_type} model not loaded. Please load the model first.") | |
if mode == "Single Image": | |
control_images = [control_images] if isinstance(control_images, Image.Image) else control_images | |
else: | |
if not control_images: | |
raise ValueError("No images provided for batch processing.") | |
control_images = [Image.open(img).convert("RGB") if isinstance(img, str) else img for img in control_images] | |
preprocessed_images = [] | |
for img in tqdm(control_images, desc="Preprocessing images"): | |
img = resize_image(img) # Resize the image before preprocessing | |
if controlnet_type == "Canny": | |
preprocessed_images.append(preprocess_canny(img)) | |
elif controlnet_type == "Depth": | |
preprocessed_images.append(preprocess_depth(img)) | |
elif controlnet_type == "OpenPose": | |
preprocessed_images.append(preprocess_openpose(img)) | |
else: # Reference | |
preprocessed_images.append(img) | |
images = [] | |
for i in range(0, len(preprocessed_images), chunk_size): | |
chunk = preprocessed_images[i:i+chunk_size] | |
if controlnet_type == "Reference": | |
images_chunk = process_image_batch(chunk, reference_pipe, prompt, negative_prompt, num_inference_steps, progress) | |
else: | |
images_chunk = process_image_batch(chunk, controlnet_pipe, prompt, negative_prompt, num_inference_steps, progress) | |
images.extend(images_chunk) | |
clear_memory() | |
else: | |
if 'controlnet_pipe' in globals() and controlnet_pipe is not None: | |
controlnet_pipe.to("cpu") | |
del controlnet_pipe | |
globals()['controlnet_pipe'] = None | |
if 'reference_pipe' in globals() and reference_pipe is not None: | |
reference_pipe.to("cpu") | |
del reference_pipe | |
globals()['reference_pipe'] = None | |
clear_memory() | |
if pipe is None: | |
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16) | |
pipe.to("cuda") | |
images = [] | |
for i in tqdm(range(batch_count), desc="Generating images"): | |
generated = pipe(prompt=[prompt], negative_prompt=[negative_prompt], num_inference_steps=num_inference_steps, width=1024, height=1024).images | |
images.extend(generated) | |
progress((i + 1) / batch_count) # Update progress bar | |
clear_memory() # Clear memory after each image, even in single image mode | |
clear_memory() | |
# Save images as PNG and return their paths | |
png_paths = save_images_as_png(images) | |
return png_paths | |
# Function to extract PNG metadata | |
def extract_png_info(image_path): | |
metadata = image_path.info # This is a dictionary containing key-value pairs of metadata | |
return metadata | |
# Load images from the specified folder | |
def load_images_from_folder(folder_path): | |
images = [] | |
for filename in os.listdir(folder_path): | |
if filename.endswith(('.png', '.jpg', '.jpeg')): | |
img_path = os.path.join(folder_path, filename) | |
img = Image.open(img_path).convert("RGB") | |
img = resize_image(img) # Resize the image before adding to the list | |
images.append((filename, img)) | |
return images | |
# Folder path where images are stored | |
image_folder_path = "control" # Update this path to your folder | |
# Load images from folder | |
loaded_images = load_images_from_folder(image_folder_path) | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image Generation with Custom Prompts and Styles") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", lines=8, interactive=True) | |
with gr.Accordion("Negative Prompt (Minimize/Expand)", open=False): | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="", | |
lines=5 | |
) | |
batch_count = gr.Slider(minimum=1, maximum=50, step=1, label="Batch Count", value=1) | |
num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Number of Inference Steps", value=30) | |
use_controlnet = gr.Checkbox(label="Use ControlNet", value=False) | |
controlnet_type = gr.Dropdown(choices=["Canny", "Depth", "OpenPose", "Reference"], label="ControlNet Type") | |
controlnet_status = gr.Textbox(label="Status", value="", interactive=False) | |
mode = gr.Radio(choices=["Single Image", "Batch", "Multiselect"], label="Mode", value="Single Image") | |
use_control_folder = gr.Checkbox(label="Use Control Folder for Batch Processing", value=False) | |
with gr.Tabs() as tabs: | |
with gr.TabItem("Single Image"): | |
control_image = gr.Image(label="Control Image", type='pil') | |
with gr.TabItem("Batch"): | |
batch_images_input = gr.File(label="Upload Images", file_count='multiple') | |
with gr.TabItem("Extract Metadata"): | |
png_image = gr.Image(label="Upload PNG Image", type='pil') | |
metadata_output = gr.JSON(label="PNG Metadata") | |
with gr.TabItem("Select from Folder"): | |
folder_images_gallery = gr.Gallery( | |
label="Images from Folder", | |
value=[img[1] for img in loaded_images], | |
interactive=True, | |
elem_id="folder-gallery", | |
columns=5, | |
object_fit="contain", | |
height=235, | |
allow_preview=False | |
) | |
clear_selection_button = gr.Button("Clear Selection") | |
with gr.Column(scale=2): | |
style_images_gallery = gr.Gallery( | |
label="Choose a Style", | |
value=list(style_images.values()), | |
interactive=True, | |
elem_id="style-gallery", | |
columns=5, | |
object_fit="contain", | |
height=235, | |
allow_preview=False | |
) | |
generate_button = gr.Button("Generate Images") | |
gallery = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=820) | |
selected_style = gr.State(value="Anime Studio Dance") | |
def select_style(evt: gr.SelectData): | |
style_names = list(styles.keys()) | |
if evt.index < 0 or evt.index >= len(style_names): | |
raise ValueError(f"Invalid index: {evt.index}") | |
selected_style = style_names[evt.index] | |
return styles[selected_style]["prompt"], styles[selected_style]["negative_prompt"], styles[selected_style]["steps"], selected_style | |
style_images_gallery.select(fn=select_style, inputs=[], outputs=[prompt, negative_prompt, num_inference_steps, selected_style]) | |
def update_controlnet(controlnet_type): | |
status = load_controlnet_model(controlnet_type) | |
return status | |
controlnet_type.change(fn=update_controlnet, inputs=controlnet_type, outputs=controlnet_status) | |
selected_folder_images = gr.State(value=[]) | |
def select_folder_image(evt: gr.SelectData, selected_folder_images, mode): | |
folder_image_names = [img[0] for img in loaded_images] | |
if evt.index < 0 or evt.index >= len(folder_image_names): | |
raise ValueError(f"Invalid index: {evt.index}") | |
selected_image_name = folder_image_names[evt.index] | |
selected_image = next(img for img in loaded_images if img[0] == selected_image_name) | |
current_images = selected_folder_images or [] | |
if mode == "Single Image": | |
current_images = [selected_image] | |
else: | |
if selected_image not in current_images: | |
current_images.append(selected_image) | |
return current_images | |
def clear_selected_folder_images(): | |
return [] | |
folder_images_gallery.select(fn=select_folder_image, inputs=[selected_folder_images, mode], outputs=selected_folder_images) | |
clear_selection_button.click(fn=clear_selected_folder_images, inputs=[], outputs=selected_folder_images) | |
def generate_images_with_folder_images(prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, use_control_folder, selected_folder_images, batch_images_input, num_inference_steps, control_image, progress=gr.Progress(track_tqdm=True)): | |
if mode == "Batch" and use_control_folder: | |
selected_images = [img[1] for img in loaded_images] | |
elif mode == "Batch": | |
if not batch_images_input: | |
raise ValueError("No images uploaded for batch processing.") | |
selected_images = [resize_image(Image.open(img).convert("RGB")) for img in batch_images_input] | |
elif mode == "Single Image" and control_image is not None: | |
selected_images = [control_image] | |
else: | |
selected_images = [img[1] for img in selected_folder_images] | |
# Adjust the batch_count here to generate the desired number of images | |
selected_images = selected_images * batch_count | |
return generate_images_with_progress(prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, selected_images, num_inference_steps, progress) | |
generate_button.click( | |
generate_images_with_folder_images, | |
inputs=[prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, use_control_folder, selected_folder_images, batch_images_input, num_inference_steps, control_image], | |
outputs=gallery | |
) | |
metadata_button = gr.Button("Extract Metadata") | |
metadata_button.click( | |
fn=extract_png_info, | |
inputs=png_image, | |
outputs=metadata_output | |
) | |
refresh_button = gr.Button("Restart/Refresh") | |
refresh_button.click( | |
fn=reset_ui, | |
inputs=[], | |
outputs=[ | |
prompt, | |
negative_prompt, | |
batch_count, | |
num_inference_steps, | |
use_controlnet, | |
controlnet_type, | |
controlnet_status, | |
mode, | |
use_control_folder, | |
control_image, | |
selected_folder_images, | |
batch_images_input | |
] | |
) | |
with gr.Row(): | |
refresh_button | |
# At the end of your script: | |
if __name__ == "__main__": | |
# Your Gradio interface setup here | |
demo.launch(auth=("roland", "roland"), debug=True) | |
clear_memory() | |