ura23 commited on
Commit
0cb1e93
·
verified ·
1 Parent(s): 6e17304

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -50
app.py CHANGED
@@ -17,12 +17,21 @@ HF_TOKEN = os.environ.get("HF_TOKEN", "")
17
 
18
  # Dataset v3 series of models:
19
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
 
 
 
 
20
 
21
  # Dataset v2 series of models:
22
  MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
 
 
 
 
23
 
24
  # IdolSankaku series of models:
25
  EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
 
26
 
27
  # Files to download from the repos
28
  MODEL_FILENAME = "model.onnx"
@@ -41,25 +50,6 @@ def load_labels(dataframe) -> list[str]:
41
  character_indexes = list(np.where(dataframe["category"] == 4)[0])
42
  return tag_names, general_indexes, character_indexes
43
 
44
- def parse_replacements(replacement_text):
45
- replacements = {}
46
- for line in replacement_text.strip().split("\n"):
47
- parts = line.split("->")
48
- if len(parts) == 2:
49
- old_tags = tuple(tag.strip().lower() for tag in parts[0].split(","))
50
- new_tags = [tag.strip() for tag in parts[1].split(",")]
51
- replacements[old_tags] = new_tags
52
- return replacements
53
-
54
- def apply_replacements(tags, replacements):
55
- modified_tags = set(tags)
56
- for old_tags, new_tags in replacements.items():
57
- if all(tag in modified_tags for tag in old_tags):
58
- for tag in old_tags:
59
- modified_tags.discard(tag)
60
- modified_tags.update(new_tags)
61
- return list(modified_tags)
62
-
63
  class Predictor:
64
  def __init__(self):
65
  self.model_target_size = None
@@ -85,10 +75,20 @@ class Predictor:
85
  self.model = model
86
 
87
  def prepare_image(self, image):
 
 
 
 
88
  if image.mode != "RGBA":
89
  image = image.convert("RGBA")
90
- image = image.convert("RGB")
 
 
 
 
 
91
 
 
92
  max_dim = max(image.size)
93
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
94
  pad_left = (max_dim - image.width) // 2
@@ -96,12 +96,15 @@ class Predictor:
96
  padded_image.paste(image, (pad_left, pad_top))
97
  padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
98
 
 
99
  image_array = np.asarray(padded_image, dtype=np.float32)[:, :, ::-1]
100
  return np.expand_dims(image_array, axis=0)
101
 
 
102
  def predict(self, images, model_repo, general_thresh, character_thresh):
103
  self.load_model(model_repo)
104
  results = []
 
105
  for image in images:
106
  image = self.prepare_image(image)
107
  input_name = self.model.get_inputs()[0].name
@@ -112,39 +115,142 @@ class Predictor:
112
  general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh]
113
  character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh]
114
  results.append((general_res, character_res))
 
115
  return results
116
 
117
- def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, replacement_text):
118
- images = [Image.open(file.name) for file in files]
119
- results = predictor.predict(images, model_repo, general_thresh, character_thresh)
120
-
121
- filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
122
- replacements = parse_replacements(replacement_text)
123
-
124
- prompts = []
125
- for general_tags, character_tags in results:
126
- character_tags = apply_replacements([tag.replace("_", " ") for tag in character_tags if tag.lower() not in filter_set], replacements)
127
- general_tags = apply_replacements([tag.replace("_", " ") for tag in general_tags if tag.lower() not in filter_set], replacements)
128
- prompt = ", ".join(character_tags + general_tags)
129
- prompts.append(prompt)
130
-
131
- return "\n\n".join(prompts)
132
 
133
- predictor = Predictor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- with gr.Blocks(title=TITLE) as demo:
136
- gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
137
- gr.Markdown(DESCRIPTION)
138
-
139
- with gr.Row():
140
- with gr.Column():
141
- image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
142
- replacement_text = gr.Textbox(label="Tag Replacements", placeholder="e.g., 1boy -> 1girl\nwinter, indoors -> summer, outdoors", lines=5)
143
- submit = gr.Button("Process Images", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- with gr.Column():
146
- output = gr.Textbox(label="Output", lines=10)
147
-
148
- submit.click(process_images, inputs=[image_files, replacement_text], outputs=output)
149
-
 
 
 
 
 
 
 
 
 
 
 
150
  demo.launch()
 
 
 
 
17
 
18
  # Dataset v3 series of models:
19
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
20
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
21
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
22
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
 
25
  # Dataset v2 series of models:
26
  MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
 
32
  # IdolSankaku series of models:
33
  EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
34
+ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
35
 
36
  # Files to download from the repos
37
  MODEL_FILENAME = "model.onnx"
 
50
  character_indexes = list(np.where(dataframe["category"] == 4)[0])
51
  return tag_names, general_indexes, character_indexes
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  class Predictor:
54
  def __init__(self):
55
  self.model_target_size = None
 
75
  self.model = model
76
 
77
  def prepare_image(self, image):
78
+ # Create a white canvas with the same size as the input image
79
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
80
+
81
+ # Ensure the input image has an alpha channel for compositing
82
  if image.mode != "RGBA":
83
  image = image.convert("RGBA")
84
+
85
+ # Composite the input image onto the canvas
86
+ canvas.alpha_composite(image)
87
+
88
+ # Convert to RGB (alpha channel is no longer needed)
89
+ image = canvas.convert("RGB")
90
 
91
+ # Resize the image to a square of size (model_target_size x model_target_size)
92
  max_dim = max(image.size)
93
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
94
  pad_left = (max_dim - image.width) // 2
 
96
  padded_image.paste(image, (pad_left, pad_top))
97
  padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
98
 
99
+ # Convert the image to a NumPy array
100
  image_array = np.asarray(padded_image, dtype=np.float32)[:, :, ::-1]
101
  return np.expand_dims(image_array, axis=0)
102
 
103
+
104
  def predict(self, images, model_repo, general_thresh, character_thresh):
105
  self.load_model(model_repo)
106
  results = []
107
+
108
  for image in images:
109
  image = self.prepare_image(image)
110
  input_name = self.model.get_inputs()[0].name
 
115
  general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh]
