ura23 commited on
Commit
d1d64c5
·
verified ·
1 Parent(s): 8d94560

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -87
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import argparse
2
  import os
 
3
  import gradio as gr
4
  import huggingface_hub
5
  import numpy as np
@@ -8,90 +9,40 @@ import pandas as pd
8
  from PIL import Image
9
 
10
  TITLE = "WaifuDiffusion Tagger"
11
- DESCRIPTION = "Demo for the WaifuDiffusion tagger models"
 
 
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
14
 
 
 
 
15
  VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  MODEL_FILENAME = "model.onnx"
17
  LABEL_FILENAME = "selected_tags.csv"
18
 
19
- PREDEFINED_FILTER_TAGS = [
20
- "loli", "oppai_loli", "onee-shota", "incest", "furry", "furry_female", "shota",
21
- "male_focus", "signature", "otoko_no_ko", "minigirl", "patreon_username", "babydoll",
22
- "monochrome", "happy_birthday", "happy_new_year", "thought_bubble", "greyscale",
23
- "speech_bubble", "english_text", "copyright_name", "twitter_username",
24
- "patreon username", "patreon logo", "cover", "content_rating", "cover_page",
25
- "doujin_cover", "sex", "artist_name", "watermark", "censored", "bar_censor",
26
- "blank_censor", "blur_censor", "light_censor", "mosaic_censoring"
27
- ]
28
-
29
- def parse_args():
30
  parser = argparse.ArgumentParser()
31
  parser.add_argument("--score-slider-step", type=float, default=0.05)
32
  parser.add_argument("--score-general-threshold", type=float, default=0.3)
33
  parser.add_argument("--score-character-threshold", type=float, default=1.0)
34
  return parser.parse_args()
35
 
36
- def load_labels(dataframe):
37
- tag_names = dataframe["name"].tolist()
38
- general_indexes = list(np.where(dataframe["category"] == 0)[0])
39
- character_indexes = list(np.where(dataframe["category"] == 4)[0])
40
- return tag_names, general_indexes, character_indexes
41
-
42
- class Predictor:
43
- def __init__(self):
44
- self.model_target_size = None
45
- self.last_loaded_repo = None
46
-
47
- def download_model(self, model_repo):
48
- csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME, use_auth_token=HF_TOKEN)
49
- model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME, use_auth_token=HF_TOKEN)
50
- return csv_path, model_path
51
-
52
- def load_model(self, model_repo):
53
- if model_repo == self.last_loaded_repo:
54
- return
55
-
56
- csv_path, model_path = self.download_model(model_repo)
57
- tags_df = pd.read_csv(csv_path)
58
- self.tag_names, self.general_indexes, self.character_indexes = load_labels(tags_df)
59
-
60
- model = rt.InferenceSession(model_path)
61
- _, height, width, _ = model.get_inputs()[0].shape
62
- self.model_target_size = height
63
- self.last_loaded_repo = model_repo
64
- self.model = model
65
-
66
- def prepare_image(self, image):
67
- max_dim = max(image.size)
68
- padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
69
- pad_left = (max_dim - image.width) // 2
70
- pad_top = (max_dim - image.height) // 2
71
- padded_image.paste(image, (pad_left, pad_top))
72
- padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
73
-
74
- return np.expand_dims(np.asarray(padded_image, dtype=np.float32)[:, :, ::-1], axis=0)
75
-
76
- def predict(self, images, model_repo, general_thresh, character_thresh):
77
- self.load_model(model_repo)
78
- results = []
79
-
80
- for image in images:
81
- image = self.prepare_image(image)
82
- input_name = self.model.get_inputs()[0].name
83
- label_name = self.model.get_outputs()[0].name
84
- preds = self.model.run([label_name], {input_name: image})[0]
85
-
86
- labels = list(zip(self.tag_names, preds[0].astype(float)))
87
- general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh]
88
- character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh]
89
- results.append((general_res, character_res))
90
-
91
- return results
92
-
93
- predictor = Predictor()
94
-
95
  def parse_replacement_rules(rules_text):
96
  rules = {}
97
  for line in rules_text.strip().split("\n"):
@@ -113,23 +64,43 @@ def parse_fallback_rules(fallback_text):
113
 
