ura23 commited on
Commit
4803c81
·
verified ·
1 Parent(s): 9ccbf15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -51
app.py CHANGED
@@ -40,8 +40,7 @@ LABEL_FILENAME = "selected_tags.csv"
40
  def parse_args() -> argparse.Namespace:
41
  parser = argparse.ArgumentParser()
42
  parser.add_argument("--score-slider-step", type=float, default=0.05)
43
- parser.add_argument("--score-general-threshold", type=float,
44
- default=0.25)
45
  parser.add_argument("--score-character-threshold", type=float, default=1.0)
46
  return parser.parse_args()
47
 
@@ -59,7 +58,6 @@ class Predictor:
59
  def download_model(self, model_repo):
60
  csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME, use_auth_token=HF_TOKEN)
61
  model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME, use_auth_token=HF_TOKEN)
62
-
63
  return csv_path, model_path
64
 
65
  def load_model(self, model_repo):
@@ -72,7 +70,6 @@ class Predictor:
72
 
73
  model = rt.InferenceSession(model_path)
74
  _, height, width, _ = model.get_inputs()[0].shape
75
-
76
  self.model_target_size = height
77
  self.last_loaded_repo = model_repo
78
  self.model = model
@@ -83,7 +80,6 @@ class Predictor:
83
 
84
  # Ensure the input image has an alpha channel for compositing
85
  if image.mode != "RGBA":
86
-
87
  image = image.convert("RGBA")
88
 
89
  # Composite the input image onto the canvas
@@ -94,7 +90,6 @@ class Predictor:
94
 
95
  # Resize the image to a square of size (model_target_size x model_target_size)
96
  max_dim = max(image.size)
97
-
98
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
99
  pad_left = (max_dim - image.width) // 2
100
  pad_top = (max_dim - image.height) // 2
@@ -108,7 +103,7 @@ class Predictor:
108
 
109
  def predict(self, images, model_repo, general_thresh, character_thresh):
110
  self.load_model(model_repo)
111
- results =
112
 
113
  for image in images:
114
  image = self.prepare_image(image)
@@ -116,10 +111,9 @@ class Predictor:
116
  label_name = self.model.get_outputs()[0].name
117
  preds = self.model.run([label_name], {input_name: image})[0]
118
 
119
-
120
  labels = list(zip(self.tag_names, preds[0].astype(float)))
121
  general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh]
122
- character_res = [x[0] for i, x in enumerate(labels) if i in the character_indexes and x[1] > character_thresh]
123
  results.append((general_res, character_res))
124
 
125
  return results
@@ -140,13 +134,113 @@ def main():
140
  CONV_MODEL_DSV2_REPO,
141
  CONV2_MODEL_DSV2_REPO,
142
  VIT_MODEL_DSV2_REPO,
143
-
144
  # ---
145
  SWINV2_MODEL_IS_DSV1_REPO,
146
  EVA02_LARGE_MODEL_IS_DSV1_REPO,
147
  ]
148
 
149
- predefined_tags = ["2024",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  "mosaic_censoring"]
151
 
152
  with gr.Blocks(title=TITLE) as demo:
@@ -154,7 +248,6 @@ def main():
154
  gr.Markdown(DESCRIPTION)
155
 
156
  with gr.Row():
157
-
158
  with gr.Column():
159
 
160
  submit = gr.Button(
@@ -184,37 +277,20 @@ def main():
184
  placeholder="Add tags to filter out (e.g., winter, red, from above)",
185
  lines=9
186
  )
187
-
188
- conditional_tags = gr.Textbox(
189
- label="Conditional Tag Rules",
190
- placeholder="Enter tag rules (e.g., sun: hot,day)",
191
- lines=3,
192
- )
193
 
194
 
195
  with gr.Column():
196
  output = gr.Textbox(label="Output", lines=10)
197
 
198
- def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, conditional_tags):
199
  images = [Image.open(file.name) for file in files]
200
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
201
 
202
  # Parse filter tags
203
  filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
