Ryukijano commited on
Commit
378f95f
·
verified ·
1 Parent(s): ff06f7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -5,10 +5,11 @@ torch.backends.cudnn.allow_tf32 = True
5
  import gradio as gr
6
  import numpy as np
7
  import random
 
8
  import time
9
  from diffusers import DiffusionPipeline, AutoencoderTiny
 
10
  from custom_pipeline import FluxWithCFGPipeline
11
- import asyncio
12
 
13
  # Constants
14
  MAX_SEED = np.iinfo(np.int32).max
@@ -49,10 +50,9 @@ if hasattr(pipe, "transformer") and torch.cuda.is_available():
49
 
50
  torch.cuda.empty_cache()
51
 
52
-
53
-
54
- # Inference function (async)
55
- async def generate_image(
56
  prompt,
57
  seed=24,
58
  width=DEFAULT_WIDTH,
@@ -118,7 +118,7 @@ async def generate_image(
118
  static_latents_out, height, width, "pil"
119
  )
120
 
121
- # Graph-based generation function (synchronous)
122
  def generate_with_graph(
123
  latents,
124
  prompt_embeds,
@@ -136,6 +136,7 @@ async def generate_image(
136
  g.replay()
137
  return static_output
138
 
 
139
  img = pipe.generate_images(
140
  prompt=prompt,
141
  width=width,
@@ -264,10 +265,10 @@ with gr.Blocks() as demo:
264
  concurrency_limit=None,
265
  )
266
 
267
- async def realtime_generation(*args):
268
- print("realtime_generation")
269
  if args[0]: # If realtime is enabled
270
- return await generate_image(*args[1:])
 
271
 
272
  prompt.submit(
273
  fn=generate_image,
 
5
  import gradio as gr
6
  import numpy as np
7
  import random
8
+ import spaces
9
  import time
10
  from diffusers import DiffusionPipeline, AutoencoderTiny
11
+ from diffusers.models.attention_processor import AttnProcessor2_0
12
  from custom_pipeline import FluxWithCFGPipeline
 
13
 
14
  # Constants
15
  MAX_SEED = np.iinfo(np.int32).max
 
50
 
51
  torch.cuda.empty_cache()
52
 
53
+ # Inference function
54
+ @spaces.GPU(duration=25)
55
+ def generate_image(
 
56
  prompt,
57
  seed=24,
58
  width=DEFAULT_WIDTH,
 
118
  static_latents_out, height, width, "pil"
119
  )
120
 
121
+ # Graph-based generation function
122
  def generate_with_graph(
123
  latents,
124
  prompt_embeds,
 
136
  g.replay()
137
  return static_output
138
 
139
+ # Only generate the last image in the sequence
140
  img = pipe.generate_images(
141
  prompt=prompt,
142
  width=width,
 
265
  concurrency_limit=None,
266
  )
267
 
268
+ def realtime_generation(*args):
 
269
  if args[0]: # If realtime is enabled
270
+ img, seed, latency = generate_image(*args[1:])
271
+ return img, seed, latency
272
 
273
  prompt.submit(
274
  fn=generate_image,