114
  def apply_replacements(tags, replacement_rules):
115
  tags_set = set(tags)
116
-
117
  for old_tags, new_tags in replacement_rules.items():
118
  if set(old_tags).issubset(tags_set):
119
  tags_set.difference_update(old_tags)
120
  tags_set.update(new_tags)
121
-
122
  return list(tags_set)
123
 
124
  def apply_fallbacks(tags, fallback_rules):
125
  tags_set = set(tags)
126
-
127
  for expected_tags, fallback_tag in fallback_rules.items():
128
  if not any(tag in tags_set for tag in expected_tags):
129
  tags_set.add(fallback_tag)
130
-
131
  return list(tags_set)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, replacement_rules_text, fallback_rules_text):
134
  images = [Image.open(file.name) for file in files]
135
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
@@ -142,34 +113,46 @@ def process_images(files, model_repo, general_thresh, character_thresh, filter_t
142
  for general_tags, character_tags in results:
143
  general_tags = apply_replacements(general_tags, replacement_rules)
144
  character_tags = apply_replacements(character_tags, replacement_rules)
145
-
146
  general_tags = apply_fallbacks(general_tags, fallback_rules)
147
  character_tags = apply_fallbacks(character_tags, fallback_rules)
148
 
149
  general_tags = [tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set]
150
  character_tags = [tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set]
151
 
152
- if character_tags:
153
- prompts.append(f"{', '.join(character_tags)}, {', '.join(general_tags)}")
154
- else:
155
- prompts.append(", ".join(general_tags))
156
 
157
  return "\n\n".join(prompts)
158
 
159
  args = parse_args()
 
 
 
160
 
161
  with gr.Blocks(title=TITLE) as demo:
162
  gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
163
  gr.Markdown(DESCRIPTION)
164
 
165
- with gr.Accordion("Settings", open=False):
166
- filter_tags = gr.Textbox(value=", ".join(PREDEFINED_FILTER_TAGS), label="Filter Tags (comma-separated)", lines=3)
167
- replacement_rules_text = gr.Textbox(label="Replacement Rules", value="1boy -> 1girl", lines=5)
168
- fallback_rules_text = gr.Textbox(label="Fallback Rules", value="sad, happy -> smile", lines=5)
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- submit = gr.Button(value="Process Images")
171
- output = gr.Textbox(label="Output", lines=10)
172
 
173
- submit.click(process_images, inputs=[[], VIT_MODEL_DSV3_REPO, args.score_general_threshold, args.score_character_threshold, filter_tags, replacement_rules_text, fallback_rules_text], outputs=output)
174
 
 
175
  demo.launch()
 
1
  import argparse
2
  import os
3
+
4
  import gradio as gr
5
  import huggingface_hub
6
  import numpy as np
 
9
  from PIL import Image
10
 
11
  TITLE = "WaifuDiffusion Tagger"
12
+ DESCRIPTION = """
13
+ Demo for the WaifuDiffusion tagger models
14
+ """
15
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
17
 
18
+ # Dataset v3 series of models:
19
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
20
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
21
  VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
22
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
+
25
+ # Dataset v2 series of models:
26
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
+
32
+ # IdolSankaku series of models:
33
+ EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
34
+ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
35
+
36
  MODEL_FILENAME = "model.onnx"
37
  LABEL_FILENAME = "selected_tags.csv"
38
 
39
+ def parse_args() -> argparse.Namespace:
 
 
 
 
 
 
 
 
 
 
40
  parser = argparse.ArgumentParser()
41
  parser.add_argument("--score-slider-step", type=float, default=0.05)
42
  parser.add_argument("--score-general-threshold", type=float, default=0.3)
43
  parser.add_argument("--score-character-threshold", type=float, default=1.0)
44
  return parser.parse_args()
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def parse_replacement_rules(rules_text):
47
  rules = {}
48
  for line in rules_text.strip().split("\n"):
 
64
 
65
  def apply_replacements(tags, replacement_rules):
66
  tags_set = set(tags)
 
67
  for old_tags, new_tags in replacement_rules.items():
68
  if set(old_tags).issubset(tags_set):
69
  tags_set.difference_update(old_tags)
70
  tags_set.update(new_tags)
 
71
  return list(tags_set)
72
 
73
  def apply_fallbacks(tags, fallback_rules):
74
  tags_set = set(tags)
 
75
  for expected_tags, fallback_tag in fallback_rules.items():
76
  if not any(tag in tags_set for tag in expected_tags):
77
  tags_set.add(fallback_tag)
 
78
  return list(tags_set)
79
 
80
+ class Predictor:
81
+ def __init__(self):
82
+ self.model_target_size = None
83
+ self.last_loaded_repo = None
84
+
85
+ def download_model(self, model_repo):
86
+ csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME, use_auth_token=HF_TOKEN)
87
+ model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME, use_auth_token=HF_TOKEN)
88
+ return csv_path, model_path
89
+
90
+ def load_model(self, model_repo):
91
+ if model_repo == self.last_loaded_repo:
92
+ return
93
+
94
+ csv_path, model_path = self.download_model(model_repo)
95
+ tags_df = pd.read_csv(csv_path)
96
+ self.tag_names, self.general_indexes, self.character_indexes = tags_df["name"].tolist(), list(np.where(tags_df["category"] == 0)[0]), list(np.where(tags_df["category"] == 4)[0])
97
+
98
+ model = rt.InferenceSession(model_path)
99
+ _, height, width, _ = model.get_inputs()[0].shape
100
+ self.model_target_size = height
101
+ self.last_loaded_repo = model_repo
102
+ self.model = model
103
+
104
  def process_images(files, model_repo, general_thresh, character_thresh, filter_tags, replacement_rules_text, fallback_rules_text):
