mantrakp commited on
Commit
daf9c75
1 Parent(s): 07dc8e6

Refactor ControlNetReq class to remove unused import and add controlnets, control_images, and controlnet_conditioning_scale attributes

Browse files
modules/events/flux_events.py CHANGED
@@ -5,14 +5,15 @@ import spaces
5
  import gradio as gr
6
  from huggingface_hub import ModelCard
7
 
8
- from modules.helpers.flux_helpers import FluxReq, FluxImg2ImgReq, FluxInpaintReq, ControlNetReq, gen_img
9
- from config import flux_models, flux_loras
 
10
 
11
  loras = flux_loras
12
 
13
 
14
  # Event functions
15
- def update_fast_generation(model, fast_generation):
16
  if fast_generation:
17
  return (
18
  gr.update(
@@ -125,7 +126,7 @@ def update_selected_lora(custom_lora):
125
  )
126
 
127
 
128
- def add_to_enabled_loras(model, selected_lora, enabled_loras):
129
  lora_data = loras
130
  try:
131
  selected_lora = int(selected_lora)
@@ -233,7 +234,7 @@ def generate_image(
233
  "vae": vae,
234
  "controlnet_config": None,
235
  }
236
- base_args = FluxReq(**base_args)
237
 
238
  if len(enabled_loras) > 0:
239
  base_args.loras = []
@@ -252,7 +253,7 @@ def generate_image(
252
  image = img2img_image
253
  strength = float(img2img_strength)
254
 
255
- base_args = FluxImg2ImgReq(
256
  **base_args.__dict__,
257
  image=image,
258
  strength=strength
@@ -263,7 +264,7 @@ def generate_image(
263
  strength = float(inpaint_strength)
264
 
265
  if image and mask_image:
266
- base_args = FluxInpaintReq(
267
  **base_args.__dict__,
268
  image=image,
269
  mask_image=mask_image,
@@ -289,7 +290,7 @@ def generate_image(
289
  base_args.controlnet_config.control_images.append(depth_image)
290
  base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
291
  else:
292
- base_args = FluxReq(**base_args.__dict__)
293
 
294
  return gr.update(
295
  value=gen_img(base_args),
 
5
  import gradio as gr
6
  from huggingface_hub import ModelCard
7
 
8
+ from modules.helpers.common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq
9
+ from modules.helpers.flux_helpers import gen_img
10
+ from config import flux_loras
11
 
12
  loras = flux_loras
13
 
14
 
15
  # Event functions
16
+ def update_fast_generation(fast_generation):
17
  if fast_generation:
18
  return (
19
  gr.update(
 
126
  )
127
 
128
 
129
+ def add_to_enabled_loras(selected_lora, enabled_loras):
130
  lora_data = loras
131
  try:
132
  selected_lora = int(selected_lora)
 
234
  "vae": vae,
235
  "controlnet_config": None,
236
  }
237
+ base_args = BaseReq(**base_args)
238
 
239
  if len(enabled_loras) > 0:
240
  base_args.loras = []
 
253
  image = img2img_image
254
  strength = float(img2img_strength)
255
 
256
+ base_args = BaseImg2ImgReq(
257
  **base_args.__dict__,
258
  image=image,
259
  strength=strength
 
264
  strength = float(inpaint_strength)
265
 
266
  if image and mask_image:
267
+ base_args = BaseInpaintReq(
268
  **base_args.__dict__,
269
  image=image,
270
  mask_image=mask_image,
 
290
  base_args.controlnet_config.control_images.append(depth_image)
291
  base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
292
  else:
293
+ base_args = BaseReq(**base_args.__dict__)
294
 
295
  return gr.update(
296
  value=gen_img(base_args),
modules/helpers/flux_helpers.py CHANGED
@@ -6,10 +6,6 @@ from diffusers import (
6
  AutoPipelineForText2Image,
7
  AutoPipelineForImage2Image,
8
  AutoPipelineForInpainting,
9
- DiffusionPipeline,
10
- AutoencoderKL,
11
- FluxControlNetModel,
12
- FluxMultiControlNetModel,
13
  )
14
  from huggingface_hub import hf_hub_download
15
  from diffusers.schedulers import *
@@ -17,56 +13,8 @@ from huggingface_hub import hf_hub_download
17
  from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
18
 
19
  from .common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq, cleanup, get_controlnet_images, resize_images
20
-
21
-
22
- def load_sd():
23
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- device = "cuda" if torch.cuda.is_available() else "cpu"
25
-
26
- # Models
27
- models = [
28
- {
29
- "repo_id": "black-forest-labs/FLUX.1-dev",
30
- "loader": "flux",
31
- "compute_type": torch.bfloat16,
32
- }
33
- ]
34
-
35
- for model in models:
36
- try:
37
- model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
38
- model['repo_id'],
39
- vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
40
- torch_dtype=model['compute_type'],
41
- safety_checker=None,
42
- variant="fp16"
43
- ).to(device)
44
- except:
45
- model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
46
- model['repo_id'],
47
- vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
48
- torch_dtype=model['compute_type'],
49
- safety_checker=None
50
- ).to(device)
51
-
52
- model["pipeline"].enable_model_cpu_offload()
53
-
54
- # VAE n Refiner
55
- flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
56
- sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
57
- refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
58
- refiner.enable_model_cpu_offload()
59
-
60
- # ControlNet
61
- controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
62
- "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
63
- torch_dtype=torch.bfloat16
64
- ).to(device)])
65
-
66
- return device, models, flux_vae, sdxl_vae, refiner, controlnet
67
-
68
-
69
- device, models, flux_vae, sdxl_vae, refiner, controlnet = load_sd()
70
 
71
 
72
  def get_control_mode(controlnet_config: ControlNetReq):
 
6
  AutoPipelineForText2Image,
7
  AutoPipelineForImage2Image,
8
  AutoPipelineForInpainting,
 
 
 
 
9
  )
10
  from huggingface_hub import hf_hub_download
11
  from diffusers.schedulers import *
 
13
  from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
14
 
15
  from .common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq, cleanup, get_controlnet_images, resize_images
16
+ from modules.pipelines.flux_pipelines import device, models, flux_vae, controlnet
17
+ from modules.pipelines.common_pipelines import refiner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def get_control_mode(controlnet_config: ControlNetReq):
modules/pipelines/common_pipelines.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import (
3
+ DiffusionPipeline,
4
+ AutoencoderKL,
5
+ )
6
+ from diffusers.schedulers import *
7
+
8
+
9
+ def load_common():
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # VAE n Refiner
13
+ sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
14
+ refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
15
+ refiner.enable_model_cpu_offload()
16
+
17
+ return refiner, sdxl_vae
18
+
19
+ refiner, sdxl_vae = load_common()
modules/pipelines/flux_pipelines.py CHANGED
@@ -1,19 +1,58 @@
1
- # modules/pipelines/flux_pipelines.py
2
 
3
  import torch
4
- from diffusers import AutoPipelineForText2Image, AutoencoderKL
 
 
 
 
 
 
 
 
 
5
 
6
  def load_flux():
7
- # Load FLUX models and pipelines
8
- # ...
9
- return device, models, flux_vae, controlnet
10
 
11
- # modules/pipelines/sdxl_pipelines.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- import torch
14
- from diffusers import AutoPipelineForText2Image, AutoencoderKL
15
 
16
- def load_sdxl():
17
- # Load SDXL models and pipelines
18
- # ...
19
- return device, models, sdxl_vae, controlnet
 
 
1
 
2
  import torch
3
+ from diffusers import (
4
+ AutoPipelineForText2Image,
5
+ DiffusionPipeline,
6
+ AutoencoderKL,
7
+ FluxControlNetModel,
8
+ FluxMultiControlNetModel,
9
+ )
10
+ from diffusers.schedulers import *
11
+
12
+
13
 
14
  def load_flux():
15
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
17
 
18
+ # Models
19
+ models = [
20
+ {
21
+ "repo_id": "black-forest-labs/FLUX.1-dev",
22
+ "loader": "flux",
23
+ "compute_type": torch.bfloat16,
24
+ }
25
+ ]
26
+
27
+ for model in models:
28
+ try:
29
+ model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
30
+ model['repo_id'],
31
+ vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
32
+ torch_dtype=model['compute_type'],
33
+ safety_checker=None,
34
+ variant="fp16"
35
+ ).to(device)
36
+ except:
37
+ model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
38
+ model['repo_id'],
39
+ vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
40
+ torch_dtype=model['compute_type'],
41
+ safety_checker=None
42
+ ).to(device)
43
+
44
+ model["pipeline"].enable_model_cpu_offload()
45
+
46
+ # VAE n Refiner
47
+ flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
48
+
49
+ # ControlNet
50
+ controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
51
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
52
+ torch_dtype=torch.bfloat16
53
+ ).to(device)])
54
+
55
+ return device, models, flux_vae, controlnet
56
 
 
 
57
 
58
+ device, models, flux_vae, controlnet = load_flux()
 
 
 
tabs/image_tab.py CHANGED
@@ -144,13 +144,13 @@ def flux_tab():
144
 
145
  # Events
146
  # Base Options
147
- fast_generation.change(update_fast_generation, [model, fast_generation], [image_guidance_scale, image_num_inference_steps]) # Fast Generation # type: ignore
148
 
149
 
150
  # Lora Gallery
151
  lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
152
  custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
153
- add_lora.click(add_to_enabled_loras, [model, selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
154
  enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
155
 
156
  for i in range(6):
 
144
 
145
  # Events
146
  # Base Options
147
+ fast_generation.change(update_fast_generation, [fast_generation], [image_guidance_scale, image_num_inference_steps]) # Fast Generation # type: ignore
148
 
149
 
150
  # Lora Gallery
151
  lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
152
  custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
153
+ add_lora.click(add_to_enabled_loras, [selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
154
  enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
155
 
156
  for i in range(6):