Update app.py
Browse files
app.py
CHANGED
@@ -440,12 +440,7 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
|
|
440 |
print("Generating image...")
|
441 |
pipe.to("cuda")
|
442 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
443 |
-
final_image = None # Initialize final_image
|
444 |
with calculateDuration("Generating image"):
|
445 |
-
# Ensure width and height are integers
|
446 |
-
width = int(width) if isinstance(width, (float, int)) else 1024
|
447 |
-
height = int(height) if isinstance(height, (float, int)) else 1024
|
448 |
-
|
449 |
# Generate image
|
450 |
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
451 |
prompt=prompt_mash,
|
@@ -458,17 +453,9 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
|
|
458 |
output_type="pil",
|
459 |
good_vae=good_vae,
|
460 |
):
|
461 |
-
print(f"Debug: Yielding image of type {type(img)}") # Check image type
|
462 |
-
if isinstance(img, float):
|
463 |
-
print("Error: A float was returned instead of an image.") # Log if img is a float
|
464 |
-
raise ValueError("Expected an image, but got a float.") # Raise error if a float is found
|
465 |
yield img
|
466 |
-
final_image = img # Update final_image with the current image
|
467 |
-
return final_image
|
468 |
-
|
469 |
|
470 |
def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
|
471 |
-
print("Generating image from input...")
|
472 |
pipe_i2i.to("cuda")
|
473 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
474 |
image_input = load_image(image_input_path)
|
@@ -484,22 +471,14 @@ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps
|
|
484 |
joint_attention_kwargs={"scale": 1.0},
|
485 |
output_type="pil",
|
486 |
).images[0]
|
487 |
-
if isinstance(final_image, float):
|
488 |
-
print("Error: Expected an image but got a float.")
|
489 |
-
raise ValueError("Expected an image, but got a float.")
|
490 |
return final_image
|
491 |
|
492 |
-
|
493 |
@spaces.GPU(duration=75)
|
494 |
-
def run_lora(prompt,
|
495 |
-
print("run_lora function called.") # Debugging statement
|
496 |
-
print(f"Inputs received - Prompt: {prompt}, CFG Scale: {cfg_scale}, Steps: {steps}, Seed: {seed}, Width: {width}, Height: {height}") # Debugging statement
|
497 |
-
|
498 |
if not selected_indices:
|
499 |
raise gr.Error("You must select at least one LoRA before proceeding.")
|
500 |
|
501 |
selected_loras = [loras_state[idx] for idx in selected_indices]
|
502 |
-
print(f"Selected LoRAs: {selected_loras}") # Debugging statement
|
503 |
|
504 |
# Build the prompt with trigger words
|
505 |
prepends = []
|
@@ -507,39 +486,27 @@ def run_lora(prompt, cfg_scale, steps, selected_info_1, selected_info_2, selecte
|
|
507 |
for lora in selected_loras:
|
508 |
trigger_word = lora.get('trigger_word', '')
|
509 |
if trigger_word:
|
510 |
-
if lora.get("trigger_position") == "
|
511 |
prepends.append(trigger_word)
|
512 |
else:
|
513 |
appends.append(trigger_word)
|
514 |
prompt_mash = " ".join(prepends + [prompt] + appends)
|
515 |
-
print("Prompt Mash: ", prompt_mash)
|
516 |
-
|
517 |
-
# Ensure valid width and height values
|
518 |
-
if width is None or isinstance(width, gr.Progress): # Check for Gradio Progress object
|
519 |
-
width = 1024 # Default value
|
520 |
-
if height is None:
|
521 |
-
height = 1024 # Default value
|
522 |
-
|
523 |
-
# Set seed value
|
524 |
-
if seed is None or randomize_seed:
|
525 |
-
seed = random.randint(0, MAX_SEED)
|
526 |
-
|
527 |
# Unload previous LoRA weights
|
528 |
with calculateDuration("Unloading LoRA"):
|
529 |
pipe.unload_lora_weights()
|
530 |
-
|
531 |
-
|
|
|
532 |
# Load LoRA weights with respective scales
|
533 |
lora_names = []
|
534 |
lora_weights = []
|
535 |
-
lora_scales = [lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4] # List of scales
|
536 |
with calculateDuration("Loading LoRA weights"):
|
537 |
for idx, lora in enumerate(selected_loras):
|
538 |
lora_name = f"lora_{idx}"
|
539 |
lora_names.append(lora_name)
|
540 |
print(f"Lora Name: {lora_name}")
|
541 |
-
|
542 |
-
lora_weights.append(lora_scales[idx] if idx < len(lora_scales) else lora_scale_4) # Default to last scale if out of bounds
|
543 |
lora_path = lora['repo']
|
544 |
weight_name = lora.get("weights")
|
545 |
print(f"Lora Path: {lora_path}")
|
@@ -550,47 +517,38 @@ def run_lora(prompt, cfg_scale, steps, selected_info_1, selected_info_2, selecte
|
|
550 |
low_cpu_mem_usage=True,
|
551 |
adapter_name=lora_name
|
552 |
)
|
553 |
-
|
554 |
-
|
555 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
# Set random seed for reproducibility
|
557 |
with calculateDuration("Randomizing seed"):
|
558 |
-
|
|
|
559 |
|
560 |
# Generate image
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
final_image =
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
print(f"Debug: generate_image yielded value: {image}") # Debugging
|
576 |
-
step_counter += 1
|
577 |
-
final_image = image
|
578 |
-
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
|
579 |
-
yield image, seed, gr.update(value=progress_bar, visible=True)
|
580 |
-
|
581 |
-
if final_image is None:
|
582 |
-
print("No final image generated.") # Debugging statement
|
583 |
-
else:
|
584 |
-
print(f"Debug: final_image type: {type(final_image)}") # Debugging
|
585 |
-
yield final_image, seed, gr.update(value=progress_bar, visible=False)
|
586 |
-
|
587 |
-
except Exception as e:
|
588 |
-
print(f"Error during image generation: {e}") # Error handling
|
589 |
-
raise gr.Error("An error occurred during image generation.")
|
590 |
|
591 |
run_lora.zerogpu = True
|
592 |
|
593 |
-
|
594 |
def get_huggingface_safetensors(link):
|
595 |
split_link = link.split("/")
|
596 |
if len(split_link) == 4:
|
|
|
440 |
print("Generating image...")
|
441 |
pipe.to("cuda")
|
442 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
|
|
443 |
with calculateDuration("Generating image"):
|
|
|
|
|
|
|
|
|
444 |
# Generate image
|
445 |
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
446 |
prompt=prompt_mash,
|
|
|
453 |
output_type="pil",
|
454 |
good_vae=good_vae,
|
455 |
):
|
|
|
|
|
|
|
|
|
456 |
yield img
|
|
|
|
|
|
|
457 |
|
458 |
def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
|
|
|
459 |
pipe_i2i.to("cuda")
|
460 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
461 |
image_input = load_image(image_input_path)
|
|
|
471 |
joint_attention_kwargs={"scale": 1.0},
|
472 |
output_type="pil",
|
473 |
).images[0]
|
|
|
|
|
|
|
474 |
return final_image
|
475 |
|
|
|
476 |
@spaces.GPU(duration=75)
|
477 |
+
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
478 |
if not selected_indices:
|
479 |
raise gr.Error("You must select at least one LoRA before proceeding.")
|
480 |
|
481 |
selected_loras = [loras_state[idx] for idx in selected_indices]
|
|
|
482 |
|
483 |
# Build the prompt with trigger words
|
484 |
prepends = []
|
|
|
486 |
for lora in selected_loras:
|
487 |
trigger_word = lora.get('trigger_word', '')
|
488 |
if trigger_word:
|
489 |
+
if lora.get("trigger_position") == "prepend":
|
490 |
prepends.append(trigger_word)
|
491 |
else:
|
492 |
appends.append(trigger_word)
|
493 |
prompt_mash = " ".join(prepends + [prompt] + appends)
|
494 |
+
print("Prompt Mash: ", prompt_mash)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
# Unload previous LoRA weights
|
496 |
with calculateDuration("Unloading LoRA"):
|
497 |
pipe.unload_lora_weights()
|
498 |
+
pipe_i2i.unload_lora_weights()
|
499 |
+
|
500 |
+
print(pipe.get_active_adapters())
|
501 |
# Load LoRA weights with respective scales
|
502 |
lora_names = []
|
503 |
lora_weights = []
|
|
|
504 |
with calculateDuration("Loading LoRA weights"):
|
505 |
for idx, lora in enumerate(selected_loras):
|
506 |
lora_name = f"lora_{idx}"
|
507 |
lora_names.append(lora_name)
|
508 |
print(f"Lora Name: {lora_name}")
|
509 |
+
lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
|
|
|
510 |
lora_path = lora['repo']
|
511 |
weight_name = lora.get("weights")
|
512 |
print(f"Lora Path: {lora_path}")
|
|
|
517 |
low_cpu_mem_usage=True,
|
518 |
adapter_name=lora_name
|
519 |
)
|
520 |
+
# if image_input is not None: pipe_i2i = pipe_to_use
|
521 |
+
# else: pipe = pipe_to_use
|
522 |
+
print("Loaded LoRAs:", lora_names)
|
523 |
+
print("Adapter weights:", lora_weights)
|
524 |
+
if image_input is not None:
|
525 |
+
pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
|
526 |
+
else:
|
527 |
+
pipe.set_adapters(lora_names, adapter_weights=lora_weights)
|
528 |
+
print(pipe.get_active_adapters())
|
529 |
# Set random seed for reproducibility
|
530 |
with calculateDuration("Randomizing seed"):
|
531 |
+
if randomize_seed:
|
532 |
+
seed = random.randint(0, MAX_SEED)
|
533 |
|
534 |
# Generate image
|
535 |
+
if image_input is not None:
|
536 |
+
final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
|
537 |
+
yield final_image, seed, gr.update(visible=False)
|
538 |
+
else:
|
539 |
+
image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
|
540 |
+
# Consume the generator to get the final image
|
541 |
+
final_image = None
|
542 |
+
step_counter = 0
|
543 |
+
for image in image_generator:
|
544 |
+
step_counter += 1
|
545 |
+
final_image = image
|
546 |
+
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
|
547 |
+
yield image, seed, gr.update(value=progress_bar, visible=True)
|
548 |
+
yield final_image, seed, gr.update(value=progress_bar, visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
549 |
|
550 |
run_lora.zerogpu = True
|
551 |
|
|
|
552 |
def get_huggingface_safetensors(link):
|
553 |
split_link = link.split("/")
|
554 |
if len(split_link) == 4:
|