Update app.py
Browse files- implemented error message
- json structure modification
app.py
CHANGED
@@ -17,7 +17,7 @@ MAX_SEED = 100_000
|
|
17 |
def generate(region1_concept,
|
18 |
region2_concept,
|
19 |
prompt,
|
20 |
-
|
21 |
region1_prompt,
|
22 |
region2_prompt,
|
23 |
negative_prompt,
|
@@ -27,6 +27,15 @@ def generate(region1_concept,
|
|
27 |
sketch_adaptor_weight,
|
28 |
keypose_adaptor_weight
|
29 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
if randomize_seed:
|
32 |
seed = random.randint(0, MAX_SEED)
|
@@ -37,9 +46,10 @@ def generate(region1_concept,
|
|
37 |
with open('multi-concept/pose_data/pose.json') as f:
|
38 |
d = json.load(f)
|
39 |
|
40 |
-
pose_image = {
|
|
|
41 |
print(pose_image)
|
42 |
-
keypose_condition = pose_image['
|
43 |
region1 = pose_image['region1']
|
44 |
region2 = pose_image['region2']
|
45 |
|
@@ -173,7 +183,7 @@ def infer(pretrained_model,
|
|
173 |
|
174 |
|
175 |
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
|
176 |
-
return
|
177 |
|
178 |
examples_context = [
|
179 |
'walking at Stanford university campus',
|
@@ -187,7 +197,7 @@ examples_region2 = ['smilling, wearing blue shirt, high resolution, best quality
|
|
187 |
|
188 |
with open('multi-concept/pose_data/pose.json') as f:
|
189 |
d = json.load(f)
|
190 |
-
pose_image_list = [(obj['
|
191 |
|
192 |
css="""
|
193 |
#col-container {
|
@@ -210,7 +220,7 @@ with gr.Blocks(css=css) as demo:
|
|
210 |
# gr.Markdown(f"""
|
211 |
# ### 🪄 Global and Region prompts
|
212 |
# """)
|
213 |
-
# with gr.Group():
|
214 |
with gr.Tab('🪄 Global and Region prompts'):
|
215 |
prompt = gr.Text(
|
216 |
label="ContextPrompt",
|
@@ -282,10 +292,10 @@ with gr.Blocks(css=css) as demo:
|
|
282 |
value = [obj[1]for obj in pose_image_list],
|
283 |
elem_id = [obj[0]for obj in pose_image_list],
|
284 |
interactive=False, show_download_button=False,
|
285 |
-
preview=True, height =
|
286 |
|
287 |
-
|
288 |
-
gallery.select(on_select, None,
|
289 |
|
290 |
run_button = gr.Button("Run", scale=1)
|
291 |
|
@@ -346,7 +356,7 @@ with gr.Blocks(css=css) as demo:
|
|
346 |
inputs = [region1_concept,
|
347 |
region2_concept,
|
348 |
prompt,
|
349 |
-
|
350 |
region1_prompt,
|
351 |
region2_prompt,
|
352 |
negative_prompt,
|
|
|
17 |
def generate(region1_concept,
|
18 |
region2_concept,
|
19 |
prompt,
|
20 |
+
pose_image_name,
|
21 |
region1_prompt,
|
22 |
region2_prompt,
|
23 |
negative_prompt,
|
|
|
27 |
sketch_adaptor_weight,
|
28 |
keypose_adaptor_weight
|
29 |
):
|
30 |
+
|
31 |
+
if region1_concept==region2_concept:
|
32 |
+
raise gr.Error("Please choose two different characters for merging weights.")
|
33 |
+
if len(pose_image_name)==0:
|
34 |
+
raise gr.Error("Please select one spatial condition!")
|
35 |
+
if len(region1_prompt)==0 or len(region1_prompt)==0:
|
36 |
+
raise gr.Error("Your regional prompt cannot be empty.")
|
37 |
+
if len(prompt)==0:
|
38 |
+
raise gr.Error("Your global prompt cannot be empty.")
|
39 |
|
40 |
if randomize_seed:
|
41 |
seed = random.randint(0, MAX_SEED)
|
|
|
46 |
with open('multi-concept/pose_data/pose.json') as f:
|
47 |
d = json.load(f)
|
48 |
|
49 |
+
pose_image = {os.path.basename(obj['img_dir']):obj for obj in d}[pose_image_name]
|
50 |
+
# pose_image = {obj.pop('pose_id'):obj for obj in d}[int(pose_image_id)]
|
51 |
print(pose_image)
|
52 |
+
keypose_condition = pose_image['img_dir']
|
53 |
region1 = pose_image['region1']
|
54 |
region2 = pose_image['region2']
|
55 |
|
|
|
183 |
|
184 |
|
185 |
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
|
186 |
+
return evt.value['image']['orig_name']
|
187 |
|
188 |
examples_context = [
|
189 |
'walking at Stanford university campus',
|
|
|
197 |
|
198 |
with open('multi-concept/pose_data/pose.json') as f:
|
199 |
d = json.load(f)
|
200 |
+
pose_image_list = [(obj['img_id'],obj['img_dir']) for obj in d]
|
201 |
|
202 |
css="""
|
203 |
#col-container {
|
|
|
220 |
# gr.Markdown(f"""
|
221 |
# ### 🪄 Global and Region prompts
|
222 |
# """)
|
223 |
+
# with gr.Group():
|
224 |
with gr.Tab('🪄 Global and Region prompts'):
|
225 |
prompt = gr.Text(
|
226 |
label="ContextPrompt",
|
|
|
292 |
value = [obj[1]for obj in pose_image_list],
|
293 |
elem_id = [obj[0]for obj in pose_image_list],
|
294 |
interactive=False, show_download_button=False,
|
295 |
+
preview=True, height = 400, object_fit="scale-down")
|
296 |
|
297 |
+
pose_image_name = gr.Textbox(visible=False)
|
298 |
+
gallery.select(on_select, None, pose_image_name)
|
299 |
|
300 |
run_button = gr.Button("Run", scale=1)
|
301 |
|
|
|
356 |
inputs = [region1_concept,
|
357 |
region2_concept,
|
358 |
prompt,
|
359 |
+
pose_image_name,
|
360 |
region1_prompt,
|
361 |
region2_prompt,
|
362 |
negative_prompt,
|