105
  images = [Image.open(file.name) for file in files]
106
  results = predictor.predict(images, model_repo, general_thresh, character_thresh)
 
113
  for general_tags, character_tags in results:
114
  general_tags = apply_replacements(general_tags, replacement_rules)
115
  character_tags = apply_replacements(character_tags, replacement_rules)
 
116
  general_tags = apply_fallbacks(general_tags, fallback_rules)
117
  character_tags = apply_fallbacks(character_tags, fallback_rules)
118
 
119
  general_tags = [tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set]
120
  character_tags = [tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set]
121
 
122
+ prompts.append(f"{', '.join(character_tags)}, {', '.join(general_tags)}" if character_tags else ", ".join(general_tags))
 
 
 
123
 
124
  return "\n\n".join(prompts)
125
 
126
  args = parse_args()
127
+ predictor = Predictor()
128
+
129
+ model_repos = [SWINV2_MODEL_DSV3_REPO, CONV_MODEL_DSV3_REPO, VIT_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO, EVA02_LARGE_MODEL_DSV3_REPO, MOAT_MODEL_DSV2_REPO, SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO, CONV2_MODEL_DSV2_REPO, VIT_MODEL_DSV2_REPO, SWINV2_MODEL_IS_DSV1_REPO, EVA02_LARGE_MODEL_IS_DSV1_REPO]
130
 
131
  with gr.Blocks(title=TITLE) as demo:
132
  gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
133
  gr.Markdown(DESCRIPTION)
134
 
135
+ with gr.Row():
136
+ with gr.Column():
137
+ image_files = gr.File(file_types=["image"], label="Upload Images", file_count="multiple")
138
+
139
+ with gr.Accordion("Advanced Settings", open=False):
140
+ model_repo = gr.Dropdown(model_repos, value=VIT_MODEL_DSV3_REPO, label="Select Model")
141
+ general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold")
142
+ character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold")
143
+
144
+ submit = gr.Button(value="Process Images", variant="primary")
145
+
146
+ with gr.Column():
147
+ output = gr.Textbox(label="Output", lines=10)
148
+
149
+ with gr.Accordion("Tag Replacements", open=False):
150
+ replacement_rules_text = gr.Textbox(label="Replacement Rules", lines=5)
151
 
152
+ with gr.Accordion("Fallback Rules", open=False):
153
+ fallback_rules_text = gr.Textbox(label="Fallback Rules", lines=5)
154
 
155
+ submit.click(process_images, inputs=[image_files, model_repo, general_thresh, character_thresh, replacement_rules_text, fallback_rules_text], outputs=output)
156
 
157
+ demo.queue()
158
  demo.launch()