ura23 commited on
Commit
67ae085
·
verified ·
1 Parent(s): c62d659

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -109
app.py CHANGED
@@ -40,7 +40,8 @@ LABEL_FILENAME = "selected_tags.csv"
40
  def parse_args() -> argparse.Namespace:
41
  parser = argparse.ArgumentParser()
42
  parser.add_argument("--score-slider-step", type=float, default=0.05)
43
- parser.add_argument("--score-general-threshold", type=float, default=0.25)
 
44
  parser.add_argument("--score-character-threshold", type=float, default=1.0)
45
  return parser.parse_args()
46
 
@@ -58,6 +59,7 @@ class Predictor:
58
  def download_model(self, model_repo):
59
  csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME, use_auth_token=HF_TOKEN)
60
  model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME, use_auth_token=HF_TOKEN)
 
61
  return csv_path, model_path
62
 
63
  def load_model(self, model_repo):
@@ -70,6 +72,7 @@ class Predictor:
70
 
71
  model = rt.InferenceSession(model_path)
72
  _, height, width, _ = model.get_inputs()[0].shape
 
73
  self.model_target_size = height
74
  self.last_loaded_repo = model_repo
75
  self.model = model
@@ -80,6 +83,7 @@ class Predictor:
80
 
81
  # Ensure the input image has an alpha channel for compositing
82
  if image.mode != "RGBA":
 
83
  image = image.convert("RGBA")
84
 
85
  # Composite the input image onto the canvas
@@ -90,6 +94,7 @@ class Predictor:
90
 
91
  # Resize the image to a square of size (model_target_size x model_target_size)
92
  max_dim = max(image.size)
 
93
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
94
  pad_left = (max_dim - image.width) // 2
95
  pad_top = (max_dim - image.height) // 2
@@ -103,7 +108,7 @@ class Predictor:
103
 
104
  def predict(self, images, model_repo, general_thresh, character_thresh):
105
  self.load_model(model_repo)
106
- results = []
107
 
108
  for image in images:
109
  image = self.prepare_image(image)
@@ -111,6 +116,7 @@ class Predictor:
111
  label_name = self.model.get_outputs()[0].name
112
  preds = self.model.run([label_name], {input_name: image})[0]
113
 
 
114
  labels = list(zip(self.tag_names, preds[0].astype(float)))
115
  general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh]
116
  character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh]
@@ -134,113 +140,13 @@ def main():
134
  CONV_MODEL_DSV2_REPO,
135
  CONV2_MODEL_DSV2_REPO,
136
  VIT_MODEL_DSV2_REPO,
 
137
  # ---
138
  SWINV2_MODEL_IS_DSV1_REPO,
139
  EVA02_LARGE_MODEL_IS_DSV1_REPO,
140
  ]
141
 
