Update app.py
Browse files
app.py
CHANGED
@@ -112,20 +112,43 @@ base_model = "black-forest-labs/FLUX.1-dev"
|
|
112 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
113 |
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
114 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
MAX_SEED = 2**32 - 1
|
127 |
|
128 |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
|
|
129 |
|
130 |
class calculateDuration:
|
131 |
def __init__(self, activity_name=""):
|
@@ -622,11 +645,28 @@ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
|
|
622 |
with gr.Row():
|
623 |
randomize_seed = gr.Checkbox(True, label="Randomize seed")
|
624 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
625 |
|
626 |
gallery.select(
|
627 |
update_selection,
|
628 |
inputs=[selected_indices, loras_state, width, height],
|
629 |
outputs=[prompt, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2])
|
|
|
630 |
remove_button_1.click(
|
631 |
remove_lora_1,
|
632 |
inputs=[selected_indices, loras_state],
|
@@ -655,7 +695,7 @@ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
|
|
655 |
gr.on(
|
656 |
triggers=[generate_button.click, prompt.submit],
|
657 |
fn=run_lora,
|
658 |
-
inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
|
659 |
outputs=[result, seed, progress_bar]
|
660 |
).then(
|
661 |
fn=lambda x, history: update_history(x, history),
|
|
|
112 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
113 |
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
114 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
115 |
+
|
116 |
+
# Function to dynamically merge models
|
117 |
+
def merge_models(base_model, enhancement_model, alpha=0.7):
|
118 |
+
for base_param, enhance_param in zip(base_model.parameters(), enhancement_model.parameters()):
|
119 |
+
base_param.data = alpha * base_param.data + (1 - alpha) * enhance_param.data
|
120 |
+
return base_model
|
121 |
+
|
122 |
+
# Gradio interface function
|
123 |
+
def process_image(enable_enhancement, weight_slider):
|
124 |
+
# Load enhancement model if enabled
|
125 |
+
if enable_enhancement:
|
126 |
+
enhancement_model_path = "xey/sldr_flux_nsfw_v2-studio"
|
127 |
+
try:
|
128 |
+
enhancement_model = AutoencoderKL.from_pretrained(enhancement_model_path, torch_dtype=dtype).to(device)
|
129 |
+
# Merge with the base VAE using the weight from the slider
|
130 |
+
merged_vae = merge_models(good_vae, enhancement_model, alpha=weight_slider)
|
131 |
+
except Exception as e:
|
132 |
+
return f"Failed to load or merge enhancement model: {e}"
|
133 |
+
else:
|
134 |
+
merged_vae = good_vae # Use the base VAE if no enhancement is enabled
|
135 |
+
|
136 |
+
# Create the image pipeline with the updated VAE
|
137 |
+
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
|
138 |
+
base_model,
|
139 |
+
vae=merged_vae,
|
140 |
+
transformer=pipe.transformer,
|
141 |
+
text_encoder=pipe.text_encoder,
|
142 |
+
tokenizer=pipe.tokenizer,
|
143 |
+
text_encoder_2=pipe.text_encoder_2,
|
144 |
+
tokenizer_2=pipe.tokenizer_2,
|
145 |
+
torch_dtype=dtype
|
146 |
+
)
|
147 |
|
148 |
MAX_SEED = 2**32 - 1
|
149 |
|
150 |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
151 |
+
return "Pipeline updated with the selected enhancement model and weight."
|
152 |
|
153 |
class calculateDuration:
|
154 |
def __init__(self, activity_name=""):
|
|
|
645 |
with gr.Row():
|
646 |
randomize_seed = gr.Checkbox(True, label="Randomize seed")
|
647 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
|
648 |
+
# Gradio UI Elements for Enhancement Model and Weight Slider
|
649 |
+
with gr.Row():
|
650 |
+
with gr.Column():
|
651 |
+
enable_enhancement_checkbox = gr.Checkbox(
|
652 |
+
label="Enable Enhancement Model",
|
653 |
+
value=False,
|
654 |
+
elem_id="enable_enhancement_checkbox"
|
655 |
+
)
|
656 |
+
enhancement_weight_slider = gr.Slider(
|
657 |
+
label="Weight for Enhancement Model",
|
658 |
+
minimum=0.0,
|
659 |
+
maximum=1.0,
|
660 |
+
step=0.05,
|
661 |
+
value=0.75, # Default weight
|
662 |
+
elem_id="enhancement_weight_slider"
|
663 |
+
)
|
664 |
|
665 |
gallery.select(
|
666 |
update_selection,
|
667 |
inputs=[selected_indices, loras_state, width, height],
|
668 |
outputs=[prompt, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2])
|
669 |
+
|
670 |
remove_button_1.click(
|
671 |
remove_lora_1,
|
672 |
inputs=[selected_indices, loras_state],
|
|
|
695 |
gr.on(
|
696 |
triggers=[generate_button.click, prompt.submit],
|
697 |
fn=run_lora,
|
698 |
+
inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, enable_enhancement_checkbox, enhancement_weight_slider, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
|
699 |
outputs=[result, seed, progress_bar]
|
700 |
).then(
|
701 |
fn=lambda x, history: update_history(x, history),
|