ura23 commited on
Commit
6e17304
·
verified ·
1 Parent(s): fe45257

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -200
app.py CHANGED
@@ -17,21 +17,12 @@ HF_TOKEN = os.environ.get("HF_TOKEN", "")
17
 
18
  # Dataset v3 series of models:
19
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
20
- CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
21
- VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
22
- VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
- EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
 
25
  # Dataset v2 series of models:
26
  MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
- SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
- CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
- CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
- VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
 
32
  # IdolSankaku series of models:
33
  EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
34
- SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
35
 
36
  # Files to download from the repos
37
  MODEL_FILENAME = "model.onnx"
@@ -50,6 +41,25 @@ def load_labels(dataframe) -> list[str]:
50
  character_indexes = list(np.where(dataframe["category"] == 4)[0])
51
  return tag_names, general_indexes, character_indexes
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  class Predictor:
54
  def __init__(self):
55
  self.model_target_size = None
@@ -75,20 +85,10 @@ class Predictor:
75
  self.model = model
76
 
77
  def prepare_image(self, image):
78
- # Create a white canvas with the same size as the input image
79
- canvas = Image.new("RGBA", image.size, (255, 255, 255))
80
-
81
- # Ensure the input image has an alpha channel for compositing
82
  if image.mode != "RGBA":
83
  image = image.convert("RGBA")
84
-
85
- # Composite the input image onto the canvas
86
- canvas.alpha_composite(image)
87
-
88
- # Convert to RGB (alpha channel is no longer needed)
89
- image = canvas.convert("RGB")
90
 
91
- # Resize the image to a square of size (model_target_size x model_target_size)
92
  max_dim = max(image.size)
93
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
94
  pad_left = (max_dim - image.width) // 2
@@ -96,15 +96,12 @@ class Predictor:
96
  padded_image.paste(image, (pad_left, pad_top))
97
  padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
98
 
99
- # Convert the image to a NumPy array
100
  image_array = np.asarray(padded_image, dtype=np.float32)[:, :, ::-1]
101
  return np.expand_dims(image_array, axis=0)
102
 
103
-
104
  def predict(self, images, model_repo, general_thresh, character_thresh):
105
  self.load_model(model_repo)
106
  results = []
107
-
108
  for image in images:
109
  image = self.prepare_image(image)
110
  input_name = self.model.get_inputs()[0].name
@@ -115,186 +112,39 @@ class Predictor:
115
  general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh]
116
  character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh]
117
  results.append((general_res, character_res))
118
-
119
  return results
120
 
