prithivMLmods commited on
Commit
c9a1d71
·
verified ·
1 Parent(s): 4a4da06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -92
app.py CHANGED
@@ -45,7 +45,7 @@ style_list = [
45
  },
46
  ]
47
 
48
- STYLE_NAMES = [style["name"] for style in style_list]
49
  DEFAULT_STYLE_NAME = STYLE_NAMES[0]
50
 
51
  grid_sizes = {
@@ -79,11 +79,10 @@ def infer(
79
  seed = random.randint(0, MAX_SEED)
80
 
81
  generator = torch.Generator().manual_seed(seed)
 
 
82
 
83
- grid_size_x, grid_size_y = grid_sizes.get(grid_size, (1, 1))
84
- num_images = grid_size_x * grid_size_y
85
-
86
- options = {
87
  "prompt": styled_prompt,
88
  "negative_prompt": styled_negative_prompt,
89
  "guidance_scale": guidance_scale,
@@ -94,13 +93,14 @@ def infer(
94
  "num_images_per_prompt": num_images,
95
  }
96
 
97
- torch.cuda.empty_cache() # Clear GPU memory
98
- result = pipe(**options)
99
-
100
- grid_img = Image.new('RGB', (width * grid_size_x, height * grid_size_y))
101
 
102
- for i, img in enumerate(result.images[:num_images]):
103
- grid_img.paste(img, (i % grid_size_x * width, i // grid_size_x * height))
 
 
 
104
 
105
  return grid_img, seed
106
 
@@ -111,36 +111,37 @@ examples = [
111
  "A cat holding a sign that says hello world --ar 85:128 --v 6.0 --style raw"
112
  ]
113
 
114
- #css = '''
115
- #.gradio-container{max-width: 585px !important}
116
- #h1{text-align:center}
117
- #footer {
118
- # visibility: hidden
119
- #}
120
- #'''
 
 
 
 
 
121
 
122
- #with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
123
- with gr.Blocks() as demo:
124
  with gr.Column(elem_id="col-container"):
125
- gr.Markdown("## GRID 6X🪨")
126
 
127
  with gr.Row():
128
  prompt = gr.Text(
129
- label="Prompt",
130
  show_label=False,
131
  max_lines=1,
132
  placeholder="Enter your prompt",
133
  container=False,
134
  )
135
-
136
  run_button = gr.Button("Run", scale=0, variant="primary")
137
 
138
- result = gr.Image(label="Result", show_label=False)
139
-
140
 
141
- with gr.Row(visible=True):
142
  grid_size_selection = gr.Dropdown(
143
- choices=["2x1", "1x2", "2x2", "2x3", "3x2", "1x1"],
144
  value="1x1",
145
  label="Grid Size"
146
  )
@@ -151,82 +152,39 @@ with gr.Blocks() as demo:
151
  max_lines=1,
152
  placeholder="Enter a negative prompt",
153
  value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
154
- visible=False,
155
  )
156
-
157
- seed = gr.Slider(
158
- label="Seed",
159
- minimum=0,
160
- maximum=MAX_SEED,
161
- step=1,
162
- value=0,
163
- )
164
-
165
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
166
 
167
  with gr.Row():
168
- width = gr.Slider(
169
- label="Width",
170
- minimum=512,
171
- maximum=MAX_IMAGE_SIZE,
172
- step=32,
173
- value=1024,
174
- )
175
-
176
- height = gr.Slider(
177
- label="Height",
178
- minimum=512,
179
- maximum=MAX_IMAGE_SIZE,
180
- step=32,
181
- value=1024,
182
- )
183
 
184
  with gr.Row():
185
- guidance_scale = gr.Slider(
186
- label="Guidance scale",
187
- minimum=0.0,
188
- maximum=7.5,
189
- step=0.1,
190
- value=0.0,
191
- )
192
-
193
- num_inference_steps = gr.Slider(
194
- label="Number of inference steps",
195
- minimum=1,
196
- maximum=50,
197
- step=1,
198
- value=8,
199
- )
200
-
201
- style_selection = gr.Radio(
202
- show_label=True,
203
- container=True,
204
- interactive=True,
205
- choices=STYLE_NAMES,
206
- value=DEFAULT_STYLE_NAME,
207
- label="Quality Style",
208
- )
209
-
210
- gr.Examples(examples=examples,
211
- inputs=[prompt],
212
- outputs=[result, seed],
213
- fn=infer,
214
- cache_examples=False)
215
 
216
  gr.on(
217
  triggers=[run_button.click, prompt.submit],
218
  fn=infer,
219
  inputs=[
220
- prompt,
221
- negative_prompt,
222
- seed,
223
- randomize_seed,
224
- width,
225
- height,
226
- guidance_scale,
227
- num_inference_steps,
228
- style_selection,
229
- grid_size_selection,
230
  ],
231
  outputs=[result, seed],
232
  )
 
45
  },
46
  ]
47
 
48
+ STYLE_NAMES = [s["name"] for s in style_list]
49
  DEFAULT_STYLE_NAME = STYLE_NAMES[0]
50
 
51
  grid_sizes = {
 
79
  seed = random.randint(0, MAX_SEED)
80
 
81
  generator = torch.Generator().manual_seed(seed)
82
+ grid_x, grid_y = grid_sizes.get(grid_size, (1, 1))
83
+ num_images = grid_x * grid_y
84
 
85
+ opts = {
 
 
 
86
  "prompt": styled_prompt,
87
  "negative_prompt": styled_negative_prompt,
88
  "guidance_scale": guidance_scale,
 
93
  "num_images_per_prompt": num_images,
94
  }
95
 
96
+ torch.cuda.empty_cache()
97
+ res = pipe(**opts)
 
 
98
 
99
+ grid_img = Image.new('RGB', (width * grid_x, height * grid_y))
100
+ for i, img in enumerate(res.images[:num_images]):
101
+ x = (i % grid_x) * width
102
+ y = (i // grid_x) * height
103
+ grid_img.paste(img, (x, y))
104
 
105
  return grid_img, seed
106
 
 
111
  "A cat holding a sign that says hello world --ar 85:128 --v 6.0 --style raw"
112
  ]
113
 
114
+ css = '''
115
+ .gradio-container {
116
+ max-width: 585px !important;
117
+ margin: 0 auto !important;
118
+ display: flex;
119
+ flex-direction: column;
120
+ align-items: center;
121
+ justify-content: center;
122
+ }
123
+ h1 { text-align: center; }
124
+ footer { visibility: hidden; }
125
+ '''
126
 
127
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
 
128
  with gr.Column(elem_id="col-container"):
129
+ gr.Markdown("## GRID 6X🪨")
130
 
131
  with gr.Row():
132
  prompt = gr.Text(
 
133
  show_label=False,
134
  max_lines=1,
135
  placeholder="Enter your prompt",
136
  container=False,
137
  )
 
138
  run_button = gr.Button("Run", scale=0, variant="primary")
139
 
140
+ result = gr.Image(show_label=False)
 
141
 
142
+ with gr.Row():
143
  grid_size_selection = gr.Dropdown(
144
+ choices=list(grid_sizes.keys()),
145
  value="1x1",
146
  label="Grid Size"
147
  )
 
152
  max_lines=1,
153
  placeholder="Enter a negative prompt",
154
  value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
 
155
  )
156
+ seed = gr.Slider(0, MAX_SEED, value=0, label="Seed")
 
 
 
 
 
 
 
 
157
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
158
 
159
  with gr.Row():
160
+ width = gr.Slider(512, MAX_IMAGE_SIZE, step=32, value=1024, label="Width")
161
+ height = gr.Slider(512, MAX_IMAGE_SIZE, step=32, value=1024, label="Height")
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  with gr.Row():
164
+ guidance_scale = gr.Slider(0.0, 7.5, step=0.1, value=7.5, label="Guidance scale")
165
+ num_inference_steps = gr.Slider(1, 50, step=1, value=10, label="Number of inference steps")
166
+
167
+ style_selection = gr.Radio(
168
+ choices=STYLE_NAMES,
169
+ value=DEFAULT_STYLE_NAME,
170
+ label="Quality Style",
171
+ )
172
+
173
+ gr.Examples(
174
+ examples=examples,
175
+ inputs=[prompt],
176
+ outputs=[result, seed],
177
+ fn=infer,
178
+ cache_examples=False
179
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  gr.on(
182
  triggers=[run_button.click, prompt.submit],
183
  fn=infer,
184
  inputs=[
185
+ prompt, negative_prompt, seed, randomize_seed,
186
+ width, height, guidance_scale, num_inference_steps,
187
+ style_selection, grid_size_selection
 
 
 
 
 
 
 
188
  ],
189
  outputs=[result, seed],
190
  )