Keltezaa commited on
Commit
45a6998
·
verified ·
1 Parent(s): d683517

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -90
app.py CHANGED
@@ -5,10 +5,8 @@ import logging
5
  import torch
6
  from PIL import Image
7
  import spaces
8
- from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
9
- from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
10
- from diffusers.utils import load_image
11
- from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
12
  from transformers import AutoModelForCausalLM, CLIPTokenizer, CLIPProcessor, CLIPModel, LongformerTokenizer, LongformerModel
13
  import copy
14
  import random
@@ -47,9 +45,7 @@ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef
47
 
48
  MAX_SEED = 2**32 - 1
49
 
50
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
51
-
52
- def process_input(input_text):
53
  # Tokenize and truncate input
54
  inputs = clip_processor(text=input_text, return_tensors="pt", padding=True, truncation=True, max_length=77)
55
  return inputs
@@ -93,7 +89,7 @@ def download_file(url, directory=None):
93
  file.write(response.content)
94
 
95
  return filepath
96
-
97
  def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
98
  selected_index = evt.index
99
  selected_indices = selected_indices or []
@@ -288,39 +284,11 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
288
  generator = torch.Generator(device="cuda").manual_seed(seed)
289
  with calculateDuration("Generating image"):
290
  # Generate image
291
- for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
292
- prompt=prompt_mash,
293
- num_inference_steps=steps,
294
- guidance_scale=cfg_scale,
295
- width=width,
296
- height=height,
297
- generator=generator,
298
- joint_attention_kwargs={"scale": 1.0},
299
- output_type="pil",
300
- good_vae=good_vae,
301
- ):
302
  yield img
303
 
304
- def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
305
- pipe_i2i.to("cuda")
306
- generator = torch.Generator(device="cuda").manual_seed(seed)
307
- image_input = load_image(image_input_path)
308
- final_image = pipe_i2i(
309
- prompt=prompt_mash,
310
- image=image_input,
311
- strength=image_strength,
312
- num_inference_steps=steps,
313
- guidance_scale=cfg_scale,
314
- width=width,
315
- height=height,
316
- generator=generator,
317
- joint_attention_kwargs={"scale": 1.0},
318
- output_type="pil",
319
- ).images[0]
320
- return final_image
321
-
322
  @spaces.GPU(duration=75)
323
- def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
324
  if not selected_indices:
325
  raise gr.Error("You must select at least one LoRA before proceeding.")
326
 
@@ -338,12 +306,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
338
  appends.append(trigger_word)
339
  prompt_mash = " ".join(prepends + [prompt] + appends)
340
  print("Prompt Mash: ", prompt_mash)
341
- # Unload previous LoRA weights
342
- with calculateDuration("Unloading LoRA"):
343
- pipe.unload_lora_weights()
344
- pipe_i2i.unload_lora_weights()
345
-
346
- print(pipe.get_active_adapters())
347
  # Load LoRA weights with respective scales
348
  lora_names = []
349
  lora_weights = []
@@ -352,46 +315,27 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
352
  lora_name = f"lora_{idx}"
353
  lora_names.append(lora_name)
354
  print(f"Lora Name: {lora_name}")
355
- lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
356
  lora_path = lora['repo']
357
  weight_name = lora.get("weights")
358
  print(f"Lora Path: {lora_path}")
359
- pipe_to_use = pipe_i2i if image_input is not None else pipe
360
- pipe_to_use.load_lora_weights(
361
  lora_path,
362
  weight_name=weight_name if weight_name else None,
363
  low_cpu_mem_usage=True,
364
  adapter_name=lora_name
365
  )
366
- # if image_input is not None: pipe_i2i = pipe_to_use
367
- # else: pipe = pipe_to_use
368
  print("Loaded LoRAs:", lora_names)
369
  print("Adapter weights:", lora_weights)
370
- if image_input is not None:
371
- pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
372
- else:
373
- pipe.set_adapters(lora_names, adapter_weights=lora_weights)
374
- print(pipe.get_active_adapters())
375
  # Set random seed for reproducibility
376
- with calculateDuration("Randomizing seed"):
377
- if randomize_seed:
378
- seed = random.randint(0, MAX_SEED)
379
 
380
  # Generate image
381
- if image_input is not None:
382
- final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
383
- yield final_image, seed, gr.update(visible=False)
384
- else:
385
- image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
386
- # Consume the generator to get the final image
387
- final_image = None
388
- step_counter = 0
389
- for image in image_generator:
390
- step_counter += 1
391
- final_image = image
392
- progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
393
- yield image, seed, gr.update(value=progress_bar, visible=True)
394
- yield final_image, seed, gr.update(value=progress_bar, visible=False)
395
 
396
  run_lora.zerogpu = True
397
 
@@ -451,7 +395,7 @@ def update_history(new_image, history):
451
  history.insert(0, new_image)
452
  return history
453
 