121
- def main():
122
- args = parse_args()
123
- predictor = Predictor()
124
-
125
- model_repos = [
126
- SWINV2_MODEL_DSV3_REPO,
127
- CONV_MODEL_DSV3_REPO,
128
- VIT_MODEL_DSV3_REPO,
129
- VIT_LARGE_MODEL_DSV3_REPO,
130
- EVA02_LARGE_MODEL_DSV3_REPO,
131
- # ---
132
- MOAT_MODEL_DSV2_REPO,
133
- SWIN_MODEL_DSV2_REPO,
134
- CONV_MODEL_DSV2_REPO,
135
- CONV2_MODEL_DSV2_REPO,
136
- VIT_MODEL_DSV2_REPO,
137
- # ---
138
- SWINV2_MODEL_IS_DSV1_REPO,
139
- EVA02_LARGE_MODEL_IS_DSV1_REPO,
140
- ]
141
-
142
- predefined_tags = ["loli",
143
- "oppai_loli",
144
- "onee-shota",
145
- "incest",
146
- "furry",
147
- "furry_female",
148
- "shota",
149
- "male_focus",
150
- "signature",
151
- "lolita_hairband",
152
- "otoko_no_ko",
153
- "minigirl",
154
- "patreon_username",
155
- "babydoll",
156
- "monochrome",
157
- "happy_birthday",
158
- "happy_new_year",
159
- "dated",
160
- "thought_bubble",
161
- "greyscale",
162
- "speech_bubble",
163
- "english_text",
164
- "copyright_name",
165
- "twitter_username",
166
- "patreon username",
167
- "patreon logo",
168
- "cover",
169
- "content_rating"
170
- "cover_page",
171
- "doujin_cover",
172
- "sex",
173
- "artist_name",
174
- "watermark",
175
- "censored",
176
- "bar_censor",
177
- "blank_censor",
178
- "blur_censor",
179
- "light_censor",
180
- "mosaic_censoring"]
181
-
182
- with gr.Blocks(title=TITLE) as demo:
183
- gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
184
- gr.Markdown(DESCRIPTION)
185
-
186
- with gr.Row():
187
- with gr.Column():
188
- image_files = gr.File(
189
- file_types=["image"], label="Upload Images", file_count="multiple",
190
- )
191
-
192
- # Wrap the model selection and sliders in an Accordion
193
- with gr.Accordion("Advanced Settings", open=False): # Collapsible by default
194
- model_repo = gr.Dropdown(
195
- model_repos,
196
- value=VIT_MODEL_DSV3_REPO,
197
- label="Select Model",
198
- )
199
- general_thresh = gr.Slider(
200
- 0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold"
201
- )
202
- character_thresh = gr.Slider(
203
- 0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold"
204
- )
205
- filter_tags = gr.Textbox(
206
- value=", ".join(predefined_tags),
207
- label="Filter Tags (comma-separated)",
208
- placeholder="Add tags to filter out (e.g., winter, red, from above)",
209
- lines=3
210
- )
211
-
212
- submit = gr.Button(
213
- value="Process Images", variant="primary"
214
- )
215
-
216
- with gr.Column():
217
- output = gr.Textbox(label="Output", lines=10)
218
-
219
- def parse_replacement_rules(rules_text):
220
- """Parse user-defined tag replacement rules into a dictionary."""
221
- rules = {}
222
- for line in rules_text.strip().split("\n"):
223
- if "->" in line:
224
- old_tags, new_tags = map(str.strip, line.split("->"))
225
- old_tags_list = tuple(map(str.strip, old_tags.lower().split(",")))
226
- new_tags_list = [tag.strip() for tag in new_tags.split(",")]
227
- rules[old_tags_list] = new_tags_list
228
- return rules
229
-
230
- def apply_replacements(tags, replacement_rules):
231
- """Apply replacement rules to a set of tags."""
232
- tags_set = set(tags)
233
-
234
- for old_tags, new_tags in replacement_rules.items():
235
- if set(old_tags).issubset(tags_set): # If all old tags exist in the set
236
- tags_set.difference_update(old_tags) # Remove old tags
237
- tags_set.update(new_tags) # Add new ones
238
-
239
- return list(tags_set)
240
 
241
- def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, replacement_rules_text):
242
- images = [Image.open(file.name) for file in files]
243
- results = predictor.predict(images, model_repo, general_thresh, character_thresh)
244
-
245
- # Parse filter tags
246
- filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
247
-
248
- # Parse user-defined replacements
249
- replacement_rules = parse_replacement_rules(replacement_rules_text)
250
-
251
- # Generate formatted output
252
- prompts = []
253
- for general_tags, character_tags in results:
254
- # Apply replacements
255
- general_tags = apply_replacements(general_tags, replacement_rules)
256
- character_tags = apply_replacements(character_tags, replacement_rules)
257
-
258
- # Remove filtered tags and format
259
- general_tags = [tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set]
260
- character_tags = [tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set]
261
-
262
- # Construct final prompt
263
- if character_tags:
264
- prompts.append(f"{', '.join(character_tags)}, {', '.join(general_tags)}")
265
- else:
266
- prompts.append(", ".join(general_tags))
267
-
268
- return "\n\n".join(prompts)
269
 
270
- # Modify UI to include replacement rules input
271
- with gr.Blocks(title=TITLE) as demo:
272
- gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
273
- gr.Markdown(DESCRIPTION)
274
-
275
- with gr.Row():
276
- with gr.Column():
277
- image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
278
-
279
- with gr.Accordion("Advanced Settings", open=False):
280
- model_repo = gr.Dropdown(model_repos, value=VIT_MODEL_DSV3_REPO, label="Select Model")
281
- general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold")
282
- character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold")
283
- filter_tags = gr.Textbox(value=", ".join(predefined_tags), label="Filter Tags (comma-separated)", lines=3)
284
-
285
- submit = gr.Button(value="Process Images", variant="primary")
286
-
287
- with gr.Column():
288
- output = gr.Textbox(label="Output", lines=10)
289
-
290
- # Separate input for tag replacements
291
- with gr.Accordion("Tag Replacements", open=False):
292
- replacement_rules_text = gr.Textbox(label="Enter replacement rules (one per line)", placeholder="e.g.,\n1boy -> 1girl\nwinter, indoors, living room -> summer, outdoors", lines=5)
293
 
