Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
from diffusers import DiffusionPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, UniPCMultistepScheduler
|
3 |
from stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
|
4 |
-
from controlnet_aux import OpenposeDetector
|
5 |
-
#from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
6 |
-
from controlnet_aux import MidasDetector, ZoeDetector
|
7 |
from tqdm import tqdm
|
8 |
|
9 |
import torch
|
@@ -25,6 +23,7 @@ def clear_memory():
|
|
25 |
controlnet_pipe = None
|
26 |
reference_pipe = None
|
27 |
pipe = None
|
|
|
28 |
|
29 |
# Load the base model
|
30 |
model = "aicollective1/aicollective"
|
@@ -40,18 +39,15 @@ controlnet_models = {
|
|
40 |
}
|
41 |
|
42 |
# Load necessary models and feature extractors for depth estimation and OpenPose
|
43 |
-
#feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
44 |
-
#depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
|
45 |
-
|
46 |
processor_zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
47 |
processor_midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
48 |
-
|
49 |
-
|
50 |
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
51 |
|
52 |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
|
53 |
-
|
54 |
-
|
|
|
|
|
55 |
|
56 |
# Define the prompts and negative prompts for each style
|
57 |
styles = {
|
@@ -130,20 +126,21 @@ style_images = {
|
|
130 |
|
131 |
# Function to load ControlNet models dynamically
|
132 |
def load_controlnet_model(controlnet_type):
|
133 |
-
global controlnet_pipe, pipe, reference_pipe, controlnet_models, vae, model
|
134 |
|
135 |
clear_memory()
|
136 |
|
137 |
if controlnet_models[controlnet_type] is None:
|
138 |
if controlnet_type in ["Canny", "Depth", "OpenPose"]:
|
139 |
-
controlnet_models[controlnet_type] =
|
140 |
-
"xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
|
141 |
-
)
|
142 |
elif controlnet_type == "Reference":
|
143 |
controlnet_models[controlnet_type] = StableDiffusionXLReferencePipeline.from_pretrained(
|
144 |
model, torch_dtype=torch.float16, use_safetensors=True
|
145 |
)
|
146 |
|
|
|
|
|
|
|
147 |
if 'controlnet_pipe' in globals() and controlnet_pipe is not None:
|
148 |
controlnet_pipe.to("cpu")
|
149 |
del controlnet_pipe
|
@@ -172,10 +169,10 @@ def load_controlnet_model(controlnet_type):
|
|
172 |
controlnet_pipe.to("cuda")
|
173 |
globals()['controlnet_pipe'] = controlnet_pipe
|
174 |
|
|
|
175 |
clear_memory()
|
176 |
return f"Loaded {controlnet_type} model."
|
177 |
|
178 |
-
|
179 |
# Preprocessing functions for each ControlNet type
|
180 |
def preprocess_canny(image):
|
181 |
if isinstance(image, Image.Image):
|
@@ -384,10 +381,8 @@ with gr.Blocks() as demo:
|
|
384 |
with gr.Row():
|
385 |
generate_button
|
386 |
|
387 |
-
|
388 |
-
|
389 |
# At the end of your script:
|
390 |
if __name__ == "__main__":
|
391 |
# Your Gradio interface setup here
|
392 |
demo.launch(debug=True)
|
393 |
-
clear_memory()
|
|
|
1 |
import gradio as gr
|
2 |
from diffusers import DiffusionPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, UniPCMultistepScheduler
|
3 |
from stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
|
4 |
+
from controlnet_aux import OpenposeDetector, MidasDetector, ZoeDetector
|
|
|
|
|
5 |
from tqdm import tqdm
|
6 |
|
7 |
import torch
|
|
|
23 |
controlnet_pipe = None
|
24 |
reference_pipe = None
|
25 |
pipe = None
|
26 |
+
current_controlnet_type = None
|
27 |
|
28 |
# Load the base model
|
29 |
model = "aicollective1/aicollective"
|
|
|
39 |
}
|
40 |
|
41 |
# Load necessary models and feature extractors for depth estimation and OpenPose
|
|
|
|
|
|
|
42 |
processor_zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
43 |
processor_midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
|
|
|
|
44 |
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
45 |
|
46 |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
|
47 |
+
|
48 |
+
controlnet_model_shared = ControlNetModel.from_pretrained(
|
49 |
+
"xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
|
50 |
+
)
|
51 |
|
52 |
# Define the prompts and negative prompts for each style
|
53 |
styles = {
|
|
|
126 |
|
127 |
# Function to load ControlNet models dynamically
|
128 |
def load_controlnet_model(controlnet_type):
|
129 |
+
global controlnet_pipe, pipe, reference_pipe, controlnet_models, vae, model, current_controlnet_type, controlnet_model_shared
|
130 |
|
131 |
clear_memory()
|
132 |
|
133 |
if controlnet_models[controlnet_type] is None:
|
134 |
if controlnet_type in ["Canny", "Depth", "OpenPose"]:
|
135 |
+
controlnet_models[controlnet_type] = controlnet_model_shared
|
|
|
|
|
136 |
elif controlnet_type == "Reference":
|
137 |
controlnet_models[controlnet_type] = StableDiffusionXLReferencePipeline.from_pretrained(
|
138 |
model, torch_dtype=torch.float16, use_safetensors=True
|
139 |
)
|
140 |
|
141 |
+
if current_controlnet_type == controlnet_type:
|
142 |
+
return f"{controlnet_type} model already loaded."
|
143 |
+
|
144 |
if 'controlnet_pipe' in globals() and controlnet_pipe is not None:
|
145 |
controlnet_pipe.to("cpu")
|
146 |
del controlnet_pipe
|
|
|
169 |
controlnet_pipe.to("cuda")
|
170 |
globals()['controlnet_pipe'] = controlnet_pipe
|
171 |
|
172 |
+
current_controlnet_type = controlnet_type
|
173 |
clear_memory()
|
174 |
return f"Loaded {controlnet_type} model."
|
175 |
|
|
|
176 |
# Preprocessing functions for each ControlNet type
|
177 |
def preprocess_canny(image):
|
178 |
if isinstance(image, Image.Image):
|
|
|
381 |
with gr.Row():
|
382 |
generate_button
|
383 |
|
|
|
|
|
384 |
# At the end of your script:
|
385 |
if __name__ == "__main__":
|
386 |
# Your Gradio interface setup here
|
387 |
demo.launch(debug=True)
|
388 |
+
clear_memory()
|