Spaces:
Runtime error
Runtime error
import gradio as gr | |
from diffusers import DiffusionPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, UniPCMultistepScheduler | |
from stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline | |
from controlnet_aux import OpenposeDetector, MidasDetector, ZoeDetector | |
from tqdm import tqdm | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import os | |
import random | |
import gc | |
def clear_memory(): | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
# Global variable definitions | |
controlnet_pipe = None | |
reference_pipe = None | |
pipe = None | |
current_controlnet_type = None | |
# Load the base model | |
model = "aicollective1/aicollective" | |
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16) | |
pipe.to("cuda") | |
# Placeholder for ControlNet models to be loaded dynamically | |
controlnet_models = { | |
"Canny": None, | |
"Depth": None, | |
"OpenPose": None, | |
"Reference": None | |
} | |
# Load necessary models and feature extractors for depth estimation and OpenPose | |
processor_zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators") | |
processor_midas = MidasDetector.from_pretrained("lllyasviel/Annotators") | |
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True) | |
controlnet_model_shared = ControlNetModel.from_pretrained( | |
"xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True | |
) | |
# Define the prompts and negative prompts for each style | |
styles = { | |
"Anime Studio Dance": { | |
"prompt": ("anime screencap of a man wearing a white helmet with pointed ears,\n" | |
"Outfit: closed animal print shirt,\n" | |
"Action: anime style, looking at viewer, solo, upper body,\n" | |
"((masterpiece)), (best quality), (extremely detailed), depth of field, sketch, " | |
"dark intense shadows, sharp focus, soft lighting, hdr, colorful, good composition, spectacular,"), | |
"negative_prompt": ("realistic, (painting by bad-artist-anime:0.9), (painting by bad-artist:0.9), watermark, " | |
"text, error, blurry, jpeg artifacts, cropped, worst quality, low quality, normal quality, " | |
"jpeg artifacts, signature, watermark, username, artist name, (worst quality, low quality:1.4), " | |
"bad anatomy, watermark, signature, text, logo") | |
}, | |
"Vintage Realistic": { | |
"prompt": ("a masterpiece close up shoot photography of an man wearing a animal print helmet with pointed ears,\n" | |
"Outfit: wearing an big oversized outfit, white leather jacket,\n" | |
"Action: sitting on steps,\n" | |
"hyper realistic with detailed textures, cinematic film still of Photorealism, realistic skin texture, " | |
"subsurface scattering, skinny, Photorealism, often for highly detailed representation, photographic accuracy, " | |
"shallow depth of field, vignette, highly detailed, bokeh, epic, gorgeous, sharp, perfect hands,\n" | |
"<lora:add-detail-xl:1> <lora:Vintage_Street_Photo:0.9>"), | |
"negative_prompt": ("deformed skin, skin veins, black skin, blurry, text, yellow, deformed, (worst quality, low resolution, " | |
"bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d render, " | |
"distorted, twisted, watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, " | |
"glitch, deformed, mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, " | |
"(high contrast:1.2), (over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, " | |
"ugly, tiling, poorly drawn hands, 3d render, impressionism, digital art") | |
}, | |
"Anime 90's Aesthetic": { | |
"prompt": ("an man wearing a white helmet with pointed ears, perfect chin,\n" | |
"Outfit: wearing oversized hoodie, animal print pants,\n" | |
"Action: dancing in nature, music production, music instruments made of wood,\n" | |
"A screengrab of an anime, 90's aesthetic,"), | |
"negative_prompt": ("photo, real, realistic, blurry, text, yellow, deformed, (worst quality, low resolution, bad hands,), " | |
"text, watermark, artist name, distorted, twisted, watermark, 3d render, distorted, twisted, watermark, " | |
"text, abstract, glitch, deformed, mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, " | |
"canvas frame, (high contrast:1.2), (over saturated:1.2), (glossy:1.1), disfigured, Photoshop, video game, " | |
"ugly, tiling, poorly drawn hands, 3d render, impressionism, eyes, mouth, black skin, pale skin, hair, beard") | |
}, | |
"Anime Style": { | |
"prompt": ("An man wearing a white helmet with pointed ears sitting on the steps of an Asian street shop,\n" | |
"Outfit: wearing blue pants and a yellow jacket with a red backpack, in the anime style with detailed " | |
"character design in the style of Atey Ghailan, featured in CGSociety, character concept art in the style of Katsuhiro Otomo"), | |
"negative_prompt": ("real, deformed fingers, chin, deformed hands, blurry, text, yellow, deformed, (worst quality, low resolution, " | |
"bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d, distorted, twisted, " | |
"watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, " | |
"ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, (high contrast:1.2), " | |
"(over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, ugly, tiling, " | |
"poorly drawn hands, 3d render, impressionism, digital art") | |
}, | |
"Real 70s": { | |
"prompt": ("a masterpiece close up shoot photography of an man wearing a white helmet with pointed ears,\n" | |
"Outfit: wearing an oversized trippy 70s shirt and scarf,\n" | |
"Action: standing on the ocean,\n" | |
"shot in the style of Erwin Olaf, hyper realistic with detailed textures, cinematic film still of Photorealism, " | |
"realistic skin texture, subsurface scattering, skinny, Photorealism, often for highly detailed representation, " | |
"photographic accuracy, shallow depth of field, vignette, highly detailed, bokeh, epic, gorgeous, sharp,"), | |
"negative_prompt": ("deformed skin, skin veins, black skin, blurry, text, yellow, deformed, (worst quality, low resolution, " | |
"bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d render, distorted, " | |
"twisted, watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, " | |
"mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, (high contrast:1.2), " | |
"(over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, ugly, tiling, " | |
"poorly drawn hands, 3d render, impressionism, digital art") | |
} | |
} | |
# Define the style images | |
style_images = { | |
"Anime Studio Dance": "style/Anime Studio Dance.png", | |
"Vintage Realistic": "style/Vintage Realistic.png", | |
"Anime 90's Aesthetic": "style/Anime 90's Aesthetic.png", | |
"Anime Style": "style/Anime Style.png", | |
"Real 70s": "style/Real 70s.png" | |
} | |
# Function to load ControlNet models dynamically | |
def load_controlnet_model(controlnet_type): | |
global controlnet_pipe, pipe, reference_pipe, controlnet_models, vae, model, current_controlnet_type, controlnet_model_shared | |
clear_memory() | |
if controlnet_models[controlnet_type] is None: | |
if controlnet_type in ["Canny", "Depth", "OpenPose"]: | |
controlnet_models[controlnet_type] = controlnet_model_shared | |
elif controlnet_type == "Reference": | |
controlnet_models[controlnet_type] = StableDiffusionXLReferencePipeline.from_pretrained( | |
model, torch_dtype=torch.float16, use_safetensors=True | |
) | |
if current_controlnet_type == controlnet_type: | |
return f"{controlnet_type} model already loaded." | |
if 'controlnet_pipe' in globals() and controlnet_pipe is not None: | |
controlnet_pipe.to("cpu") | |
del controlnet_pipe | |
globals()['controlnet_pipe'] = None | |
if 'reference_pipe' in globals() and reference_pipe is not None: | |
reference_pipe.to("cpu") | |
del reference_pipe | |
globals()['reference_pipe'] = None | |
if pipe is not None: | |
pipe.to("cpu") | |
clear_memory() | |
if controlnet_type == "Reference": | |
reference_pipe = controlnet_models[controlnet_type] | |
reference_pipe.scheduler = UniPCMultistepScheduler.from_config(reference_pipe.scheduler.config) | |
reference_pipe.to("cuda") | |
globals()['reference_pipe'] = reference_pipe | |
else: | |
controlnet_pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | |
model, controlnet=controlnet_models[controlnet_type], vae=vae, torch_dtype=torch.float16, use_safetensors=True | |
) | |
controlnet_pipe.scheduler = UniPCMultistepScheduler.from_config(controlnet_pipe.scheduler.config) | |
controlnet_pipe.to("cuda") | |
globals()['controlnet_pipe'] = controlnet_pipe | |
current_controlnet_type = controlnet_type | |
clear_memory() | |
return f"Loaded {controlnet_type} model." | |
# Preprocessing functions for each ControlNet type | |
def preprocess_canny(image): | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
if image.dtype != np.uint8: | |
image = (image * 255).astype(np.uint8) | |
image = cv2.Canny(image, 100, 200) | |
image = image[:, :, None] | |
image = np.concatenate([image, image, image], axis=2) | |
return Image.fromarray(image) | |
def preprocess_depth(image, target_size=(1024, 1024)): | |
if isinstance(image, Image.Image): | |
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
else: | |
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
depth_img = processor_zoe(img, output_type='cv2') if random.random() > 0.5 else processor_midas(img, output_type='cv2') | |
height, width = depth_img.shape[:2] | |
ratio = min(target_size[0] / width, target_size[1] / height) | |
new_width, new_height = int(width * ratio), int(height * ratio) | |
depth_img_resized = cv2.resize(depth_img, (new_width, new_height)) | |
return Image.fromarray(depth_img_resized) | |
def preprocess_openpose(image): | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
image = openpose_processor(image, hand_and_face=False, output_type='cv2') | |
height, width = image.shape[:2] | |
ratio = np.sqrt(1024. * 1024. / (width * height)) | |
new_width, new_height = int(width * ratio), int(height * ratio) | |
image = cv2.resize(image, (new_width, new_height)) | |
return Image.fromarray(image) | |
def process_image_batch(images, pipe, prompt, negative_prompt, progress, batch_size=2): | |
all_processed_images = [] | |
for i in range(0, len(images), batch_size): | |
batch = images[i:i+batch_size] | |
batch_prompt = [prompt] * len(batch) | |
batch_negative_prompt = [negative_prompt] * len(batch) | |
if isinstance(pipe, StableDiffusionXLReferencePipeline): | |
processed_batch = [] | |
for img in batch: | |
result = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
ref_image=img, | |
num_inference_steps=20 | |
).images | |
processed_batch.extend(result) | |
else: | |
processed_batch = pipe( | |
prompt=batch_prompt, | |
negative_prompt=batch_negative_prompt, | |
image=batch, | |
num_inference_steps=20 | |
).images | |
all_processed_images.extend(processed_batch) | |
progress((i + batch_size) / len(images)) # Update progress bar | |
clear_memory() # Clear memory after each batch | |
return all_processed_images | |
# Define the function to generate images | |
def generate_images_with_progress(prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, control_image, batch_images_input, progress=gr.Progress(track_tqdm=True)): | |
global controlnet_pipe, pipe, reference_pipe | |
clear_memory() | |
if use_controlnet: | |
if controlnet_type not in controlnet_models or controlnet_models[controlnet_type] is None: | |
raise ValueError(f"{controlnet_type} model not loaded. Please load the model first.") | |
if mode == "Single Image": | |
control_images = [control_image] | |
else: | |
control_images = [Image.open(img).convert("RGB") for img in batch_images_input] | |
preprocessed_images = [] | |
for img in tqdm(control_images, desc="Preprocessing images"): | |
if controlnet_type == "Canny": | |
preprocessed_images.append(preprocess_canny(img)) | |
elif controlnet_type == "Depth": | |
preprocessed_images.append(preprocess_depth(img)) | |
elif controlnet_type == "OpenPose": | |
preprocessed_images.append(preprocess_openpose(img)) | |
else: # Reference | |
preprocessed_images.append(img) | |
if controlnet_type == "Reference": | |
images = process_image_batch(preprocessed_images, reference_pipe, prompt, negative_prompt, progress) | |
else: | |
images = process_image_batch(preprocessed_images, controlnet_pipe, prompt, negative_prompt, progress) | |
else: | |
if 'controlnet_pipe' in globals() and controlnet_pipe is not None: | |
controlnet_pipe.to("cpu") | |
del controlnet_pipe | |
globals()['controlnet_pipe'] = None | |
if 'reference_pipe' in globals() and reference_pipe is not None: | |
reference_pipe.to("cpu") | |
del reference_pipe | |
globals()['reference_pipe'] = None | |
clear_memory() | |
if pipe is None: | |
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16) | |
pipe.to("cuda") | |
images = [] | |
for i in tqdm(range(batch_count), desc="Generating images"): | |
generated = pipe(prompt=[prompt], negative_prompt=[negative_prompt], num_inference_steps=20, width=1024, height=1024).images | |
images.extend(generated) | |
progress((i + 1) / batch_count) # Update progress bar | |
clear_memory() # Clear memory after each image, even in single image mode | |
clear_memory() | |
return images | |
# Function to extract PNG metadata | |
def extract_png_info(image_path): | |
metadata = image_path.info # This is a dictionary containing key-value pairs of metadata | |
return metadata | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image Generation with Custom Prompts and Styles") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", lines=8, interactive=True) | |
with gr.Accordion("Negative Prompt (Minimize/Expand)", open=False): | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="", | |
lines=5 | |
) | |
batch_count = gr.Slider(minimum=1, maximum=10, step=1, label="Batch Count", value=1) | |
use_controlnet = gr.Checkbox(label="Use ControlNet", value=False) | |
controlnet_type = gr.Dropdown(choices=["Canny", "Depth", "OpenPose", "Reference"], label="ControlNet Type") | |
controlnet_status = gr.Textbox(label="ControlNet Status", value="", interactive=False) | |
mode = gr.Radio(choices=["Single Image", "Batch"], label="Mode", value="Single Image") | |
with gr.Tabs() as tabs: | |
with gr.TabItem("Single Image"): | |
control_image = gr.Image(label="Control Image", type='pil') | |
with gr.TabItem("Batch"): | |
batch_images_input = gr.File(label="Upload Images", file_count='multiple') | |
with gr.TabItem("Extract Metadata"): | |
png_image = gr.Image(label="Upload PNG Image", type='pil') | |
metadata_output = gr.JSON(label="PNG Metadata") | |
with gr.Column(scale=2): | |
style_images_gallery = gr.Gallery( | |
label="Choose a Style", | |
value=list(style_images.values()), | |
interactive=True, | |
elem_id="style-gallery", | |
columns=5, | |
object_fit="contain", | |
height=235, | |
allow_preview=False | |
) | |
gallery = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=785) | |
selected_style = gr.State(value="Anime Studio Dance") | |
def select_style(evt: gr.SelectData): | |
style_names = list(styles.keys()) | |
if evt.index < 0 or evt.index >= len(style_names): | |
raise ValueError(f"Invalid index: {evt.index}") | |
selected_style = style_names[evt.index] | |
return styles[selected_style]["prompt"], styles[selected_style]["negative_prompt"], selected_style | |
style_images_gallery.select(fn=select_style, inputs=[], outputs=[prompt, negative_prompt, selected_style]) | |
def update_controlnet(controlnet_type): | |
status = load_controlnet_model(controlnet_type) | |
return status | |
controlnet_type.change(fn=update_controlnet, inputs=controlnet_type, outputs=controlnet_status) | |
generate_button = gr.Button("Generate Images") | |
generate_button.click( | |
generate_images_with_progress, | |
inputs=[prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, control_image, batch_images_input], | |
outputs=gallery | |
) | |
metadata_button = gr.Button("Extract Metadata") | |
metadata_button.click( | |
fn=extract_png_info, | |
inputs=png_image, | |
outputs=metadata_output | |
) | |
with gr.Row(): | |
generate_button | |
# At the end of your script: | |
if __name__ == "__main__": | |
# Your Gradio interface setup here | |
demo.launch(auth=("roland", "roland"), debug=True) | |
clear_memory() | |