142
- predefined_tags = ["loli",
143
- "oppai_loli",
144
- "2024",
145
- "2023",
146
- "2025",
147
- "2022",
148
- "2021",
149
- "onee-shota",
150
- "incest",
151
- "furry",
152
- "twitter_strip_game_(meme)",
153
- "like_and_retweet",
154
- "furry_female",
155
- "realistic",
156
- "egg_vibrator",
157
- "tongue_piercing",
158
- "handheld_game_console",
159
- "game_controller",
160
- "nintendo_switch",
161
- "talking",
162
- "swastika",
163
- "character_name",
164
- "vibrator",
165
- "black-framed_eyewear",
166
- "heterochromia",
167
- "controller",
168
- "remote_control_vibrator",
169
- "vibrator_under_clothes",
170
- "thank_you",
171
- "vibrator_cord",
172
- "shota",
173
- "male_focus",
174
- "signature",
175
- "web_address",
176
- "censored_nipples",
177
- "rhodes_island_logo_(arknights)",
178
- "gothic_lolita",
179
- "glasses",
180
- "reference_inset",
181
- "twitter_logo",
182
- "mother_and_daughter",
183
- "holding_controller",
184
- "holding_game_controller",
185
- "baby",
186
- "heart_censor",
187
- "pixiv_username",
188
- "korean_text",
189
- "pixiv_logo",
190
- "greyscale_with_colored_background",
191
- "water_bottle",
192
- "body_writing",
193
- "used_condom",
194
- "multiple_condoms",
195
- "condom_belt",
196
- "holding_phone",
197
- "multiple_views",
198
- "phone",
199
- "cellphone",
200
- "zoom_layer",
201
- "smartphone",
202
- "lolita_hairband",
203
- "lactation",
204
- "otoko_no_ko",
205
- "minigirl",
206
- "babydoll",
207
- "domino_mask",
208
- "pixiv_id",
209
- "qr_code",
210
- "monochrome",
211
- "trick_or_treat",
212
- "happy_birthday",
213
- "lolita_fashion",
214
- "arrow_(symbol)",
215
- "happy_new_year",
216
- "dated",
217
- "thought_bubble",
218
- "greyscale",
219
- "speech_bubble",
220
- "mask",
221
- "bottle",
222
- "holding_bottle",
223
- "milk",
224
- "milk_bottle",
225
- "english_text",
226
- "copyright_name",
227
- "twitter_username",
228
- "fanbox_username",
229
- "patreon_username",
230
- "patreon_logo",
231
- "cover",
232
- "signature",
233
- "content_rating",
234
- "cover_page",
235
- "doujin_cover",
236
- "sex",
237
- "artist_name",
238
- "watermark",
239
- "censored",
240
- "bar_censor",
241
- "blank_censor",
242
- "blur_censor",
243
- "light_censor",
244
  "mosaic_censoring"]
245
 
246
  with gr.Blocks(title=TITLE) as demo:
@@ -248,6 +154,7 @@ def main():
248
  gr.Markdown(DESCRIPTION)
249
 
250
  with gr.Row():
 
251
  with gr.Column():
252
 
253
  submit = gr.Button(
@@ -277,20 +184,37 @@ def main():
277
  placeholder="Add tags to filter out (e.g., winter, red, from above)",
278
  lines=9
279
  )
 
 
 
 
 
 
280
 
281
 
282
  with gr.Column():
283
  output = gr.Textbox(label="Output", lines=10)
284
 
285
- def process_images(files, model_repo, general_thresh, character_thresh, filter_tags):
286
  images = [Image.open(file.name) for file in files]
287
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
288
 
289
  # Parse filter tags
290
  filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  # Generate formatted output
293
- prompts = []
294
  for i, (general_tags, character_tags) in enumerate(results):
295
  # Replace underscores with spaces for both character and general tags
296
  character_part = ", ".join(
@@ -301,17 +225,34 @@ def main():
301
  )
302
 
303
  # Construct the prompt based on the presence of character_part
 
304
  if character_part:
305
- prompts.append(f"{character_part}, {general_part}")
306
  else:
307
- prompts.append(general_part)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
  # Join all prompts with blank lines
310
  return "\n\n".join(prompts)
311
 
312
  submit.click(
313
  process_images,
314
- inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
315
  outputs=output
316
  )
317
 
 
40
  def parse_args() -> argparse.Namespace:
41
  parser = argparse.ArgumentParser()
42
  parser.add_argument("--score-slider-step", type=float, default=0.05)
43
+ parser.add_argument("--score-general-threshold", type=float,
44
+ default=0.25)
45
  parser.add_argument("--score-character-threshold", type=float, default=1.0)
46
  return parser.parse_args()
47
 
 
59
  def download_model(self, model_repo):
60
  csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME, use_auth_token=HF_TOKEN)
61
  model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME, use_auth_token=HF_TOKEN)
62
+
63
  return csv_path, model_path
64
 
65
  def load_model(self, model_repo):
 
72
 
73
  model = rt.InferenceSession(model_path)
74
  _, height, width, _ = model.get_inputs()[0].shape
75
+
76
  self.model_target_size = height
77
  self.last_loaded_repo = model_repo
78
  self.model = model
 
83
 
