Ryukijano commited on
Commit
8b1e42e
·
verified ·
1 Parent(s): 66cb29c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -103
app.py CHANGED
@@ -7,7 +7,7 @@ import time
7
  from diffusers import DiffusionPipeline, AutoencoderTiny
8
  from diffusers.models.attention_processor import AttnProcessor2_0
9
  from custom_pipeline import FluxWithCFGPipeline
10
- from huggingface_hub import hf_hub_download
11
 
12
  torch.backends.cuda.matmul.allow_tf32 = True
13
 
@@ -20,123 +20,51 @@ DEFAULT_INFERENCE_STEPS = 1
20
 
21
  # Device and model setup
22
  dtype = torch.float16
23
- device = "cuda" # Explicitly set device to CUDA
24
-
25
- # Download the LoRA weights
26
- lora_weights_path = hf_hub_download(
27
- repo_id="hugovntr/flux-schnell-realism",
28
- filename="schnell-realism_v2.3.safetensors",
29
- )
30
-
31
  pipe = FluxWithCFGPipeline.from_pretrained(
32
- "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
33
  )
34
- pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
35
- pipe.to(device) # Move the pipeline to CUDA
36
-
37
- # Load the LoRA weights
38
- pipe.load_lora_weights(lora_weights_path, adapter_name="better")
39
  pipe.set_adapters(["better"], adapter_weights=[1.0])
40
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
41
  pipe.unload_lora_weights()
42
 
43
- # Memory optimizations
44
- pipe.transformer.to(memory_format=torch.channels_last)
45
  pipe.enable_xformers_memory_efficient_attention()
46
 
47
- # CUDA Graph setup
48
- static_inputs = None
49
- static_model = None
50
- graph = None
51
-
52
- def setup_cuda_graph(prompt, height, width, num_inference_steps):
53
- global static_inputs, static_model, graph
54
-
55
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
56
- num_images_per_prompt = 1
57
-
58
- prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
59
- prompt=prompt,
60
- prompt_2=None,
61
- prompt_embeds=None,
62
- pooled_prompt_embeds=None,
63
- device=device,
64
- num_images_per_prompt=num_images_per_prompt,
65
- max_sequence_length=300,
66
- lora_scale=None,
67
- )
68
-
69
- latents, latent_image_ids = pipe.prepare_latents(
70
- batch_size * num_images_per_prompt,
71
- pipe.transformer.config.in_channels // 4,
72
- height,
73
- width,
74
- prompt_embeds.dtype,
75
- device,
76
- None,
77
- None,
78
- )
79
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
80
- image_seq_len = latents.shape[1]
81
- mu = calculate_timestep_shift(image_seq_len)
82
-
83
- timesteps, num_inference_steps = prepare_timesteps(
84
- pipe.scheduler,
85
- num_inference_steps,
86
- device,
87
- None,
88
- sigmas,
89
- mu=mu,
90
- )
91
-
92
- guidance = torch.full([1], 3.5, device=device, dtype=torch.float16).expand(latents.shape[0]) if pipe.transformer.config.guidance_embeds else None
93
-
94
- static_inputs = {
95
- "hidden_states": latents.to(device),
96
- "timestep": timesteps.to(device),
97
- "guidance": guidance.to(device) if guidance is not None else None,
98
- "pooled_projections": pooled_prompt_embeds.to(device),
99
- "encoder_hidden_states": prompt_embeds.to(device),
100
- "txt_ids": text_ids,
101
- "img_ids": latent_image_ids,
102
- "joint_attention_kwargs": None,
103
- }
104
 
105
- static_model = torch.cuda.make_graphed_callables(pipe.transformer, (static_inputs,))
106
- graph = torch.cuda.CUDAGraph()
 
 
 
 
 
 
107
 
108
- with torch.cuda.graph(graph):
109
- static_output = static_model(**static_inputs)
110
 
111
  # Inference function
112
- # @spaces.GPU(duration=25) # Remove decorator
113
  def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=False, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
114
- global static_inputs, graph
115
-
116
  if randomize_seed:
117
  seed = random.randint(0, MAX_SEED)
118
  generator = torch.Generator().manual_seed(int(float(seed)))
119
 
120
  start_time = time.time()
121
-
122
- if static_inputs is None:
123
- setup_cuda_graph(prompt, height, width, num_inference_steps)
124
-
125
- static_inputs["hidden_states"].copy_(pipe.prepare_latents(
126
- 1,
127
- pipe.transformer.config.in_channels // 4,
128
- height,
129
- width,
130
- static_inputs["encoder_hidden_states"].dtype,
131
- "cuda",
132
- generator,
133
- None,
134
- )[0])
135
 
