Gainward777 commited on
Commit
a509e1d
·
verified ·
1 Parent(s): 8ea0f99

Update sd/prompt_helper/helper.py

Browse files
Files changed (1) hide show
  1. sd/prompt_helper/helper.py +91 -91
sd/prompt_helper/helper.py CHANGED
@@ -1,92 +1,92 @@
1
- import onnx
2
- import onnxruntime as ort
3
- import numpy as np
4
- import cv2
5
- import os
6
- import csv
7
-
8
- VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
9
-
10
- MODEL_DIR='sd/prompt_helper/model'
11
- ORT_SESSION=ort.InferenceSession(f'{MODEL_DIR}/model.onnx', providers=['CPUExecutionProvider'])
12
- #ORT_INPUT_NAME=ort_session.get_inputs()[0].name
13
-
14
- IMAGE_SIZE = 448
15
-
16
- def img_preporation(img):
17
- bgr_img = np.array(img)[:, :, ::-1].copy()
18
- size = max(bgr_img.shape[0:2])
19
- pad_x = size - bgr_img.shape[1]
20
- pad_y = size - bgr_img.shape[0]
21
- pad_l = pad_x // 2
22
- pad_t = pad_y // 2
23
-
24
- #add paddings to squaring image
25
- np.pad(bgr_img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
26
-
27
- #adaptive resize
28
- interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
29
- bgr_img = cv2.resize(bgr_img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
30
-
31
- bgr_img = bgr_img.astype(np.float32)
32
-
33
-
34
- def get_help(img):
35
- # = model[0]
36
- ort_input_name = ORT_SESSION.get_inputs()[0].name
37
-
38
- with open(os.path.join(MODEL_DIR, "selected_tags.csv"), "r", encoding="utf-8") as f:
39
- reader = csv.reader(f)
40
- l = [row for row in reader]
41
- header = l[0] # tag_id,name,category,count
42
- rows = l[1:]
43
- assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
44
-
45
- general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
46
- character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
47
-
48
- tag_freq = {}
49
- undesired_tags = ["transparent background"]
50
-
51
- #img = Image.open(image_path)
52
- preped_img = img_preporation(img)
53
- preped_img = np.expand_dims(preped_img, axis=0)
54
-
55
- # Run inference
56
- prob = ORT_SESSION.run(None, {ort_input_name: preped_img})[0][0]
57
- # Generate Tags
58
- combined_tags = []
59
- general_tag_text = ""
60
- character_tag_text = ""
61
- remove_underscore = True
62
- caption_separator = ", "
63
- general_threshold = 0.35
64
- character_threshold = 0.35
65
-
66
- for i, p in enumerate(prob[4:]):
67
- if i < len(general_tags) and p >= general_threshold:
68
- tag_name = general_tags[i]
69
- if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
70
- tag_name = tag_name.replace("_", " ")
71
-
72
- if tag_name not in undesired_tags:
73
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
74
- general_tag_text += caption_separator + tag_name
75
- combined_tags.append(tag_name)
76
- elif i >= len(general_tags) and p >= character_threshold:
77
- tag_name = character_tags[i - len(general_tags)]
78
- if remove_underscore and len(tag_name) > 3:
79
- tag_name = tag_name.replace("_", " ")
80
-
81
- if tag_name not in undesired_tags:
82
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
83
- character_tag_text += caption_separator + tag_name
84
- combined_tags.append(tag_name)
85
-
86
- # Remove leading comma
87
- if len(general_tag_text) > 0:
88
- general_tag_text = general_tag_text[len(caption_separator) :]
89
- if len(character_tag_text) > 0:
90
- character_tag_text = character_tag_text[len(caption_separator) :]
91
- tag_text = caption_separator.join(combined_tags)
92
  return tag_text
 
1
+ import onnx
2
+ import onnxruntime as ort
3
+ import numpy as np
4
+ import cv2
5
+ import os
6
+ import csv
7
+
8
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
9
+
10
+ MODEL_DIR='sd/prompt_helper/model'
11
+ ORT_SESSION=ort.InferenceSession(f'{MODEL_DIR}/model.onnx', providers=['CPUExecutionProvider'])
12
+ #ORT_INPUT_NAME=ort_session.get_inputs()[0].name
13
+
14
+ IMAGE_SIZE = 448
15
+
16
+ def img_preporation(img):
17
+ bgr_img = np.array(img)[:, :, ::-1].copy()
18
+ size = max(bgr_img.shape[0:2])
19
+ pad_x = size - bgr_img.shape[1]
20
+ pad_y = size - bgr_img.shape[0]
21
+ pad_l = pad_x // 2
22
+ pad_t = pad_y // 2
23
+
24
+ #add paddings to squaring image
25
+ np.pad(bgr_img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
26
+
27
+ #adaptive resize
28
+ interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
29
+ bgr_img = cv2.resize(bgr_img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
30
+
31
+ return bgr_img.astype(np.float32)
32
+
33
+
34
+ def get_help(img):
35
+ # = model[0]
36
+ ort_input_name = ORT_SESSION.get_inputs()[0].name
37
+
38
+ with open(os.path.join(MODEL_DIR, "selected_tags.csv"), "r", encoding="utf-8") as f:
39
+ reader = csv.reader(f)
40
+ l = [row for row in reader]
41
+ header = l[0] # tag_id,name,category,count
42
+ rows = l[1:]
43
+ assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
44
+
45
+ general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
46
+ character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
47
+
48
+ tag_freq = {}
49
+ undesired_tags = ["transparent background"]
50
+
51
+ #img = Image.open(image_path)
52
+ preped_img = img_preporation(img)
53
+ preped_img = np.expand_dims(preped_img, axis=0)
54
+
55
+ # Run inference
56
+ prob = ORT_SESSION.run(None, {ort_input_name: preped_img})[0][0]
57
+ # Generate Tags
58
+ combined_tags = []
59
+ general_tag_text = ""
60
+ character_tag_text = ""
61
+ remove_underscore = True
62
+ caption_separator = ", "
63
+ general_threshold = 0.35
64
+ character_threshold = 0.35
65
+
66
+ for i, p in enumerate(prob[4:]):
67
+ if i < len(general_tags) and p >= general_threshold:
68
+ tag_name = general_tags[i]
69
+ if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
70
+ tag_name = tag_name.replace("_", " ")
71
+
72
+ if tag_name not in undesired_tags:
73
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
74
+ general_tag_text += caption_separator + tag_name
75
+ combined_tags.append(tag_name)
76
+ elif i >= len(general_tags) and p >= character_threshold:
77
+ tag_name = character_tags[i - len(general_tags)]
78
+ if remove_underscore and len(tag_name) > 3:
79
+ tag_name = tag_name.replace("_", " ")
80
+
81
+ if tag_name not in undesired_tags:
82
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
83
+ character_tag_text += caption_separator + tag_name
84
+ combined_tags.append(tag_name)
85
+
86
+ # Remove leading comma
87
+ if len(general_tag_text) > 0:
88
+ general_tag_text = general_tag_text[len(caption_separator) :]
89
+ if len(character_tag_text) > 0:
90
+ character_tag_text = character_tag_text[len(caption_separator) :]
91
+ tag_text = caption_separator.join(combined_tags)
92
  return tag_text