Keltezaa commited on
Commit
7816a02
·
verified ·
1 Parent(s): f36149d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -11
app.py CHANGED
@@ -112,20 +112,43 @@ base_model = "black-forest-labs/FLUX.1-dev"
112
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
113
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
114
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
115
- pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
116
- base_model,
117
- vae=good_vae,
118
- transformer=pipe.transformer,
119
- text_encoder=pipe.text_encoder,
120
- tokenizer=pipe.tokenizer,
121
- text_encoder_2=pipe.text_encoder_2,
122
- tokenizer_2=pipe.tokenizer_2,
123
- torch_dtype=dtype
124
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  MAX_SEED = 2**32 - 1
127
 
128
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
 
129
 
130
  class calculateDuration:
131
  def __init__(self, activity_name=""):
@@ -622,11 +645,28 @@ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
622
  with gr.Row():
623
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
624
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
 
626
  gallery.select(
627
  update_selection,
628
  inputs=[selected_indices, loras_state, width, height],
629
  outputs=[prompt, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2])
 
630
  remove_button_1.click(
631
  remove_lora_1,
632
  inputs=[selected_indices, loras_state],
@@ -655,7 +695,7 @@ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
655
  gr.on(
656
  triggers=[generate_button.click, prompt.submit],
657
  fn=run_lora,
658
- inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
659
  outputs=[result, seed, progress_bar]
660
  ).then(
661
  fn=lambda x, history: update_history(x, history),
 
112
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
113
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
114
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
115
+
116
+ # Function to dynamically merge models
117
+ def merge_models(base_model, enhancement_model, alpha=0.7):
118
+ for base_param, enhance_param in zip(base_model.parameters(), enhancement_model.parameters()):
119
+ base_param.data = alpha * base_param.data + (1 - alpha) * enhance_param.data
120
+ return base_model
121
+
122
+ # Gradio interface function
123
+ def process_image(enable_enhancement, weight_slider):
124
+ # Load enhancement model if enabled
125
+ if enable_enhancement:
126
+ enhancement_model_path = "xey/sldr_flux_nsfw_v2-studio"
127
+ try:
128
+ enhancement_model = AutoencoderKL.from_pretrained(enhancement_model_path, torch_dtype=dtype).to(device)
129
+ # Merge with the base VAE using the weight from the slider
130
+ merged_vae = merge_models(good_vae, enhancement_model, alpha=weight_slider)
131
+ except Exception as e:
132
+ return f"Failed to load or merge enhancement model: {e}"
133
+ else:
134
+ merged_vae = good_vae # Use the base VAE if no enhancement is enabled
135
+
136
+ # Create the image pipeline with the updated VAE
137
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
138
+ base_model,
139
+ vae=merged_vae,
140
+ transformer=pipe.transformer,
141
+ text_encoder=pipe.text_encoder,
142
+ tokenizer=pipe.tokenizer,
143
+ text_encoder_2=pipe.text_encoder_2,
144
+ tokenizer_2=pipe.tokenizer_2,
145
+ torch_dtype=dtype
146
+ )
147
 
148
  MAX_SEED = 2**32 - 1
149
 
150
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
151
+ return "Pipeline updated with the selected enhancement model and weight."
152
 
153
  class calculateDuration:
154
  def __init__(self, activity_name=""):
 
645
  with gr.Row():
646
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
647
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
648
+ # Gradio UI Elements for Enhancement Model and Weight Slider
649
+ with gr.Row():
650
+ with gr.Column():
651
+ enable_enhancement_checkbox = gr.Checkbox(
652
+ label="Enable Enhancement Model",
653
+ value=False,
654
+ elem_id="enable_enhancement_checkbox"
655
+ )
656
+ enhancement_weight_slider = gr.Slider(
657
+ label="Weight for Enhancement Model",
658
+ minimum=0.0,
659
+ maximum=1.0,
660
+ step=0.05,
661
+ value=0.75, # Default weight
662
+ elem_id="enhancement_weight_slider"
663
+ )
664
 
665
  gallery.select(
666
  update_selection,
667
  inputs=[selected_indices, loras_state, width, height],
668
  outputs=[prompt, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2])
669
+
670
  remove_button_1.click(
671
  remove_lora_1,
672
  inputs=[selected_indices, loras_state],
 
695
  gr.on(
696
  triggers=[generate_button.click, prompt.submit],
697
  fn=run_lora,
698
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, enable_enhancement_checkbox, enhancement_weight_slider, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
699
  outputs=[result, seed, progress_bar]
700
  ).then(
701
  fn=lambda x, history: update_history(x, history),