nevreal commited on
Commit
9dda882
·
verified ·
1 Parent(s): 60d5dd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -34
app.py CHANGED
@@ -2,62 +2,70 @@ import gradio as gr
2
  from random import randint
3
  from all_models import models
4
 
5
- def load_fn(models):
6
- global models_load
7
  models_load = {}
8
-
9
  for model in models:
10
- if model not in models_load.keys():
11
  try:
12
  m = gr.load(f'models/{model}')
13
  except Exception as error:
14
  m = gr.Interface(lambda txt: None, ['text'], ['image'])
15
- models_load.update({model: m})
16
-
17
-
18
- load_fn(models)
19
 
 
20
 
21
  num_models = 6
22
  default_models = models[:num_models]
23
 
24
-
25
  def extend_choices(choices):
26
- return choices + (num_models - len(choices)) * ['NA']
27
-
28
 
 
29
  def update_imgbox(choices):
30
- choices_plus = extend_choices(choices)
31
- return [gr.Image(None, label = m, visible = (m != 'NA')) for m in choices_plus]
32
 
33
-
34
- def gen_fn(model_str, prompt):
35
  if model_str == 'NA':
36
  return None
37
  noise = str(randint(0, 99999999999))
38
  return models_load[model_str](f'{prompt} {noise}')
39
 
40
-
41
  with gr.Blocks() as demo:
42
- model_choice2 = gr.Dropdown(models, label = 'Choose model', value = models[0], filterable = False)
43
- txt_input2 = gr.Textbox(label = 'Prompt text')
44
-
45
  max_images = 6
46
- num_images = gr.Slider(1, max_images, value = max_images, step = 1, label = 'Number of images')
47
-
48
- gen_button2 = gr.Button('Generate')
49
- stop_button2 = gr.Button('Stop', variant = 'secondary', interactive = False)
50
- gen_button2.click(lambda s: gr.update(interactive = True), None, stop_button2)
51
-
 
 
52
  with gr.Row():
53
- output2 = [gr.Image(label = '') for _ in range(max_images)]
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- for i, o in enumerate(output2):
56
- img_i = gr.Number(i, visible = False)
57
- num_images.change(lambda i, n: gr.update(visible = (i < n)), [img_i, num_images], o)
58
- gen_event2 = gen_button2.click(lambda i, n, m, t: gen_fn(m, t) if (i < n) else None, [img_i, num_images, model_choice2, txt_input2], o)
59
- stop_button2.click(lambda s: gr.update(interactive = False), None, stop_button2, cancels = [gen_event2])
60
 
61
-
62
- demo.queue(concurrency_count = 36)
63
- demo.launch()
 
2
  from random import randint
3
  from all_models import models
4
 
5
+ # Load models
6
+ def load_models(models):
7
  models_load = {}
 
8
  for model in models:
9
+ if model not in models_load:
10
  try:
11
  m = gr.load(f'models/{model}')
12
  except Exception as error:
13
  m = gr.Interface(lambda txt: None, ['text'], ['image'])
14
+ models_load[model] = m
15
+ return models_load
 
 
16
 
17
+ models_load = load_models(models)
18
 
19
  num_models = 6
20
  default_models = models[:num_models]
21
 
22
+ # Extend choices to a fixed number of models
23
  def extend_choices(choices):
24
+ return choices + ['NA'] * (num_models - len(choices))
 
25
 
26
+ # Dynamically update image boxes based on number of choices
27
  def update_imgbox(choices):
28
+ extended_choices = extend_choices(choices)
29
+ return [gr.Image(None, label=m, visible=(m != 'NA')) for m in extended_choices]
30
 
31
+ # Generate function with noise added to prompt
32
+ def generate_image(model_str, prompt):
33
  if model_str == 'NA':
34
  return None
35
  noise = str(randint(0, 99999999999))
36
  return models_load[model_str](f'{prompt} {noise}')
37
 
38
+ # Gradio interface setup
39
  with gr.Blocks() as demo:
40
+ model_dropdown = gr.Dropdown(models, label='Choose model', value=models[0], filterable=False)
41
+ text_input = gr.Textbox(label='Prompt text')
42
+
43
  max_images = 6
44
+ num_images_slider = gr.Slider(1, max_images, value=max_images, step=1, label='Number of images')
45
+
46
+ generate_button = gr.Button('Generate')
47
+ stop_button = gr.Button('Stop', variant='secondary', interactive=False)
48
+
49
+ # Enable the stop button when generation starts
50
+ generate_button.click(lambda: gr.update(interactive=True), None, stop_button)
51
+
52
  with gr.Row():
53
+ output_images = [gr.Image(label='') for _ in range(max_images)]
54
+
55
+ for i, output in enumerate(output_images):
56
+ img_index = gr.Number(i, visible=False)
57
+ num_images_slider.change(
58
+ lambda idx, n: gr.update(visible=(idx < n)),
59
+ [img_index, num_images_slider], output
60
+ )
61
+ generate_event = generate_button.click(
62
+ lambda idx, n, model, prompt: generate_image(model, prompt) if idx < n else None,
63
+ [img_index, num_images_slider, model_dropdown, text_input], output
64
+ )
65
 
66
+ # Stop button functionality to cancel image generation
67
+ stop_button.click(lambda: gr.update(interactive=False), None, stop_button, cancels=[generate_event])
 
 
 
68
 
69
+ # Set queue concurrency and launch the demo
70
+ demo.queue(concurrency_count=36)
71
+ demo.launch()