ura23 commited on
Commit
40964e6
·
verified ·
1 Parent(s): d49d87d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -220,13 +220,16 @@ def main():
220
  with gr.Column():
221
  output = gr.Textbox(label="Output", lines=10)
222
 
223
- def process_images(files, model_repo, general_thresh, character_thresh, filter_tags):
224
  images = [Image.open(file.name) for file in files]
225
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
226
-
227
  # Parse filter tags
228
  filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
229
-
 
 
 
230
  # Generate formatted output
231
  prompts = []
232
  for i, (general_tags, character_tags) in enumerate(results):
@@ -237,22 +240,29 @@ def main():
237
  general_part = ", ".join(
238
  tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
239
  )
240
-
 
 
 
 
 
241
  # Construct the prompt based on the presence of character_part
242
  if character_part:
243
  prompts.append(f"{character_part}, {general_part}")
244
  else:
245
  prompts.append(general_part)
246
-
247
  # Join all prompts with blank lines
248
  return "\n\n".join(prompts)
249
 
 
250
  submit.click(
251
  process_images,
252
- inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
253
  outputs=output
254
  )
255
 
 
256
  demo.queue(max_size=10)
257
  demo.launch()
258
 
 
220
  with gr.Column():
221
  output = gr.Textbox(label="Output", lines=10)
222
 
223
+ def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, custom_tags):
224
  images = [Image.open(file.name) for file in files]
225
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
226
+
227
  # Parse filter tags
228
  filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
229
+
230
+ # Parse custom tags
231
+ custom_tag_set = set(tag.strip().lower() for tag in custom_tags.split(","))
232
+
233
  # Generate formatted output
234
  prompts = []
235
  for i, (general_tags, character_tags) in enumerate(results):
 
240
  general_part = ", ".join(
241
  tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
242
  )
243
+
244
+ # Check for missing custom tags
245
+ missing_custom_tags = custom_tag_set - set(general_tags + character_tags)
246
+ if missing_custom_tags:
247
+ general_part += ", " + ", ".join(missing_custom_tags)
248
+
249
  # Construct the prompt based on the presence of character_part
250
  if character_part:
251
  prompts.append(f"{character_part}, {general_part}")
252
  else:
253
  prompts.append(general_part)
254
+
255
  # Join all prompts with blank lines
256
  return "\n\n".join(prompts)
257
 
258
+
259
  submit.click(
260
  process_images,
261
+ inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags, custom_tags],
262
  outputs=output
263
  )
264
 
265
+
266
  demo.queue(max_size=10)
267
  demo.launch()
268