294
- submit.click(process_images, inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags, replacement_rules_text], outputs=output)
295
-
296
- demo.queue(max_size=10)
 
 
297
  demo.launch()
298
-
299
- if __name__ == "__main__":
300
- main()
 
17
 
18
  # Dataset v3 series of models:
19
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
 
 
 
 
20
 
21
  # Dataset v2 series of models:
22
  MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
 
 
 
 
23
 
24
  # IdolSankaku series of models:
25
  EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
 
26
 
27
  # Files to download from the repos
28
  MODEL_FILENAME = "model.onnx"
 
41
  character_indexes = list(np.where(dataframe["category"] == 4)[0])
42
  return tag_names, general_indexes, character_indexes
43
 
44
+ def parse_replacements(replacement_text):
45
+ replacements = {}
46
+ for line in replacement_text.strip().split("\n"):
47
+ parts = line.split("->")
48
+ if len(parts) == 2:
49
+ old_tags = tuple(tag.strip().lower() for tag in parts[0].split(","))
50
+ new_tags = [tag.strip() for tag in parts[1].split(",")]
51
+ replacements[old_tags] = new_tags
52
+ return replacements
53
+
54
+ def apply_replacements(tags, replacements):
55
+ modified_tags = set(tags)
56
+ for old_tags, new_tags in replacements.items():
57
+ if all(tag in modified_tags for tag in old_tags):
58
+ for tag in old_tags:
59
+ modified_tags.discard(tag)
60
+ modified_tags.update(new_tags)
61
+ return list(modified_tags)
62
+
63
  class Predictor:
64
  def __init__(self):
65
  self.model_target_size = None
 
85
  self.model = model
86
 
87
  def prepare_image(self, image):
 
 
 
 
88
  if image.mode != "RGBA":
89
  image = image.convert("RGBA")
90
+ image = image.convert("RGB")
 
 
 
 
 
91
 
 
92
  max_dim = max(image.size)
93
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
94
  pad_left = (max_dim - image.width) // 2
 
96
  padded_image.paste(image, (pad_left, pad_top))
97
  padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
98
 
 
99
  image_array = np.asarray(padded_image, dtype=np.float32)[:, :, ::-1]
100
  return np.expand_dims(image_array, axis=0)
101
 
 
102
  def predict(self, images, model_repo, general_thresh, character_thresh):
103
  self.load_model(model_repo)
104
  results = []
 
105
  for image in images:
106
  image = self.prepare_image(image)
107
  input_name = self.model.get_inputs()[0].name
 
112
  general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh]
113
  character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh]
114
  results.append((general_res, character_res))
 
115
  return results
116
 
117
+ def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, replacement_text):
118
+ images = [Image.open(file.name) for file in files]
119
+ results = predictor.predict(images, model_repo, general_thresh, character_thresh)
120
+
121
+ filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
122
+ replacements = parse_replacements(replacement_text)
123
+
124
+ prompts = []
125
+ for general_tags, character_tags in results:
126
+ character_tags = apply_replacements([tag.replace("_", " ") for tag in character_tags if tag.lower() not in filter_set], replacements)
127
+ general_tags = apply_replacements([tag.replace("_", " ") for tag in general_tags if tag.lower() not in filter_set], replacements)
128
+ prompt = ", ".join(character_tags + general_tags)
129
+ prompts.append(prompt)
130
+
131
+ return "\n\n".join(prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ predictor = Predictor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ with gr.Blocks(title=TITLE) as demo:
136
+ gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
137
+ gr.Markdown(DESCRIPTION)
138
+
139
+ with gr.Row():
140
+ with gr.Column():
141
+ image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
142
+ replacement_text = gr.Textbox(label="Tag Replacements", placeholder="e.g., 1boy -> 1girl\nwinter, indoors -> summer, outdoors", lines=5)
143
+ submit = gr.Button("Process Images", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ with gr.Column():
146
+ output = gr.Textbox(label="Output", lines=10)
147
+
148
+ submit.click(process_images, inputs=[image_files, replacement_text], outputs=output)
149
+
150
  demo.launch()