204
-
205
- # Parse conditional tag rules
206
- tag_rules = {}
207
- if conditional_tags:
208
- for rule in conditional_tags.splitlines():
209
- if ":" in rule:
210
- trigger_tag, tags_to_add = rule.split(":", 1)
211
- tag_rules[trigger_tag.strip().lower()] = [
212
- tag.strip() for tag in tags_to_add.split(",")
213
- ]
214
-
215
 
216
  # Generate formatted output
217
- prompts =
218
  for i, (general_tags, character_tags) in enumerate(results):
219
  # Replace underscores with spaces for both character and general tags
220
  character_part = ", ".join(
@@ -225,34 +301,17 @@ def main():
225
  )
226
 
227
  # Construct the prompt based on the presence of character_part
228
- prompt = ""
229
  if character_part:
230
- prompt = f"{character_part}, {general_part}"
231
  else:
232
- prompt = general_part
233
-
234
- # Apply conditional tag rules
235
- found_trigger = False
236
- for trigger_tag, tags_to_add in tag_rules.items():
237
- if trigger_tag in prompt.lower():
238
- prompt += ", " + ", ".join(tags_to_add)
239
- found_trigger = True
240
- break
241
-
242
- if not found_trigger:
243
- for trigger_tag, tags_to_add in tag_rules.items():
244
- if trigger_tag not in prompt.lower():
245
- prompt += ", " + ", ".join(tags_to_add)
246
- break # Only apply the first rule that matches
247
-
248
- prompts.append(prompt)
249
 
250
  # Join all prompts with blank lines
251
  return "\n\n".join(prompts)
252
 
253
  submit.click(
254
  process_images,
255
- inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags, conditional_tags],
256
  outputs=output
257
  )
258
 
 
40
  def parse_args() -> argparse.Namespace:
41
  parser = argparse.ArgumentParser()
42
  parser.add_argument("--score-slider-step", type=float, default=0.05)
43
+ parser.add_argument("--score-general-threshold", type=float, default=0.25)
 
44
  parser.add_argument("--score-character-threshold", type=float, default=1.0)
45
  return parser.parse_args()
46
 
 
58
  def download_model(self, model_repo):
59
  csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME, use_auth_token=HF_TOKEN)
60
  model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME, use_auth_token=HF_TOKEN)
 
61
  return csv_path, model_path
62
 
63
  def load_model(self, model_repo):
 
70
 
71
  model = rt.InferenceSession(model_path)
72
  _, height, width, _ = model.get_inputs()[0].shape
 
73
  self.model_target_size = height
74
  self.last_loaded_repo = model_repo
75
  self.model = model
 
80
 
81
  # Ensure the input image has an alpha channel for compositing
82
  if image.mode != "RGBA":
 
83
  image = image.convert("RGBA")
84
 
85
  # Composite the input image onto the canvas
 
90
 
91
  # Resize the image to a square of size (model_target_size x model_target_size)
92
  max_dim = max(image.size)
 
93
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
94
  pad_left = (max_dim - image.width) // 2
95
  pad_top = (max_dim - image.height) // 2
 
103
 
104
  def predict(self, images, model_repo, general_thresh, character_thresh):
105
  self.load_model(model_repo)
106
+ results = []
107
 
108
  for image in images:
109
  image = self.prepare_image(image)
 
111
  label_name = self.model.get_outputs()[0].name
112
  preds = self.model.run([label_name], {input_name: image})[0]
113
 
 
114
  labels = list(zip(self.tag_names, preds[0].astype(float)))
115
  general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh]
116
+ character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh]
117
  results.append((general_res, character_res))
118
 
119
  return results
 
134
  CONV_MODEL_DSV2_REPO,
135
  CONV2_MODEL_DSV2_REPO,
136
  VIT_MODEL_DSV2_REPO,
 
137
  # ---
138
  SWINV2_MODEL_IS_DSV1_REPO,
139
  EVA02_LARGE_MODEL_IS_DSV1_REPO,
140
  ]
141
 