116
  character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh]
117
  results.append((general_res, character_res))
118
+
119
  return results
120
 
121
+ def main():
122
+ args = parse_args()
123
+ predictor = Predictor()
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ model_repos = [
126
+ SWINV2_MODEL_DSV3_REPO,
127
+ CONV_MODEL_DSV3_REPO,
128
+ VIT_MODEL_DSV3_REPO,
129
+ VIT_LARGE_MODEL_DSV3_REPO,
130
+ EVA02_LARGE_MODEL_DSV3_REPO,
131
+ # ---
132
+ MOAT_MODEL_DSV2_REPO,
133
+ SWIN_MODEL_DSV2_REPO,
134
+ CONV_MODEL_DSV2_REPO,
135
+ CONV2_MODEL_DSV2_REPO,
136
+ VIT_MODEL_DSV2_REPO,
137
+ # ---
138
+ SWINV2_MODEL_IS_DSV1_REPO,
139
+ EVA02_LARGE_MODEL_IS_DSV1_REPO,
140
+ ]
141
 
142
+ predefined_tags = ["loli",
143
+ "oppai_loli",
144
+ "onee-shota",
145
+ "incest",
146
+ "furry",
147
+ "furry_female",
148
+ "shota",
149
+ "male_focus",
150
+ "signature",
151
+ "lolita_hairband",
152
+ "otoko_no_ko",
153
+ "minigirl",
154
+ "patreon_username",
155
+ "babydoll",
156
+ "monochrome",
157
+ "happy_birthday",
158
+ "happy_new_year",
159
+ "dated",
160
+ "thought_bubble",
161
+ "greyscale",
162
+ "speech_bubble",
163
+ "english_text",
164
+ "copyright_name",
165
+ "twitter_username",
166
+ "patreon username",
167
+ "patreon logo",
168
+ "cover",
169
+ "content_rating"
170
+ "cover_page",
171
+ "doujin_cover",
172
+ "sex",
173
+ "artist_name",
174
+ "watermark",
175
+ "censored",
176
+ "bar_censor",
177
+ "blank_censor",
178
+ "blur_censor",
179
+ "light_censor",
180
+ "mosaic_censoring"]
181
+
182
+ with gr.Blocks(title=TITLE) as demo:
183
+ gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
184
+ gr.Markdown(DESCRIPTION)
185
+
186
+ with gr.Row():
187
+ with gr.Column():
188
+ image_files = gr.File(
189
+ file_types=["image"], label="Upload Images", file_count="multiple",
190
+ )
191
+
192
+ # Wrap the model selection and sliders in an Accordion
193
+ with gr.Accordion("Advanced Settings", open=False): # Collapsible by default
194
+ model_repo = gr.Dropdown(
195
+ model_repos,
196
+ value=VIT_MODEL_DSV3_REPO,
197
+ label="Select Model",
198
+ )
199
+ general_thresh = gr.Slider(
200
+ 0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold"
201
+ )
202
+ character_thresh = gr.Slider(
203
+ 0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold"
204
+ )
205
+ filter_tags = gr.Textbox(
206
+ value=", ".join(predefined_tags),
207
+ label="Filter Tags (comma-separated)",
208
+ placeholder="Add tags to filter out (e.g., winter, red, from above)",
209
+ lines=3
210
+ )
211
+
212
+ submit = gr.Button(
213
+ value="Process Images", variant="primary"
214
+ )
215
+
216
+ with gr.Column():
217
+ output = gr.Textbox(label="Output", lines=10)
218
+
219
+ def process_images(files, model_repo, general_thresh, character_thresh, filter_tags):
220
+ images = [Image.open(file.name) for file in files]
221
+ results = predictor.predict(images, model_repo, general_thresh, character_thresh)
222
+
223
+ # Parse filter tags
224
+ filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
225
+
226
+ # Generate formatted output
227
+ prompts = []
228
+ for i, (general_tags, character_tags) in enumerate(results):
229
+ # Replace underscores with spaces for both character and general tags
230
+ character_part = ", ".join(
231
+ tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set
232
+ )
233
+ general_part = ", ".join(
234
+ tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
235
+ )
236
 
237
+ # Construct the prompt based on the presence of character_part
238
+ if character_part:
239
+ prompts.append(f"{character_part}, {general_part}")
240
+ else:
241
+ prompts.append(general_part)
242
+
243
+ # Join all prompts with blank lines
244
+ return "\n\n".join(prompts)
245
+
246
+ submit.click(
247
+ process_images,
248
+ inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
249
+ outputs=output
250
+ )
251
+
252
+ demo.queue(max_size=10)
253
  demo.launch()
254
+
255
+ if __name__ == "__main__":
256
+ main()