aicollectiveapp / app.py
Vijish's picture
Update app.py
391ab3b verified
raw
history blame
18.8 kB
import gradio as gr
from diffusers import DiffusionPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, UniPCMultistepScheduler
from stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
from controlnet_aux import OpenposeDetector, MidasDetector, ZoeDetector
from tqdm import tqdm
import torch
import numpy as np
import cv2
from PIL import Image
import os
import random
import gc
def clear_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# Global variable definitions
controlnet_pipe = None
reference_pipe = None
pipe = None
current_controlnet_type = None
# Load the base model
model = "aicollective1/aicollective"
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
pipe.to("cuda")
# Placeholder for ControlNet models to be loaded dynamically
controlnet_models = {
"Canny": None,
"Depth": None,
"OpenPose": None,
"Reference": None
}
# Load necessary models and feature extractors for depth estimation and OpenPose
processor_zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processor_midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
controlnet_model_shared = ControlNetModel.from_pretrained(
"xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
)
# Define the prompts and negative prompts for each style
styles = {
"Anime Studio Dance": {
"prompt": ("anime screencap of a man wearing a white helmet with pointed ears,\n"
"Outfit: closed animal print shirt,\n"
"Action: anime style, looking at viewer, solo, upper body,\n"
"((masterpiece)), (best quality), (extremely detailed), depth of field, sketch, "
"dark intense shadows, sharp focus, soft lighting, hdr, colorful, good composition, spectacular,"),
"negative_prompt": ("realistic, (painting by bad-artist-anime:0.9), (painting by bad-artist:0.9), watermark, "
"text, error, blurry, jpeg artifacts, cropped, worst quality, low quality, normal quality, "
"jpeg artifacts, signature, watermark, username, artist name, (worst quality, low quality:1.4), "
"bad anatomy, watermark, signature, text, logo")
},
"Vintage Realistic": {
"prompt": ("a masterpiece close up shoot photography of an man wearing a animal print helmet with pointed ears,\n"
"Outfit: wearing an big oversized outfit, white leather jacket,\n"
"Action: sitting on steps,\n"
"hyper realistic with detailed textures, cinematic film still of Photorealism, realistic skin texture, "
"subsurface scattering, skinny, Photorealism, often for highly detailed representation, photographic accuracy, "
"shallow depth of field, vignette, highly detailed, bokeh, epic, gorgeous, sharp, perfect hands,\n"
"<lora:add-detail-xl:1> <lora:Vintage_Street_Photo:0.9>"),
"negative_prompt": ("deformed skin, skin veins, black skin, blurry, text, yellow, deformed, (worst quality, low resolution, "
"bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d render, "
"distorted, twisted, watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, "
"glitch, deformed, mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, "
"(high contrast:1.2), (over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, "
"ugly, tiling, poorly drawn hands, 3d render, impressionism, digital art")
},
"Anime 90's Aesthetic": {
"prompt": ("an man wearing a white helmet with pointed ears, perfect chin,\n"
"Outfit: wearing oversized hoodie, animal print pants,\n"
"Action: dancing in nature, music production, music instruments made of wood,\n"
"A screengrab of an anime, 90's aesthetic,"),
"negative_prompt": ("photo, real, realistic, blurry, text, yellow, deformed, (worst quality, low resolution, bad hands,), "
"text, watermark, artist name, distorted, twisted, watermark, 3d render, distorted, twisted, watermark, "
"text, abstract, glitch, deformed, mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, "
"canvas frame, (high contrast:1.2), (over saturated:1.2), (glossy:1.1), disfigured, Photoshop, video game, "
"ugly, tiling, poorly drawn hands, 3d render, impressionism, eyes, mouth, black skin, pale skin, hair, beard")
},
"Anime Style": {
"prompt": ("An man wearing a white helmet with pointed ears sitting on the steps of an Asian street shop,\n"
"Outfit: wearing blue pants and a yellow jacket with a red backpack, in the anime style with detailed "
"character design in the style of Atey Ghailan, featured in CGSociety, character concept art in the style of Katsuhiro Otomo"),
"negative_prompt": ("real, deformed fingers, chin, deformed hands, blurry, text, yellow, deformed, (worst quality, low resolution, "
"bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d, distorted, twisted, "
"watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, "
"ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, (high contrast:1.2), "
"(over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, ugly, tiling, "
"poorly drawn hands, 3d render, impressionism, digital art")
},
"Real 70s": {
"prompt": ("a masterpiece close up shoot photography of an man wearing a white helmet with pointed ears,\n"
"Outfit: wearing an oversized trippy 70s shirt and scarf,\n"
"Action: standing on the ocean,\n"
"shot in the style of Erwin Olaf, hyper realistic with detailed textures, cinematic film still of Photorealism, "
"realistic skin texture, subsurface scattering, skinny, Photorealism, often for highly detailed representation, "
"photographic accuracy, shallow depth of field, vignette, highly detailed, bokeh, epic, gorgeous, sharp,"),
"negative_prompt": ("deformed skin, skin veins, black skin, blurry, text, yellow, deformed, (worst quality, low resolution, "
"bad hands, open mouth), text, watermark, artist name, distorted, twisted, watermark, 3d render, distorted, "
"twisted, watermark, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, "
"mutated, ugly, disfigured, photoshopped skin, airbrushed skin, glossy skin, canvas frame, (high contrast:1.2), "
"(over saturated:1.2), (glossy:1.1), cartoon, 3d, disfigured, Photoshop, video game, ugly, tiling, "
"poorly drawn hands, 3d render, impressionism, digital art")
}
}
# Define the style images
style_images = {
"Anime Studio Dance": "style/Anime Studio Dance.png",
"Vintage Realistic": "style/Vintage Realistic.png",
"Anime 90's Aesthetic": "style/Anime 90's Aesthetic.png",
"Anime Style": "style/Anime Style.png",
"Real 70s": "style/Real 70s.png"
}
# Function to load ControlNet models dynamically
def load_controlnet_model(controlnet_type):
global controlnet_pipe, pipe, reference_pipe, controlnet_models, vae, model, current_controlnet_type, controlnet_model_shared
clear_memory()
if controlnet_models[controlnet_type] is None:
if controlnet_type in ["Canny", "Depth", "OpenPose"]:
controlnet_models[controlnet_type] = controlnet_model_shared
elif controlnet_type == "Reference":
controlnet_models[controlnet_type] = StableDiffusionXLReferencePipeline.from_pretrained(
model, torch_dtype=torch.float16, use_safetensors=True
)
if current_controlnet_type == controlnet_type:
return f"{controlnet_type} model already loaded."
if 'controlnet_pipe' in globals() and controlnet_pipe is not None:
controlnet_pipe.to("cpu")
del controlnet_pipe
globals()['controlnet_pipe'] = None
if 'reference_pipe' in globals() and reference_pipe is not None:
reference_pipe.to("cpu")
del reference_pipe
globals()['reference_pipe'] = None
if pipe is not None:
pipe.to("cpu")
clear_memory()
if controlnet_type == "Reference":
reference_pipe = controlnet_models[controlnet_type]
reference_pipe.scheduler = UniPCMultistepScheduler.from_config(reference_pipe.scheduler.config)
reference_pipe.to("cuda")
globals()['reference_pipe'] = reference_pipe
else:
controlnet_pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
model, controlnet=controlnet_models[controlnet_type], vae=vae, torch_dtype=torch.float16, use_safetensors=True
)
controlnet_pipe.scheduler = UniPCMultistepScheduler.from_config(controlnet_pipe.scheduler.config)
controlnet_pipe.to("cuda")
globals()['controlnet_pipe'] = controlnet_pipe
current_controlnet_type = controlnet_type
clear_memory()
return f"Loaded {controlnet_type} model."
# Preprocessing functions for each ControlNet type
def preprocess_canny(image):
if isinstance(image, Image.Image):
image = np.array(image)
if image.dtype != np.uint8:
image = (image * 255).astype(np.uint8)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
return Image.fromarray(image)
def preprocess_depth(image, target_size=(1024, 1024)):
if isinstance(image, Image.Image):
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
else:
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
depth_img = processor_zoe(img, output_type='cv2') if random.random() > 0.5 else processor_midas(img, output_type='cv2')
height, width = depth_img.shape[:2]
ratio = min(target_size[0] / width, target_size[1] / height)
new_width, new_height = int(width * ratio), int(height * ratio)
depth_img_resized = cv2.resize(depth_img, (new_width, new_height))
return Image.fromarray(depth_img_resized)
def preprocess_openpose(image):
if isinstance(image, Image.Image):
image = np.array(image)
image = openpose_processor(image, hand_and_face=False, output_type='cv2')
height, width = image.shape[:2]
ratio = np.sqrt(1024. * 1024. / (width * height))
new_width, new_height = int(width * ratio), int(height * ratio)
image = cv2.resize(image, (new_width, new_height))
return Image.fromarray(image)
def process_image_batch(images, pipe, prompt, negative_prompt, progress, batch_size=2):
all_processed_images = []
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size]
batch_prompt = [prompt] * len(batch)
batch_negative_prompt = [negative_prompt] * len(batch)
if isinstance(pipe, StableDiffusionXLReferencePipeline):
processed_batch = []
for img in batch:
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
ref_image=img,
num_inference_steps=20
).images
processed_batch.extend(result)
else:
processed_batch = pipe(
prompt=batch_prompt,
negative_prompt=batch_negative_prompt,
image=batch,
num_inference_steps=20
).images
all_processed_images.extend(processed_batch)
progress((i + batch_size) / len(images)) # Update progress bar
clear_memory() # Clear memory after each batch
return all_processed_images
# Define the function to generate images
def generate_images_with_progress(prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, control_image, batch_images_input, progress=gr.Progress(track_tqdm=True)):
global controlnet_pipe, pipe, reference_pipe
clear_memory()
if use_controlnet:
if controlnet_type not in controlnet_models or controlnet_models[controlnet_type] is None:
raise ValueError(f"{controlnet_type} model not loaded. Please load the model first.")
if mode == "Single Image":
control_images = [control_image]
else:
control_images = [Image.open(img).convert("RGB") for img in batch_images_input]
preprocessed_images = []
for img in tqdm(control_images, desc="Preprocessing images"):
if controlnet_type == "Canny":
preprocessed_images.append(preprocess_canny(img))
elif controlnet_type == "Depth":
preprocessed_images.append(preprocess_depth(img))
elif controlnet_type == "OpenPose":
preprocessed_images.append(preprocess_openpose(img))
else: # Reference
preprocessed_images.append(img)
if controlnet_type == "Reference":
images = process_image_batch(preprocessed_images, reference_pipe, prompt, negative_prompt, progress)
else:
images = process_image_batch(preprocessed_images, controlnet_pipe, prompt, negative_prompt, progress)
else:
if 'controlnet_pipe' in globals() and controlnet_pipe is not None:
controlnet_pipe.to("cpu")
del controlnet_pipe
globals()['controlnet_pipe'] = None
if 'reference_pipe' in globals() and reference_pipe is not None:
reference_pipe.to("cpu")
del reference_pipe
globals()['reference_pipe'] = None
clear_memory()
if pipe is None:
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
pipe.to("cuda")
images = []
for i in tqdm(range(batch_count), desc="Generating images"):
generated = pipe(prompt=[prompt], negative_prompt=[negative_prompt], num_inference_steps=20, width=1024, height=1024).images
images.extend(generated)
progress((i + 1) / batch_count) # Update progress bar
clear_memory() # Clear memory after each image, even in single image mode
clear_memory()
return images
# Function to extract PNG metadata
def extract_png_info(image_path):
metadata = image_path.info # This is a dictionary containing key-value pairs of metadata
return metadata
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image Generation with Custom Prompts and Styles")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", lines=8, interactive=True)
with gr.Accordion("Negative Prompt (Minimize/Expand)", open=False):
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="",
lines=5
)
batch_count = gr.Slider(minimum=1, maximum=10, step=1, label="Batch Count", value=1)
use_controlnet = gr.Checkbox(label="Use ControlNet", value=False)
controlnet_type = gr.Dropdown(choices=["Canny", "Depth", "OpenPose", "Reference"], label="ControlNet Type")
controlnet_status = gr.Textbox(label="ControlNet Status", value="", interactive=False)
mode = gr.Radio(choices=["Single Image", "Batch"], label="Mode", value="Single Image")
with gr.Tabs() as tabs:
with gr.TabItem("Single Image"):
control_image = gr.Image(label="Control Image", type='pil')
with gr.TabItem("Batch"):
batch_images_input = gr.File(label="Upload Images", file_count='multiple')
with gr.TabItem("Extract Metadata"):
png_image = gr.Image(label="Upload PNG Image", type='pil')
metadata_output = gr.JSON(label="PNG Metadata")
with gr.Column(scale=2):
style_images_gallery = gr.Gallery(
label="Choose a Style",
value=list(style_images.values()),
interactive=True,
elem_id="style-gallery",
columns=5,
object_fit="contain",
height=235,
allow_preview=False
)
gallery = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=785)
selected_style = gr.State(value="Anime Studio Dance")
def select_style(evt: gr.SelectData):
style_names = list(styles.keys())
if evt.index < 0 or evt.index >= len(style_names):
raise ValueError(f"Invalid index: {evt.index}")
selected_style = style_names[evt.index]
return styles[selected_style]["prompt"], styles[selected_style]["negative_prompt"], selected_style
style_images_gallery.select(fn=select_style, inputs=[], outputs=[prompt, negative_prompt, selected_style])
def update_controlnet(controlnet_type):
status = load_controlnet_model(controlnet_type)
return status
controlnet_type.change(fn=update_controlnet, inputs=controlnet_type, outputs=controlnet_status)
generate_button = gr.Button("Generate Images")
generate_button.click(
generate_images_with_progress,
inputs=[prompt, negative_prompt, batch_count, use_controlnet, controlnet_type, mode, control_image, batch_images_input],
outputs=gallery
)
metadata_button = gr.Button("Extract Metadata")
metadata_button.click(
fn=extract_png_info,
inputs=png_image,
outputs=metadata_output
)
with gr.Row():
generate_button
# At the end of your script:
if __name__ == "__main__":
auth = gr.auth.Basic(username="roland", password="roland")
demo.launch(auth=auth, debug=True)
clear_memory()