Vijish commited on
Commit
872a7b3
Β·
verified Β·
1 Parent(s): 682647c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -18
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import gradio as gr
2
  from diffusers import DiffusionPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, UniPCMultistepScheduler
3
  from stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
4
- from controlnet_aux import OpenposeDetector
5
- #from transformers import DPTFeatureExtractor, DPTForDepthEstimation
6
- from controlnet_aux import MidasDetector, ZoeDetector
7
  from tqdm import tqdm
8
 
9
  import torch
@@ -25,6 +23,7 @@ def clear_memory():
25
  controlnet_pipe = None
26
  reference_pipe = None
27
  pipe = None
 
28
 
29
  # Load the base model
30
  model = "aicollective1/aicollective"
@@ -40,18 +39,15 @@ controlnet_models = {
40
  }
41
 
42
  # Load necessary models and feature extractors for depth estimation and OpenPose
43
- #feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
44
- #depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
45
-
46
  processor_zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
47
  processor_midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
48
-
49
-
50
  openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
51
 
52
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
53
- controlnet_pipe = None # Initial placeholder, will be loaded dynamically
54
- reference_pipe = None # Initial placeholder for reference pipeline
 
 
55
 
56
  # Define the prompts and negative prompts for each style
57
  styles = {
@@ -130,20 +126,21 @@ style_images = {
130
 
131
  # Function to load ControlNet models dynamically
132
  def load_controlnet_model(controlnet_type):
133
- global controlnet_pipe, pipe, reference_pipe, controlnet_models, vae, model
134
 
135
  clear_memory()
136
 
137
  if controlnet_models[controlnet_type] is None:
138
  if controlnet_type in ["Canny", "Depth", "OpenPose"]:
139
- controlnet_models[controlnet_type] = ControlNetModel.from_pretrained(
140
- "xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
141
- )
142
  elif controlnet_type == "Reference":
143
  controlnet_models[controlnet_type] = StableDiffusionXLReferencePipeline.from_pretrained(
144
  model, torch_dtype=torch.float16, use_safetensors=True
145
  )
146
 
 
 
 
147
  if 'controlnet_pipe' in globals() and controlnet_pipe is not None:
148
  controlnet_pipe.to("cpu")
149
  del controlnet_pipe
@@ -172,10 +169,10 @@ def load_controlnet_model(controlnet_type):
172
  controlnet_pipe.to("cuda")
173
  globals()['controlnet_pipe'] = controlnet_pipe
174
 
 
175
  clear_memory()
176
  return f"Loaded {controlnet_type} model."
177
 
178
-
179
  # Preprocessing functions for each ControlNet type
180
  def preprocess_canny(image):
181
  if isinstance(image, Image.Image):
@@ -384,10 +381,8 @@ with gr.Blocks() as demo:
384
  with gr.Row():
385
  generate_button
386
 
387
-
388
-
389
  # At the end of your script:
390
  if __name__ == "__main__":
391
  # Your Gradio interface setup here
392
  demo.launch(debug=True)
393
- clear_memory()
 
1
  import gradio as gr
2
  from diffusers import DiffusionPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, UniPCMultistepScheduler
3
  from stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
4
+ from controlnet_aux import OpenposeDetector, MidasDetector, ZoeDetector
 
 
5
  from tqdm import tqdm
6
 
7
  import torch
 
23
  controlnet_pipe = None
24
  reference_pipe = None
25
  pipe = None
26
+ current_controlnet_type = None
27
 
28
  # Load the base model
29
  model = "aicollective1/aicollective"
 
39
  }
40
 
41
  # Load necessary models and feature extractors for depth estimation and OpenPose
 
 
 
42
  processor_zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
43
  processor_midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
 
 
44
  openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
45
 
46
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
47
+
48
+ controlnet_model_shared = ControlNetModel.from_pretrained(
49
+ "xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
50
+ )
51
 
52
  # Define the prompts and negative prompts for each style
53
  styles = {
 
126
 
127
  # Function to load ControlNet models dynamically
128
  def load_controlnet_model(controlnet_type):
129
+ global controlnet_pipe, pipe, reference_pipe, controlnet_models, vae, model, current_controlnet_type, controlnet_model_shared
130
 
131
  clear_memory()
132
 
133
  if controlnet_models[controlnet_type] is None:
134
  if controlnet_type in ["Canny", "Depth", "OpenPose"]:
135
+ controlnet_models[controlnet_type] = controlnet_model_shared
 
 
136
  elif controlnet_type == "Reference":
137
  controlnet_models[controlnet_type] = StableDiffusionXLReferencePipeline.from_pretrained(
138
  model, torch_dtype=torch.float16, use_safetensors=True
139
  )
140
 
141
+ if current_controlnet_type == controlnet_type:
142
+ return f"{controlnet_type} model already loaded."
143
+
144
  if 'controlnet_pipe' in globals() and controlnet_pipe is not None:
145
  controlnet_pipe.to("cpu")
146
  del controlnet_pipe
 
169
  controlnet_pipe.to("cuda")
170
  globals()['controlnet_pipe'] = controlnet_pipe
171
 
172
+ current_controlnet_type = controlnet_type
173
  clear_memory()
174
  return f"Loaded {controlnet_type} model."
175
 
 
176
  # Preprocessing functions for each ControlNet type
177
  def preprocess_canny(image):
178
  if isinstance(image, Image.Image):
 
381
  with gr.Row():
382
  generate_button
383
 
 
 
384
  # At the end of your script:
385
  if __name__ == "__main__":
386
  # Your Gradio interface setup here
387
  demo.launch(debug=True)
388
+ clear_memory()