dgoot commited on
Commit
fbb30c1
·
1 Parent(s): 1789e44

Tune GPU duration based on model

Browse files
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -10,11 +10,20 @@ models = [
10
  "stabilityai/stable-diffusion-xl-refiner-1.0",
11
  "timbrooks/instruct-pix2pix",
12
  ]
13
- default_model = "stabilityai/stable-diffusion-xl-refiner-1.0"
 
 
 
 
 
14
 
15
 
16
- @logger.catch(reraise=True)
17
  @spaces.GPU(duration=180)
 
 
 
 
 
18
  def generate(
19
  model: str,
20
  prompt: str,
@@ -49,12 +58,17 @@ def generate(
49
  additional_args = (
50
  {} if model == "timbrooks/instruct-pix2pix" else dict(strength=strength)
51
  )
52
- images = pipe(
53
- prompt=prompt,
54
- image=init_image,
55
- callback_on_step_end=progress_callback,
56
- **additional_args,
57
- ).images
 
 
 
 
 
58
  return images[0]
59
 
60
 
@@ -62,7 +76,7 @@ demo = gr.Interface(
62
  fn=generate,
63
  inputs=[
64
  gr.Dropdown(
65
- label="Model", choices=models, value=default_model, allow_custom_value=True
66
  ),
67
  gr.Text(label="Prompt"),
68
  gr.Image(label="Init image", type="pil"),
 
10
  "stabilityai/stable-diffusion-xl-refiner-1.0",
11
  "timbrooks/instruct-pix2pix",
12
  ]
13
+ DEFAULT_MODEL = "stabilityai/stable-diffusion-xl-refiner-1.0"
14
+
15
+
16
+ @spaces.GPU
17
+ def gpu(fn):
18
+ return fn()
19
 
20
 
 
21
  @spaces.GPU(duration=180)
22
+ def gpu_3min(fn):
23
+ return fn()
24
+
25
+
26
+ @logger.catch(reraise=True)
27
  def generate(
28
  model: str,
29
  prompt: str,
 
58
  additional_args = (
59
  {} if model == "timbrooks/instruct-pix2pix" else dict(strength=strength)
60
  )
61
+
62
+ gpu_runner = gpu_3min if model == "timbrooks/instruct-pix2pix" else gpu
63
+
64
+ images = gpu_runner(
65
+ lambda: pipe(
66
+ prompt=prompt,
67
+ image=init_image,
68
+ callback_on_step_end=progress_callback,
69
+ **additional_args,
70
+ ).images
71
+ )
72
  return images[0]
73
 
74
 
 
76
  fn=generate,
77
  inputs=[
78
  gr.Dropdown(
79
+ label="Model", choices=models, value=DEFAULT_MODEL, allow_custom_value=True
80
  ),
81
  gr.Text(label="Prompt"),
82
  gr.Image(label="Init image", type="pil"),