demo version2
Browse filesspatial condition added
app.py
CHANGED
@@ -17,6 +17,7 @@ MAX_SEED = 100_000
|
|
17 |
def generate(region1_concept,
|
18 |
region2_concept,
|
19 |
prompt,
|
|
|
20 |
region1_prompt,
|
21 |
region2_prompt,
|
22 |
negative_prompt,
|
@@ -33,13 +34,19 @@ def generate(region1_concept,
|
|
33 |
region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower()
|
34 |
pretrained_model = merge(region1_concept, region2_concept)
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
region1_prompt = f'[<{region1_concept}1> <{region1_concept}2>, {region1_prompt}]'
|
41 |
region2_prompt = f'[<{region2_concept}1> <{region2_concept}2>, {region2_prompt}]'
|
42 |
prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}"
|
|
|
43 |
|
44 |
result = infer(pretrained_model,
|
45 |
prompt,
|
@@ -164,6 +171,10 @@ def infer(pretrained_model,
|
|
164 |
|
165 |
return image[0]
|
166 |
|
|
|
|
|
|
|
|
|
167 |
examples_context = [
|
168 |
'walking at Stanford university campus',
|
169 |
'in a castle',
|
@@ -174,6 +185,10 @@ examples_context = [
|
|
174 |
examples_region1 = ['wearing red hat, high resolution, best quality','bright smile, wearing pants, best quality']
|
175 |
examples_region2 = ['smilling, wearing blue shirt, high resolution, best quality']
|
176 |
|
|
|
|
|
|
|
|
|
177 |
css="""
|
178 |
#col-container {
|
179 |
margin: 0 auto;
|
@@ -182,124 +197,140 @@ css="""
|
|
182 |
"""
|
183 |
|
184 |
with gr.Blocks(css=css) as demo:
|
|
|
|
|
|
|
|
|
185 |
|
186 |
-
with gr.
|
187 |
-
gr.
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
prompt = gr.Text(
|
192 |
-
label="ContextPrompt",
|
193 |
-
show_label=False,
|
194 |
-
max_lines=1,
|
195 |
-
placeholder="Enter your context prompt for overall image",
|
196 |
-
container=False,
|
197 |
-
)
|
198 |
-
with gr.Row():
|
199 |
-
|
200 |
-
region1_concept = gr.Dropdown(
|
201 |
-
["Elsa", "Moana"],
|
202 |
-
label="Character 1",
|
203 |
-
info="Will add more characters later!"
|
204 |
-
)
|
205 |
-
region2_concept = gr.Dropdown(
|
206 |
-
["Elsa", "Moana"],
|
207 |
-
label="Character 2",
|
208 |
-
info="Will add more characters later!"
|
209 |
-
)
|
210 |
-
|
211 |
-
with gr.Row():
|
212 |
-
|
213 |
-
region1_prompt = gr.Textbox(
|
214 |
-
label="Region1 Prompt",
|
215 |
-
show_label=False,
|
216 |
-
max_lines=2,
|
217 |
-
placeholder="Enter your prompt for character 1",
|
218 |
-
container=False,
|
219 |
-
)
|
220 |
-
|
221 |
-
region2_prompt = gr.Textbox(
|
222 |
-
label="Region2 Prompt",
|
223 |
-
show_label=False,
|
224 |
-
max_lines=2,
|
225 |
-
placeholder="Enter your prompt for character 2",
|
226 |
-
container=False,
|
227 |
-
)
|
228 |
-
|
229 |
-
run_button = gr.Button("Run", scale=1)
|
230 |
-
|
231 |
-
result = gr.Image(label="Result", show_label=False)
|
232 |
-
|
233 |
-
with gr.Accordion("Advanced Settings", open=False):
|
234 |
-
|
235 |
-
negative_prompt = gr.Text(
|
236 |
-
label="Context Negative prompt",
|
237 |
-
max_lines=1,
|
238 |
-
value = 'saturated, cropped, worst quality, low quality',
|
239 |
-
visible=False,
|
240 |
-
)
|
241 |
-
|
242 |
-
region_neg_prompt = gr.Text(
|
243 |
-
label="Regional Negative prompt",
|
244 |
-
max_lines=1,
|
245 |
-
value = 'shirtless, nudity, saturated, cropped, worst quality, low quality',
|
246 |
-
visible=False,
|
247 |
-
)
|
248 |
-
|
249 |
-
seed = gr.Slider(
|
250 |
-
label="Seed",
|
251 |
-
minimum=0,
|
252 |
-
maximum=MAX_SEED,
|
253 |
-
step=1,
|
254 |
-
value=0,
|
255 |
-
)
|
256 |
-
|
257 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
with gr.Row():
|
260 |
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
step=0.01,
|
266 |
-
value=0,
|
267 |
)
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
|
|
|
|
|
277 |
|
278 |
-
|
279 |
-
label = 'Context Prompt example',
|
280 |
-
examples = examples_context,
|
281 |
-
inputs = [prompt]
|
282 |
-
)
|
283 |
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
|
291 |
gr.Examples(
|
292 |
-
label = '
|
293 |
-
examples =
|
294 |
-
inputs = [
|
295 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
|
|
297 |
|
298 |
run_button.click(
|
299 |
fn = generate,
|
300 |
inputs = [region1_concept,
|
301 |
region2_concept,
|
302 |
prompt,
|
|
|
303 |
region1_prompt,
|
304 |
region2_prompt,
|
305 |
negative_prompt,
|
|
|
17 |
def generate(region1_concept,
|
18 |
region2_concept,
|
19 |
prompt,
|
20 |
+
pose_image_id,
|
21 |
region1_prompt,
|
22 |
region2_prompt,
|
23 |
negative_prompt,
|
|
|
34 |
region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower()
|
35 |
pretrained_model = merge(region1_concept, region2_concept)
|
36 |
|
37 |
+
with open('multi-concept/pose_data/pose.json') as f:
|
38 |
+
d = json.load(f)
|
39 |
+
|
40 |
+
pose_image = {obj.pop('pose_id'):obj for obj in d}[int(pose_image_id)]
|
41 |
+
print(pose_image)
|
42 |
+
keypose_condition = pose_image['keypose_condition']
|
43 |
+
region1 = pose_image['region1']
|
44 |
+
region2 = pose_image['region2']
|
45 |
|
46 |
region1_prompt = f'[<{region1_concept}1> <{region1_concept}2>, {region1_prompt}]'
|
47 |
region2_prompt = f'[<{region2_concept}1> <{region2_concept}2>, {region2_prompt}]'
|
48 |
prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}"
|
49 |
+
print(prompt_rewrite)
|
50 |
|
51 |
result = infer(pretrained_model,
|
52 |
prompt,
|
|
|
171 |
|
172 |
return image[0]
|
173 |
|
174 |
+
|
175 |
+
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
|
176 |
+
return ''.join(c for c in evt.value['image']['orig_name'] if c.isdigit())
|
177 |
+
|
178 |
examples_context = [
|
179 |
'walking at Stanford university campus',
|
180 |
'in a castle',
|
|
|
185 |
examples_region1 = ['wearing red hat, high resolution, best quality','bright smile, wearing pants, best quality']
|
186 |
examples_region2 = ['smilling, wearing blue shirt, high resolution, best quality']
|
187 |
|
188 |
+
with open('multi-concept/pose_data/pose.json') as f:
|
189 |
+
d = json.load(f)
|
190 |
+
pose_image_list = [(obj['pose_id'],obj['keypose_condition']) for obj in d]
|
191 |
+
|
192 |
css="""
|
193 |
#col-container {
|
194 |
margin: 0 auto;
|
|
|
197 |
"""
|
198 |
|
199 |
with gr.Blocks(css=css) as demo:
|
200 |
+
gr.Markdown(f"""
|
201 |
+
# Orthogonal Adaptation
|
202 |
+
Currently running on : {power_device}
|
203 |
+
""")
|
204 |
|
205 |
+
with gr.Row():
|
206 |
+
with gr.Column(elem_id="col-container", scale=2):
|
207 |
+
gr.Markdown(f"""
|
208 |
+
### 🕹️ Global and Region prompts:
|
209 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
+
prompt = gr.Text(
|
212 |
+
label="ContextPrompt",
|
213 |
+
show_label=False,
|
214 |
+
max_lines=1,
|
215 |
+
placeholder="Enter your context(global) prompt",
|
216 |
+
container=False,
|
217 |
+
)
|
218 |
with gr.Row():
|
219 |
|
220 |
+
region1_concept = gr.Dropdown(
|
221 |
+
["Elsa", "Moana"],
|
222 |
+
label="Character 1",
|
223 |
+
info="Will add more characters later!"
|
|
|
|
|
224 |
)
|
225 |
+
region2_concept = gr.Dropdown(
|
226 |
+
["Elsa", "Moana"],
|
227 |
+
label="Character 2",
|
228 |
+
info="Will add more characters later!"
|
229 |
+
)
|
230 |
+
|
231 |
+
with gr.Row():
|
232 |
+
|
233 |
+
region1_prompt = gr.Textbox(
|
234 |
+
label="Region1 Prompt",
|
235 |
+
show_label=False,
|
236 |
+
max_lines=2,
|
237 |
+
placeholder="Enter your regional prompt for character 1",
|
238 |
+
container=False,
|
239 |
+
)
|
240 |
+
|
241 |
+
region2_prompt = gr.Textbox(
|
242 |
+
label="Region2 Prompt",
|
243 |
+
show_label=False,
|
244 |
+
max_lines=2,
|
245 |
+
placeholder="Enter your regional prompt for character 2",
|
246 |
+
container=False,
|
247 |
)
|
248 |
+
|
249 |
+
gr.Markdown(f"### 🧭 Spatial Condition for regionally controllable sampling: ")
|
250 |
+
gallery = gr.Gallery(label = "Select pose for characters",
|
251 |
+
value = [obj[1]for obj in pose_image_list],
|
252 |
+
elem_id = [obj[0]for obj in pose_image_list],
|
253 |
+
interactive=False, show_download_button=False,
|
254 |
+
preview=True, height = 200, object_fit="scale-down")
|
255 |
|
256 |
+
pose_image_id = gr.Textbox(visible=False)
|
257 |
+
gallery.select(on_select, None, pose_image_id)
|
258 |
|
259 |
+
run_button = gr.Button("Run", scale=1)
|
|
|
|
|
|
|
|
|
260 |
|
261 |
+
with gr.Accordion("Advanced Settings", open=False):
|
262 |
+
|
263 |
+
negative_prompt = gr.Text(
|
264 |
+
label="Context Negative prompt",
|
265 |
+
max_lines=1,
|
266 |
+
value = 'saturated, cropped, worst quality, low quality',
|
267 |
+
visible=False,
|
268 |
+
)
|
269 |
+
|
270 |
+
region_neg_prompt = gr.Text(
|
271 |
+
label="Regional Negative prompt",
|
272 |
+
max_lines=1,
|
273 |
+
value = 'shirtless, nudity, saturated, cropped, worst quality, low quality',
|
274 |
+
visible=False,
|
275 |
+
)
|
276 |
+
|
277 |
+
seed = gr.Slider(
|
278 |
+
label="Seed",
|
279 |
+
minimum=0,
|
280 |
+
maximum=MAX_SEED,
|
281 |
+
step=1,
|
282 |
+
value=0,
|
283 |
+
)
|
284 |
+
|
285 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
286 |
+
|
287 |
+
with gr.Row():
|
288 |
+
|
289 |
+
sketch_adaptor_weight = gr.Slider(
|
290 |
+
label="Sketch Adapter Weight",
|
291 |
+
minimum = 0,
|
292 |
+
maximum = 1,
|
293 |
+
step=0.01,
|
294 |
+
value=0,
|
295 |
+
)
|
296 |
+
|
297 |
+
keypose_adaptor_weight = gr.Slider(
|
298 |
+
label="Keypose Adapter Weight",
|
299 |
+
minimum = 0,
|
300 |
+
maximum = 1,
|
301 |
+
step= 0.01,
|
302 |
+
value=1.0,
|
303 |
+
)
|
304 |
+
|
305 |
+
with gr.Column(scale=1):
|
306 |
|
307 |
gr.Examples(
|
308 |
+
label = 'Global Prompt example',
|
309 |
+
examples = examples_context,
|
310 |
+
inputs = [prompt]
|
311 |
+
)
|
312 |
+
|
313 |
+
with gr.Row():
|
314 |
+
gr.Examples(
|
315 |
+
label = 'Region1 Prompt example',
|
316 |
+
examples = examples_region1,
|
317 |
+
inputs = [region1_prompt]
|
318 |
+
)
|
319 |
+
|
320 |
+
gr.Examples(
|
321 |
+
label = 'Region2 Prompt example',
|
322 |
+
examples = [examples_region2],
|
323 |
+
inputs = [region2_prompt]
|
324 |
+
)
|
325 |
|
326 |
+
result = gr.Image(label="Result", show_label=False)
|
327 |
|
328 |
run_button.click(
|
329 |
fn = generate,
|
330 |
inputs = [region1_concept,
|
331 |
region2_concept,
|
332 |
prompt,
|
333 |
+
pose_image_id,
|
334 |
region1_prompt,
|
335 |
region2_prompt,
|
336 |
negative_prompt,
|