andsteing commited on
Commit
a6ee350
·
1 Parent(s): bc8a162

Simplifies reactive UI logic.

Browse files
Files changed (1) hide show
  1. app.py +25 -45
app.py CHANGED
@@ -10,6 +10,7 @@ Features:
10
  - Use of `gr.State()` for better use of progress bars.
11
  """
12
  import dataclasses
 
13
  import json
14
  import logging
15
  import os
@@ -171,29 +172,17 @@ def create_app():
171
  prompts = gr.Textbox(label='Prompts (press Shift-ENTER to add a prompt)')
172
  with gr.Row():
173
 
174
- values = {}
175
-
176
  family = gr.Dropdown(value='lit', choices=list(MODEL_MAP), label='Model family')
177
- values['family'] = family.value
178
-
179
- # Unfortunately below reactive UI code is a bit convoluted, because:
180
- # 1. When e.g. `family.change()` updates `variant`, then that does not
181
- # trigger a `varaint.change()`.
182
- # 2. The widget values like `family.value` are *not* updated when the
183
- # widget is updated. Therefore, we keep a manual copy in `values`.
184
 
185
- def make_variant(family_value):
186
  choices = list(MODEL_MAP[family_value])
187
- values['variant'] = choices[0]
188
- return gr.Dropdown(value=values['variant'], choices=choices, label='Variant')
189
- variant = make_variant(family.value)
 
190
 
191
- def make_res(family, variant):
192
- choices = list(MODEL_MAP[family][variant])
193
- values['res'] = choices[0]
194
- return gr.Dropdown(value=values['res'], choices=choices, label='Resolution')
195
- res = make_res(family.value, variant.value)
196
- values['res'] = res.value
197
 
198
  def make_bias(family, variant, res):
199
  visible = family == 'siglip'
@@ -205,37 +194,28 @@ def create_app():
205
  }.get((family, variant, res), -10.0)
206
  return gr.Slider(value=value, minimum=-20, maximum=0, step=0.05, label='Bias', visible=visible)
207
  bias = make_bias(family.value, variant.value, res.value)
208
- values['bias'] = bias.value
209
-
210
- def family_changed(family):
211
- variant = list(MODEL_MAP[family])[0]
212
- res = list(MODEL_MAP[family][variant])[0]
213
- values['family'] = family
214
- values['variant'] = variant
215
- values['res'] = res
216
- return [
217
- make_variant(family),
218
- make_res(family, variant),
219
- make_bias(family, variant, res),
220
- ]
221
 
222
- def variant_changed(variant):
223
- res = list(MODEL_MAP[values['family']][variant])[0]
224
- values['variant'] = variant
225
- values['res'] = res
 
 
 
226
  return [
227
- make_res(values['family'], variant),
228
- make_bias(values['family'], variant, res),
 
229
  ]
230
 
231
- def res_changed(res):
232
- return make_bias(values['family'], values['variant'], res)
233
-
234
- family.change(family_changed, family, [variant, res, bias])
235
- variant.change(variant_changed, variant, [res, bias])
236
- res.change(res_changed, res, bias)
237
 
238
- # (end of code for reactive UI code)
239
 
240
  run = gr.Button('Run')
241
  answers = [
 
10
  - Use of `gr.State()` for better use of progress bars.
11
  """
12
  import dataclasses
13
+ import functools
14
  import json
15
  import logging
16
  import os
 
172
  prompts = gr.Textbox(label='Prompts (press Shift-ENTER to add a prompt)')
173
  with gr.Row():
174
 
 
 
175
  family = gr.Dropdown(value='lit', choices=list(MODEL_MAP), label='Model family')
 
 
 
 
 
 
 
176
 
177
+ def make_variant(family_value, value=None):
178
  choices = list(MODEL_MAP[family_value])
179
+ if value is None:
180
+ value = choices
181
+ make_variant = functools.partial(gr.Dropdown, label='Variant')
182
+ variant = make_variant(list(MODEL_MAP['lit']), value='B/16')
183
 
184
+ make_res = functools.partial(gr.Dropdown, label='Resolution')
185
+ res = make_res(list(MODEL_MAP['lit']['B/16']), value=224)
 
 
 
 
186
 
187
  def make_bias(family, variant, res):
188
  visible = family == 'siglip'
 
194
  }.get((family, variant, res), -10.0)
195
  return gr.Slider(value=value, minimum=-20, maximum=0, step=0.05, label='Bias', visible=visible)
196
  bias = make_bias(family.value, variant.value, res.value)
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ def update_inputs(family, variant, res):
199
+ d = MODEL_MAP[family]
200
+ variants = list(d)
201
+ variant = variant if variant in variants else variants[0]
202
+ d = d[variant]
203
+ ress = list(d)
204
+ res = res if res in ress else ress[0]
205
  return [
206
+ make_variant(variants, value=variant),
207
+ make_res(ress, value=res),
208
+ make_bias(family, variant, res),
209
  ]
210
 
211
+ gr.on(
212
+ [family.change, variant.change, res.change],
213
+ update_inputs,
214
+ [family, variant, res],
215
+ [variant, res, bias],
216
+ )
217
 
218
+ # (end of code for reactive UI)
219
 
220
  run = gr.Button('Run')
221
  answers = [