Gainward777 commited on
Commit
465f026
·
verified ·
1 Parent(s): f0e6dd6

Upload helper.py

Browse files
Files changed (1) hide show
  1. sd/prompt_helper/helper.py +71 -2
sd/prompt_helper/helper.py CHANGED
@@ -1,11 +1,19 @@
1
  import onnx
2
  import onnxruntime as ort
3
  import numpy as np
 
 
 
4
 
5
  VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
6
 
 
 
 
7
 
8
- def preprocess_image(img):
 
 
9
  bgr_img = np.array(img)[:, :, ::-1].copy()
10
  size = max(bgr_img.shape[0:2])
11
  pad_x = size - bgr_img.shape[1]
@@ -20,4 +28,65 @@ def preprocess_image(img):
20
  interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
21
  bgr_img = cv2.resize(bgr_img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
22
 
23
- bgr_img = bgr_img.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import onnx
2
  import onnxruntime as ort
3
  import numpy as np
4
+ import cv2
5
+ import os
6
+ import csv
7
 
8
  VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
9
 
10
+ MODEL_DIR='sd/prompt_helper/model'
11
+ ORT_SESSION=ort.InferenceSession(f'{MODEL_DIR}/model.onnx', providers=['CPUExecutionProvider'])
12
+ #ORT_INPUT_NAME=ort_session.get_inputs()[0].name
13
 
14
+ IMAGE_SIZE = 448
15
+
16
+ def img_preporation(img):
17
  bgr_img = np.array(img)[:, :, ::-1].copy()
18
  size = max(bgr_img.shape[0:2])
19
  pad_x = size - bgr_img.shape[1]
 
28
  interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
29
  bgr_img = cv2.resize(bgr_img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
30
 
31
+ bgr_img = bgr_img.astype(np.float32)
32
+
33
+
34
+ def analysis(img):
35
+ # = model[0]
36
+ ort_input_name = ORT_SESSION.get_inputs()[0].name
37
+
38
+ with open(os.path.join(MODEL_DIR, "selected_tags.csv"), "r", encoding="utf-8") as f:
39
+ reader = csv.reader(f)
40
+ l = [row for row in reader]
41
+ header = l[0] # tag_id,name,category,count
42
+ rows = l[1:]
43
+ assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
44
+
45
+ general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
46
+ character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
47
+
48
+ tag_freq = {}
49
+ undesired_tags = ["transparent background"]
50
+
51
+ #img = Image.open(image_path)
52
+ preped_img = img_preporation(img)
53
+ preped_img = np.expand_dims(preped_img, axis=0)
54
+
55
+ # Run inference
56
+ prob = ORT_SESSION.run(None, {ort_input_name: preped_img})[0][0]
57
+ # Generate Tags
58
+ combined_tags = []
59
+ general_tag_text = ""
60
+ character_tag_text = ""
61
+ remove_underscore = True
62
+ caption_separator = ", "
63
+ general_threshold = 0.35
64
+ character_threshold = 0.35
65
+
66
+ for i, p in enumerate(prob[4:]):
67
+ if i < len(general_tags) and p >= general_threshold:
68
+ tag_name = general_tags[i]
69
+ if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
70
+ tag_name = tag_name.replace("_", " ")
71
+
72
+ if tag_name not in undesired_tags:
73
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
74
+ general_tag_text += caption_separator + tag_name
75
+ combined_tags.append(tag_name)
76
+ elif i >= len(general_tags) and p >= character_threshold:
77
+ tag_name = character_tags[i - len(general_tags)]
78
+ if remove_underscore and len(tag_name) > 3:
79
+ tag_name = tag_name.replace("_", " ")
80
+
81
+ if tag_name not in undesired_tags:
82
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
83
+ character_tag_text += caption_separator + tag_name
84
+ combined_tags.append(tag_name)
85
+
86
+ # Remove leading comma
87
+ if len(general_tag_text) > 0:
88
+ general_tag_text = general_tag_text[len(caption_separator) :]
89
+ if len(character_tag_text) > 0:
90
+ character_tag_text = character_tag_text[len(caption_separator) :]
91
+ tag_text = caption_separator.join(combined_tags)
92
+ return tag_text