142
+ predefined_tags = ["loli",
143
+ "oppai_loli",
144
+ "2024",
145
+ "2023",
146
+ "2025",
147
+ "2022",
148
+ "2021",
149
+ "onee-shota",
150
+ "incest",
151
+ "furry",
152
+ "twitter_strip_game_(meme)",
153
+ "like_and_retweet",
154
+ "furry_female",
155
+ "realistic",
156
+ "egg_vibrator",
157
+ "tongue_piercing",
158
+ "handheld_game_console",
159
+ "game_controller",
160
+ "nintendo_switch",
161
+ "talking",
162
+ "swastika",
163
+ "character_name",
164
+ "vibrator",
165
+ "black-framed_eyewear",
166
+ "heterochromia",
167
+ "controller",
168
+ "remote_control_vibrator",
169
+ "vibrator_under_clothes",
170
+ "thank_you",
171
+ "vibrator_cord",
172
+ "shota",
173
+ "male_focus",
174
+ "signature",
175
+ "web_address",
176
+ "censored_nipples",
177
+ "rhodes_island_logo_(arknights)",
178
+ "gothic_lolita",
179
+ "glasses",
180
+ "reference_inset",
181
+ "twitter_logo",
182
+ "mother_and_daughter",
183
+ "holding_controller",
184
+ "holding_game_controller",
185
+ "baby",
186
+ "heart_censor",
187
+ "pixiv_username",
188
+ "korean_text",
189
+ "pixiv_logo",
190
+ "greyscale_with_colored_background",
191
+ "water_bottle",
192
+ "body_writing",
193
+ "used_condom",
194
+ "multiple_condoms",
195
+ "condom_belt",
196
+ "holding_phone",
197
+ "multiple_views",
198
+ "phone",
199
+ "cellphone",
200
+ "zoom_layer",
201
+ "smartphone",
202
+ "lolita_hairband",
203
+ "lactation",
204
+ "otoko_no_ko",
205
+ "minigirl",
206
+ "babydoll",
207
+ "domino_mask",
208
+ "pixiv_id",
209
+ "qr_code",
210
+ "monochrome",
211
+ "trick_or_treat",
212
+ "happy_birthday",
213
+ "lolita_fashion",
214
+ "arrow_(symbol)",
215
+ "happy_new_year",
216
+ "dated",
217
+ "thought_bubble",
218
+ "greyscale",
219
+ "speech_bubble",
220
+ "mask",
221
+ "bottle",
222
+ "holding_bottle",
223
+ "milk",
224
+ "milk_bottle",
225
+ "english_text",
226
+ "copyright_name",
227
+ "twitter_username",
228
+ "fanbox_username",
229
+ "patreon_username",
230
+ "patreon_logo",
231
+ "cover",
232
+ "signature",
233
+ "content_rating",
234
+ "cover_page",
235
+ "doujin_cover",
236
+ "sex",
237
+ "artist_name",
238
+ "watermark",
239
+ "censored",
240
+ "bar_censor",
241
+ "blank_censor",
242
+ "blur_censor",
243
+ "light_censor",
244
  "mosaic_censoring"]
245
 
246
  with gr.Blocks(title=TITLE) as demo:
 
248
  gr.Markdown(DESCRIPTION)
249
 
250
  with gr.Row():
 
251
  with gr.Column():
252
 
253
  submit = gr.Button(
 
277
  placeholder="Add tags to filter out (e.g., winter, red, from above)",
278
  lines=9
279
  )
 
 
 
 
 
 
280
 
281
 
282
  with gr.Column():
283
  output = gr.Textbox(label="Output", lines=10)
284
 
285
+ def process_images(files, model_repo, general_thresh, character_thresh, filter_tags):
286
  images = [Image.open(file.name) for file in files]
287
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
288
 
289
  # Parse filter tags
290
  filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  # Generate formatted output
293
+ prompts = []
294
  for i, (general_tags, character_tags) in enumerate(results):
295
  # Replace underscores with spaces for both character and general tags
296
  character_part = ", ".join(
 
301
  )
302
 
303
  # Construct the prompt based on the presence of character_part
 
304
  if character_part:
305
+ prompts.append(f"{character_part}, {general_part}")
306
  else:
307
+ prompts.append(general_part)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
  # Join all prompts with blank lines
310
  return "\n\n".join(prompts)
311
 
312
  submit.click(
313
  process_images,
314
+ inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
315
  outputs=output
316
  )
317