136
- graph.replay()
137
- latents = static_inputs["hidden_states"]
138
-
139
- img = pipe._decode_latents_to_image(latents, height, width, "pil")
 
 
 
 
140
  latency = f"Latency: {(time.time()-start_time):.2f} seconds"
141
  return img, seed, latency
142
 
@@ -159,7 +87,7 @@ with gr.Blocks() as demo:
159
  gr.Markdown("<span style='color: red;'>Note: Sometimes it stucks or stops generating images (I don't know why). In that situation just refresh the site.</span>")
160
 
161
  with gr.Row():
162
- with gr.Column(scale=2): # Changed scale to 2
163
  result = gr.Image(label="Generated Image", show_label=False, interactive=False)
164
  with gr.Column(scale=1):
165
  prompt = gr.Text(
@@ -192,8 +120,7 @@ with gr.Blocks() as demo:
192
  fn=generate_image,
193
  inputs=[prompt],
194
  outputs=[result, seed, latency],
195
- cache_examples=True, # Changed cache_examples
196
- cache_mode="lazy" # Added cache_mode
197
  )
198
 
199
  enhanceBtn.click(
@@ -253,4 +180,4 @@ with gr.Blocks() as demo:
253
  )
254
 
255
  # Launch the app
256
- demo.queue(max_size=5).launch() # Removed concurrency_count
 
7
  from diffusers import DiffusionPipeline, AutoencoderTiny
8
  from diffusers.models.attention_processor import AttnProcessor2_0
9
  from custom_pipeline import FluxWithCFGPipeline
10
+ from huggingface_hub import login
11
 
12
  torch.backends.cuda.matmul.allow_tf32 = True
13
 
 
20
 
21
  # Device and model setup
22
  dtype = torch.float16
 
 
 
 
 
 
 
 
23
  pipe = FluxWithCFGPipeline.from_pretrained(
24
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype, use_safetensors=True, variant="fp16"
25
  )
26
+ pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, use_safetensors=True, variant="fp16")
27
+ pipe.to("cuda")
28
+ pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
 
 
29
  pipe.set_adapters(["better"], adapter_weights=[1.0])
30
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
31
  pipe.unload_lora_weights()
32
 
33
+ # Enable xformers
 
34
  pipe.enable_xformers_memory_efficient_attention()
35
 
36
+ # Compile the model (Optional, needs further testing for stability)
37
+ # pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # Capture CUDA Graph (Warm-up)
40
+ static_inputs = {
41
+ "prompt": "warmup",
42
+ "width": DEFAULT_WIDTH,
43
+ "height": DEFAULT_HEIGHT,
44
+ "num_inference_steps": DEFAULT_INFERENCE_STEPS,
45
+ "generator": torch.Generator().manual_seed(0),
46
+ }
47
 
48
+ pipe.capture_cuda_graph(**static_inputs)
49
+ torch.cuda.empty_cache()
50
 
51
  # Inference function
52
+ @spaces.GPU(duration=25)
53
  def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=False, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
 
 
54
  if randomize_seed:
55
  seed = random.randint(0, MAX_SEED)
56
  generator = torch.Generator().manual_seed(int(float(seed)))
57
 
58
  start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # Only generate the last image in the sequence
61
+ img = pipe.generate_images(
62
+ prompt=prompt,
63
+ width=width,
64
+ height=height,
65
+ num_inference_steps=num_inference_steps,
66
+ generator=generator
67
+ )
68
  latency = f"Latency: {(time.time()-start_time):.2f} seconds"
69
  return img, seed, latency
70
 
 
87
  gr.Markdown("<span style='color: red;'>Note: Sometimes it stucks or stops generating images (I don't know why). In that situation just refresh the site.</span>")
88
 
89
  with gr.Row():
90
+ with gr.Column(scale=2.5):
91
  result = gr.Image(label="Generated Image", show_label=False, interactive=False)
92
  with gr.Column(scale=1):
93
  prompt = gr.Text(
 
120
  fn=generate_image,
121
  inputs=[prompt],
122
  outputs=[result, seed, latency],
123
+ cache_examples="lazy"
 
124
  )
125
 
126
  enhanceBtn.click(
 
180
  )
181
 
182
  # Launch the app
183
+ demo.launch()