454
- css = '''
455
  #gen_btn{height: 100%}
456
  #title{text-align: center}
457
  #title h1{font-size: 3em; display:inline-flex; align-items:center}
@@ -500,7 +444,7 @@ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
500
  with gr.Column(scale=3, min_width=100):
501
  selected_info_1 = gr.Markdown("Select a LoRA 1")
502
  with gr.Column(scale=5, min_width=50):
503
- lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
504
  with gr.Row():
505
  remove_button_1 = gr.Button("Remove", size="sm")
506
  with gr.Column(scale=8):
@@ -510,7 +454,7 @@ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
510
  with gr.Column(scale=3, min_width=100):
511
  selected_info_2 = gr.Markdown("Select a LoRA 2")
512
  with gr.Column(scale=5, min_width=50):
513
- lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
514
  with gr.Row():
515
  remove_button_2 = gr.Button("Remove", size="sm")
516
  with gr.Row():
@@ -539,21 +483,16 @@ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
539
  with gr.Row():
540
  with gr.Accordion("Advanced Settings", open=False):
541
  with gr.Row():
542
- input_image = gr.Image(label="Input image", type="filepath", show_share_button=False)
543
- image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
544
- with gr.Column():
545
- with gr.Row():
546
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
547
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
548
-
549
- with gr.Row():
550
- width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
551
- height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
552
-
553
- with gr.Row():
554
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
555
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
556
-
557
 
558
  gallery.select(
559
  update_selection,
@@ -588,7 +527,7 @@ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
588
  gr.on(
589
  triggers=[generate_button.click, prompt.submit],
590
  fn=run_lora,
591
- inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
592
  outputs=[result, seed, progress_bar]
593
  ).then(
594
  fn=lambda x, history: update_history(x, history),
 
5
  import torch
6
  from PIL import Image
7
  import spaces
8
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
9
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
 
 
10
  from transformers import AutoModelForCausalLM, CLIPTokenizer, CLIPProcessor, CLIPModel, LongformerTokenizer, LongformerModel
11
  import copy
12
  import random
 
45
 
46
  MAX_SEED = 2**32 - 1
47
 
48
+ ef process_input(input_text):
 
 
49
  # Tokenize and truncate input
50
  inputs = clip_processor(text=input_text, return_tensors="pt", padding=True, truncation=True, max_length=77)
51
  return inputs
 
89
  file.write(response.content)
90
 
91
  return filepath
92
+
93
  def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
94
  selected_index = evt.index
95
  selected_indices = selected_indices or []
 
284
  generator = torch.Generator(device="cuda").manual_seed(seed)
285
  with calculateDuration("Generating image"):
286
  # Generate image
287
+ for img in pipe(prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator):
 
 
 
 
 
 
 
 
 
 
288
  yield img
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  @spaces.GPU(duration=75)
291
+ def run_lora(prompt, selected_indices, loras_state):
292
  if not selected_indices:
293
  raise gr.Error("You must select at least one LoRA before proceeding.")
294
 
 
306
  appends.append(trigger_word)
307
  prompt_mash = " ".join(prepends + [prompt] + appends)
308
  print("Prompt Mash: ", prompt_mash)
309
+
 
 
 
 
 
310
  # Load LoRA weights with respective scales
311
  lora_names = []
312
  lora_weights = []
 
315
  lora_name = f"lora_{idx}"
316
  lora_names.append(lora_name)
317
  print(f"Lora Name: {lora_name}")
318
+ lora_weights.append(1.15) # Assuming a default scale
319
  lora_path = lora['repo']
320
  weight_name = lora.get("weights")
321
  print(f"Lora Path: {lora_path}")
322
+ pipe.load_lora_weights(
 
323
  lora_path,
324
  weight_name=weight_name if weight_name else None,
325
  low_cpu_mem_usage=True,
326
  adapter_name=lora_name
327
  )
 
 
328
  print("Loaded LoRAs:", lora_names)
329
  print("Adapter weights:", lora_weights)
330
+ pipe.set_adapters(lora_names, adapter_weights=lora_weights)
331
+ print(pipe.get_active_adapters())
332
+
 
 
333
  # Set random seed for reproducibility
334
+ seed = random.randint(0, MAX_SEED)
 
 
335
 
336
  # Generate image
337
+ final_image = generate_image(prompt_mash, 50, seed, 7.5, 512, 512, None) # Example parameters
338
+ yield final_image, seed, gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
  run_lora.zerogpu = True
341
 
 
395
  history.insert(0, new_image)
396
  return history
397
 
398
+ ccss = '''
399
  #gen_btn{height: 100%}
400
  #title{text-align: center}
401
  #title h1{font-size: 3em; display:inline-flex; align-items:center}
 
444
  with gr.Column(scale=3, min_width=100):
445
  selected_info_1 = gr.Markdown("Select a LoRA 1")
446
  with gr.Column(scale=5, min_width=50):
447
+ lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.05, value=1.15)
448
  with gr.Row():
449
  remove_button_1 = gr.Button("Remove", size="sm")
450
  with gr.Column(scale=8):
 
454
  with gr.Column(scale=3, min_width=100):
455
  selected_info_2 = gr.Markdown("Select a LoRA 2")
456
  with gr.Column(scale=5, min_width=50):
457
+ lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.05, value=1.15)
458
  with gr.Row():
459
  remove_button_2 = gr.Button("Remove", size="sm")
460
  with gr.Row():
 
483
  with gr.Row():
484
  with gr.Accordion("Advanced Settings", open=False):
485
  with gr.Row():
486
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5)
487
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
488
+
489
+ with gr.Row():
490
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=768)
491
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
492
+
493
+ with gr.Row():
494
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
495
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
 
 
 
 
 
496
 
497
  gallery.select(
498
  update_selection,
 
527
  gr.on(
528
  triggers=[generate_button.click, prompt.submit],
529
  fn=run_lora,
530
+ inputs=[prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
531
  outputs=[result, seed, progress_bar]
532
  ).then(
533
  fn=lambda x, history: update_history(x, history),