84
  # Ensure the input image has an alpha channel for compositing
85
  if image.mode != "RGBA":
86
+
87
  image = image.convert("RGBA")
88
 
89
  # Composite the input image onto the canvas
 
94
 
95
  # Resize the image to a square of size (model_target_size x model_target_size)
96
  max_dim = max(image.size)
97
+
98
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
99
  pad_left = (max_dim - image.width) // 2
100
  pad_top = (max_dim - image.height) // 2
 
108
 
109
  def predict(self, images, model_repo, general_thresh, character_thresh):
110
  self.load_model(model_repo)
111
+ results =
112
 
113
  for image in images:
114
  image = self.prepare_image(image)
 
116
  label_name = self.model.get_outputs()[0].name
117
  preds = self.model.run([label_name], {input_name: image})[0]
118
 
119
+
120
  labels = list(zip(self.tag_names, preds[0].astype(float)))
121
  general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh]
122
  character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh]
 
140
  CONV_MODEL_DSV2_REPO,
141
  CONV2_MODEL_DSV2_REPO,
142
  VIT_MODEL_DSV2_REPO,
143
+
144
  # ---
145
  SWINV2_MODEL_IS_DSV1_REPO,
146
  EVA02_LARGE_MODEL_IS_DSV1_REPO,
147
  ]
148
 
149
+ predefined_tags = ["2024",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  "mosaic_censoring"]
151
 
152
  with gr.Blocks(title=TITLE) as demo:
 
154
  gr.Markdown(DESCRIPTION)
155
 
156
  with gr.Row():
157
+
158
  with gr.Column():
159
 
160
  submit = gr.Button(
 
184
  placeholder="Add tags to filter out (e.g., winter, red, from above)",
185
  lines=9
186
  )
187
+
188
+ conditional_tags = gr.Textbox(
189
+ label="Conditional Tag Rules",
190
+ placeholder="Enter tag rules (e.g., sun: hot,day)",
191
+ lines=3,
192
+ )
193
 
194
 
195
  with gr.Column():
196
  output = gr.Textbox(label="Output", lines=10)
197
 
198
+ def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, conditional_tags):
199
  images = [Image.open(file.name) for file in files]
200
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
201
 
202
  # Parse filter tags
203
  filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
204
+
205
+ # Parse conditional tag rules
206
+ tag_rules = {}
207
+ if conditional_tags:
208
+ for rule in conditional_tags.splitlines():
209
+ if ":" in rule:
210
+ trigger_tag, tags_to_add = rule.split(":", 1)
211
+ tag_rules[trigger_tag.strip().lower()] = [
212
+ tag.strip() for tag in tags_to_add.split(",")
213
+ ]
214
+
215
 
216
  # Generate formatted output
217
+ prompts =
218
  for i, (general_tags, character_tags) in enumerate(results):
219
  # Replace underscores with spaces for both character and general tags
220
  character_part = ", ".join(
 
225
  )
226
 
227
  # Construct the prompt based on the presence of character_part
228
+ prompt = ""
229
  if character_part:
230
+ prompt = f"{character_part}, {general_part}"
231
  else:
232
+ prompt = general_part
233
+
234
+ # Apply conditional tag rules
235
+ found_trigger = False
236
+ for trigger_tag, tags_to_add in tag_rules.items():
237
+ if trigger_tag in prompt.lower():
238
+ prompt += ", " + ", ".join(tags_to_add)
239
+ found_trigger = True
240
+ break
241
+
242
+ if not found_trigger:
243
+ for trigger_tag, tags_to_add in tag_rules.items():
244
+ if trigger_tag not in prompt.lower():
245
+ prompt += ", " + ", ".join(tags_to_add)
246
+ break # Only apply the first rule that matches
247
+
248
+ prompts.append(prompt)
249
 
250
  # Join all prompts with blank lines
251
  return "\n\n".join(prompts)
252
 
253
  submit.click(
254
  process_images,
255
+ inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags, conditional_tags],
256
  outputs=output
257
  )
258