Update image_to_tag.py
Browse files- image_to_tag.py +88 -32
image_to_tag.py
CHANGED
@@ -7,6 +7,7 @@ import threading
|
|
7 |
import subprocess
|
8 |
import sys
|
9 |
import json
|
|
|
10 |
|
11 |
# Global variables to control the process and track errors
|
12 |
stop_processing = False
|
@@ -43,10 +44,11 @@ total_pages = 1 # Initialize total pages
|
|
43 |
def update_and_save_config():
|
44 |
"""Update and save the configuration to JSON."""
|
45 |
save_config_to_json(
|
46 |
-
model="
|
47 |
general_threshold=general_threshold_var.get(),
|
48 |
character_threshold=character_threshold_var.get(),
|
49 |
-
model_dir="D:/test/models/wd-
|
|
|
50 |
)
|
51 |
|
52 |
def show_errors(root):
|
@@ -77,13 +79,14 @@ def show_errors(root):
|
|
77 |
|
78 |
error_window.protocol("WM_DELETE_WINDOW", on_close_error_window)
|
79 |
|
80 |
-
def save_config_to_json(model, general_threshold, character_threshold, model_dir, filepath='config.json'):
|
81 |
"""Save the model and threshold values to a JSON file."""
|
82 |
config = {
|
83 |
'model': model,
|
84 |
'general_threshold': general_threshold,
|
85 |
'character_threshold': character_threshold,
|
86 |
-
'model_dir': model_dir
|
|
|
87 |
}
|
88 |
try:
|
89 |
with open(filepath, 'w') as f:
|
@@ -95,6 +98,7 @@ def open_image_to_tag():
|
|
95 |
global stop_processing, error_messages, selected_files, save_directory, caption_window, caption_frame, thumbnails, caption_text_widgets, tag_dict, selected_tag, edit_buttons, tag_text_frame, current_page, total_pages, content_canvas
|
96 |
global status_var, num_files_var, errors_var, progress, character_threshold_var, general_threshold_var, thread_count_var, batch_size_var
|
97 |
global start_button, stop_button, prepend_text_var, append_text_var
|
|
|
98 |
|
99 |
# Create Tkinter window
|
100 |
root = tk.Tk()
|
@@ -107,16 +111,17 @@ def open_image_to_tag():
|
|
107 |
progress = tk.IntVar()
|
108 |
character_threshold_var = tk.DoubleVar(value=0.35)
|
109 |
general_threshold_var = tk.DoubleVar(value=0.35)
|
110 |
-
thread_count_var = tk.IntVar(value=
|
111 |
-
batch_size_var = tk.IntVar(value=
|
112 |
prepend_text_var = tk.StringVar()
|
113 |
append_text_var = tk.StringVar()
|
|
|
114 |
q = queue.Queue()
|
115 |
|
116 |
def center_window(window, width_extra=0, height_extra=0):
|
117 |
window.update_idletasks()
|
118 |
width = 100 + width_extra
|
119 |
-
height =
|
120 |
x = (window.winfo_screenwidth() // 2) - (width // 2)
|
121 |
y = (window.winfo_screenheight() // 2) - (height // 2)
|
122 |
window.geometry(f'{width}x{height}+{x}+{y}')
|
@@ -158,6 +163,7 @@ def open_image_to_tag():
|
|
158 |
# Stop button should always be enabled
|
159 |
stop_button.config(state=tk.NORMAL)
|
160 |
|
|
|
161 |
def generate_caption(image_path, save_directory, q):
|
162 |
"""Generate captions for a single image using the wd-swinv2-tagger-v3 model."""
|
163 |
if stop_processing:
|
@@ -165,13 +171,18 @@ def open_image_to_tag():
|
|
165 |
|
166 |
try:
|
167 |
filename = os.path.splitext(os.path.basename(image_path))[0]
|
168 |
-
output_path = os.path.join(save_directory, f"{filename}.txt")
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
command = [
|
171 |
sys.executable, 'D:/test/wdv3-timm-main/wdv3_timm.py',
|
172 |
-
'--model', "
|
173 |
'--image_path', image_path,
|
174 |
-
'--model_dir', "D:/test/models/wd-
|
175 |
'--general_threshold', str(general_threshold_var.get()),
|
176 |
'--character_threshold', str(character_threshold_var.get())
|
177 |
]
|
@@ -183,31 +194,54 @@ def open_image_to_tag():
|
|
183 |
print(output) # In ra đầu ra từ lệnh subprocess
|
184 |
print(error_output) # In ra đầu ra lỗi từ lệnh subprocess
|
185 |
|
186 |
-
#
|
187 |
-
|
188 |
-
|
189 |
for line in output.split('\n'):
|
190 |
if "General tags" in line:
|
191 |
-
|
192 |
continue
|
193 |
-
if
|
194 |
if line.startswith(' '):
|
195 |
tag = line.strip().split(':')[0].replace('_', ' ')
|
196 |
-
|
197 |
else:
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
print("Filtered output:", final_tags) # Debug: In ra các nhãn cuối cùng sau khi lọc
|
204 |
|
205 |
-
#
|
206 |
final_tags = f"{prepend_text_var.get()},{final_tags},{append_text_var.get()}".strip(',')
|
207 |
|
208 |
-
#
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
q.put(image_path)
|
213 |
except Exception as e:
|
@@ -432,7 +466,8 @@ def open_image_to_tag():
|
|
432 |
file_label.grid(row=i*2, column=1, padx=5, pady=5, sticky="nsew")
|
433 |
|
434 |
# Check and display caption if available
|
435 |
-
|
|
|
436 |
if os.path.exists(caption_file):
|
437 |
with open(caption_file, 'r', encoding='utf-8') as file:
|
438 |
caption_text = file.read()
|
@@ -442,7 +477,7 @@ def open_image_to_tag():
|
|
442 |
caption_text_widget = tk.Text(caption_frame, width=50, height=3, wrap=tk.WORD, font=('Helvetica', 12))
|
443 |
caption_text_widget.insert(tk.END, caption_text)
|
444 |
caption_text_widget.grid(row=i*2, column=2, padx=5, pady=5, sticky="nsew")
|
445 |
-
caption_text_widget.bind("<FocusOut>", lambda e, fp=file_path: save_caption(fp,
|
446 |
caption_text_widgets.append(caption_text_widget)
|
447 |
|
448 |
# Update tags in tag_dict
|
@@ -494,7 +529,8 @@ def open_image_to_tag():
|
|
494 |
|
495 |
def save_caption(file_path, caption_text):
|
496 |
"""Save caption when user changes it."""
|
497 |
-
|
|
|
498 |
with open(output_path, 'w', encoding='utf-8') as file:
|
499 |
file.write(caption_text)
|
500 |
|
@@ -580,7 +616,8 @@ def open_image_to_tag():
|
|
580 |
|
581 |
# Update the captions in the respective files
|
582 |
for file_path in selected_files:
|
583 |
-
|
|
|
584 |
if os.path.exists(caption_file):
|
585 |
with open(caption_file, 'r', encoding='utf-8') as file:
|
586 |
caption_text = file.read()
|
@@ -614,7 +651,8 @@ def open_image_to_tag():
|
|
614 |
|
615 |
# Update the captions in the respective files
|
616 |
for i, file_path in enumerate(selected_files):
|
617 |
-
|
|
|
618 |
if os.path.exists(caption_file):
|
619 |
with open(caption_file, 'r', encoding='utf-8') as file:
|
620 |
caption_text = file.read()
|
@@ -644,11 +682,14 @@ def open_image_to_tag():
|
|
644 |
|
645 |
# Update the captions in the respective files
|
646 |
for i, file_path in enumerate(selected_files):
|
647 |
-
|
|
|
648 |
if os.path.exists(caption_file):
|
649 |
with open(caption_file, 'r', encoding='utf-8') as file:
|
650 |
caption_text = file.read()
|
651 |
-
|
|
|
|
|
652 |
with open(caption_file, 'w', encoding='utf-8') as file:
|
653 |
file.write(new_caption_text)
|
654 |
|
@@ -664,7 +705,8 @@ def open_image_to_tag():
|
|
664 |
# Delete the files containing the tag
|
665 |
files_to_delete = []
|
666 |
for i, file_path in enumerate(selected_files):
|
667 |
-
|
|
|
668 |
if os.path.exists(caption_file):
|
669 |
with open(caption_file, 'r', encoding='utf-8') as file:
|
670 |
caption_text = file.read()
|
@@ -740,6 +782,19 @@ def open_image_to_tag():
|
|
740 |
append_text_entry = tk.Entry(root, textvariable=append_text_var, justify='center', width=20)
|
741 |
append_text_entry.pack(pady=5)
|
742 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
743 |
thread_count_label = tk.Label(root, text="Thread Count:")
|
744 |
thread_count_label.pack(pady=5)
|
745 |
thread_count_entry = tk.Entry(root, textvariable=thread_count_var, justify='center', width=5, validate='key')
|
@@ -775,6 +830,7 @@ def open_image_to_tag():
|
|
775 |
general_threshold_var.trace_add('write', lambda *args: update_and_save_config())
|
776 |
character_threshold_var.trace_add('write', lambda *args: update_and_save_config())
|
777 |
thread_count_var.trace_add('write', lambda *args: update_and_save_config())
|
|
|
778 |
|
779 |
center_window(root, width_extra=200)
|
780 |
root.protocol("WM_DELETE_WINDOW", on_closing)
|
|
|
7 |
import subprocess
|
8 |
import sys
|
9 |
import json
|
10 |
+
import re
|
11 |
|
12 |
# Global variables to control the process and track errors
|
13 |
stop_processing = False
|
|
|
44 |
def update_and_save_config():
|
45 |
"""Update and save the configuration to JSON."""
|
46 |
save_config_to_json(
|
47 |
+
model="eva02", # Default model
|
48 |
general_threshold=general_threshold_var.get(),
|
49 |
character_threshold=character_threshold_var.get(),
|
50 |
+
model_dir="D:/test/models/wd-eva02-large-tagger-v3", # Default model directory
|
51 |
+
caption_mode=caption_mode_var.get()
|
52 |
)
|
53 |
|
54 |
def show_errors(root):
|
|
|
79 |
|
80 |
error_window.protocol("WM_DELETE_WINDOW", on_close_error_window)
|
81 |
|
82 |
+
def save_config_to_json(model, general_threshold, character_threshold, model_dir, caption_mode, filepath='config.json'):
|
83 |
"""Save the model and threshold values to a JSON file."""
|
84 |
config = {
|
85 |
'model': model,
|
86 |
'general_threshold': general_threshold,
|
87 |
'character_threshold': character_threshold,
|
88 |
+
'model_dir': model_dir,
|
89 |
+
'caption_mode': caption_mode
|
90 |
}
|
91 |
try:
|
92 |
with open(filepath, 'w') as f:
|
|
|
98 |
global stop_processing, error_messages, selected_files, save_directory, caption_window, caption_frame, thumbnails, caption_text_widgets, tag_dict, selected_tag, edit_buttons, tag_text_frame, current_page, total_pages, content_canvas
|
99 |
global status_var, num_files_var, errors_var, progress, character_threshold_var, general_threshold_var, thread_count_var, batch_size_var
|
100 |
global start_button, stop_button, prepend_text_var, append_text_var
|
101 |
+
global caption_mode_var # Khai báo biến toàn cục
|
102 |
|
103 |
# Create Tkinter window
|
104 |
root = tk.Tk()
|
|
|
111 |
progress = tk.IntVar()
|
112 |
character_threshold_var = tk.DoubleVar(value=0.35)
|
113 |
general_threshold_var = tk.DoubleVar(value=0.35)
|
114 |
+
thread_count_var = tk.IntVar(value=1)
|
115 |
+
batch_size_var = tk.IntVar(value=8)
|
116 |
prepend_text_var = tk.StringVar()
|
117 |
append_text_var = tk.StringVar()
|
118 |
+
caption_mode_var = tk.IntVar(value=1)
|
119 |
q = queue.Queue()
|
120 |
|
121 |
def center_window(window, width_extra=0, height_extra=0):
|
122 |
window.update_idletasks()
|
123 |
width = 100 + width_extra
|
124 |
+
height = 950 + height_extra
|
125 |
x = (window.winfo_screenwidth() // 2) - (width // 2)
|
126 |
y = (window.winfo_screenheight() // 2) - (height // 2)
|
127 |
window.geometry(f'{width}x{height}+{x}+{y}')
|
|
|
163 |
# Stop button should always be enabled
|
164 |
stop_button.config(state=tk.NORMAL)
|
165 |
|
166 |
+
|
167 |
def generate_caption(image_path, save_directory, q):
|
168 |
"""Generate captions for a single image using the wd-swinv2-tagger-v3 model."""
|
169 |
if stop_processing:
|
|
|
171 |
|
172 |
try:
|
173 |
filename = os.path.splitext(os.path.basename(image_path))[0]
|
174 |
+
output_path = os.path.join(save_directory, f"{filename}.txt") # Sửa lại tên tệp caption
|
175 |
+
|
176 |
+
# Kiểm tra chế độ tạo caption
|
177 |
+
if caption_mode_var.get() == 2 and os.path.exists(output_path):
|
178 |
+
q.put(image_path)
|
179 |
+
return
|
180 |
|
181 |
command = [
|
182 |
sys.executable, 'D:/test/wdv3-timm-main/wdv3_timm.py',
|
183 |
+
'--model', "eva02",
|
184 |
'--image_path', image_path,
|
185 |
+
'--model_dir', "D:/test/models/wd-eva02-large-tagger-v3",
|
186 |
'--general_threshold', str(general_threshold_var.get()),
|
187 |
'--character_threshold', str(character_threshold_var.get())
|
188 |
]
|
|
|
194 |
print(output) # In ra đầu ra từ lệnh subprocess
|
195 |
print(error_output) # In ra đầu ra lỗi từ lệnh subprocess
|
196 |
|
197 |
+
# Lọc thông tin "General tags"
|
198 |
+
general_tags = []
|
199 |
+
recording_general = False
|
200 |
for line in output.split('\n'):
|
201 |
if "General tags" in line:
|
202 |
+
recording_general = True
|
203 |
continue
|
204 |
+
if recording_general:
|
205 |
if line.startswith(' '):
|
206 |
tag = line.strip().split(':')[0].replace('_', ' ')
|
207 |
+
general_tags.append(tag)
|
208 |
else:
|
209 |
+
recording_general = False
|
210 |
+
|
211 |
+
# Lọc thông tin "Character tags"
|
212 |
+
character_tags = []
|
213 |
+
recording_character = False
|
214 |
+
for line in output.split('\n'):
|
215 |
+
if "Character tags" in line:
|
216 |
+
recording_character = True
|
217 |
+
continue
|
218 |
+
if recording_character:
|
219 |
+
if line.startswith(' '):
|
220 |
+
tag = line.strip().split(':')[0].replace('_', ' ')
|
221 |
|
222 |
+
# Loại bỏ từ khóa có chứa từ 'costume'
|
223 |
+
if 'costume' not in tag.lower():
|
224 |
+
character_tags.append(tag) # Giữ lại từ khóa không chứa 'costume'
|
225 |
+
else:
|
226 |
+
recording_character = False
|
227 |
+
|
228 |
+
# Kết hợp cả general và character tags
|
229 |
+
final_tags = ','.join(general_tags + character_tags) if general_tags or character_tags else "No tags found"
|
230 |
print("Filtered output:", final_tags) # Debug: In ra các nhãn cuối cùng sau khi lọc
|
231 |
|
232 |
+
# Thêm văn bản trước và sau
|
233 |
final_tags = f"{prepend_text_var.get()},{final_tags},{append_text_var.get()}".strip(',')
|
234 |
|
235 |
+
# Xử lý ghi đè, nối thêm hoặc bỏ qua caption hiện có
|
236 |
+
if caption_mode_var.get() == 0: # Overwrite
|
237 |
+
with open(output_path, 'w', encoding='utf-8') as file:
|
238 |
+
file.write(final_tags)
|
239 |
+
elif caption_mode_var.get() == 1 and os.path.exists(output_path): # Append
|
240 |
+
with open(output_path, 'a', encoding='utf-8') as file:
|
241 |
+
file.write(f",{final_tags}")
|
242 |
+
else: # Tạo mới hoặc ghi đè nếu file không tồn tại
|
243 |
+
with open(output_path, 'w', encoding='utf-8') as file:
|
244 |
+
file.write(final_tags)
|
245 |
|
246 |
q.put(image_path)
|
247 |
except Exception as e:
|
|
|
466 |
file_label.grid(row=i*2, column=1, padx=5, pady=5, sticky="nsew")
|
467 |
|
468 |
# Check and display caption if available
|
469 |
+
filename = os.path.splitext(os.path.basename(file_path))[0] # Lấy tên tệp không có phần mở rộng
|
470 |
+
caption_file = os.path.join(save_directory, f"{filename}.txt")
|
471 |
if os.path.exists(caption_file):
|
472 |
with open(caption_file, 'r', encoding='utf-8') as file:
|
473 |
caption_text = file.read()
|
|
|
477 |
caption_text_widget = tk.Text(caption_frame, width=50, height=3, wrap=tk.WORD, font=('Helvetica', 12))
|
478 |
caption_text_widget.insert(tk.END, caption_text)
|
479 |
caption_text_widget.grid(row=i*2, column=2, padx=5, pady=5, sticky="nsew")
|
480 |
+
caption_text_widget.bind("<FocusOut>", lambda e, fp=file_path, w=caption_text_widget: save_caption(fp, w.get("1.0", "end-1c")))
|
481 |
caption_text_widgets.append(caption_text_widget)
|
482 |
|
483 |
# Update tags in tag_dict
|
|
|
529 |
|
530 |
def save_caption(file_path, caption_text):
|
531 |
"""Save caption when user changes it."""
|
532 |
+
filename = os.path.splitext(os.path.basename(file_path))[0]
|
533 |
+
output_path = os.path.join(save_directory, f"{filename}.txt")
|
534 |
with open(output_path, 'w', encoding='utf-8') as file:
|
535 |
file.write(caption_text)
|
536 |
|
|
|
616 |
|
617 |
# Update the captions in the respective files
|
618 |
for file_path in selected_files:
|
619 |
+
filename = os.path.splitext(os.path.basename(file_path))[0]
|
620 |
+
caption_file = os.path.join(save_directory, f"{filename}.txt")
|
621 |
if os.path.exists(caption_file):
|
622 |
with open(caption_file, 'r', encoding='utf-8') as file:
|
623 |
caption_text = file.read()
|
|
|
651 |
|
652 |
# Update the captions in the respective files
|
653 |
for i, file_path in enumerate(selected_files):
|
654 |
+
filename = os.path.splitext(os.path.basename(file_path))[0]
|
655 |
+
caption_file = os.path.join(save_directory, f"{filename}.txt")
|
656 |
if os.path.exists(caption_file):
|
657 |
with open(caption_file, 'r', encoding='utf-8') as file:
|
658 |
caption_text = file.read()
|
|
|
682 |
|
683 |
# Update the captions in the respective files
|
684 |
for i, file_path in enumerate(selected_files):
|
685 |
+
filename = os.path.splitext(os.path.basename(file_path))[0]
|
686 |
+
caption_file = os.path.join(save_directory, f"{filename}.txt")
|
687 |
if os.path.exists(caption_file):
|
688 |
with open(caption_file, 'r', encoding='utf-8') as file:
|
689 |
caption_text = file.read()
|
690 |
+
# Remove the tag from the caption
|
691 |
+
tags = [tag.strip() for tag in caption_text.split(',') if tag.strip() != tag_to_delete]
|
692 |
+
new_caption_text = ','.join(tags)
|
693 |
with open(caption_file, 'w', encoding='utf-8') as file:
|
694 |
file.write(new_caption_text)
|
695 |
|
|
|
705 |
# Delete the files containing the tag
|
706 |
files_to_delete = []
|
707 |
for i, file_path in enumerate(selected_files):
|
708 |
+
filename = os.path.splitext(os.path.basename(file_path))[0]
|
709 |
+
caption_file = os.path.join(save_directory, f"{filename}.txt")
|
710 |
if os.path.exists(caption_file):
|
711 |
with open(caption_file, 'r', encoding='utf-8') as file:
|
712 |
caption_text = file.read()
|
|
|
782 |
append_text_entry = tk.Entry(root, textvariable=append_text_var, justify='center', width=20)
|
783 |
append_text_entry.pack(pady=5)
|
784 |
|
785 |
+
# Add Radio buttons for caption mode
|
786 |
+
caption_mode_label = tk.Label(root, text="Caption Mode:")
|
787 |
+
caption_mode_label.pack(fill='x', pady=5)
|
788 |
+
|
789 |
+
overwrite_radio = tk.Radiobutton(root, text="Overwrite existing caption", variable=caption_mode_var, value=0)
|
790 |
+
overwrite_radio.pack(fill='x', pady=5)
|
791 |
+
|
792 |
+
append_radio = tk.Radiobutton(root, text="Append to existing caption", variable=caption_mode_var, value=1)
|
793 |
+
append_radio.pack(fill='x', pady=5)
|
794 |
+
|
795 |
+
skip_radio = tk.Radiobutton(root, text="Skip images with existing caption", variable=caption_mode_var, value=2)
|
796 |
+
skip_radio.pack(fill='x', pady=5)
|
797 |
+
|
798 |
thread_count_label = tk.Label(root, text="Thread Count:")
|
799 |
thread_count_label.pack(pady=5)
|
800 |
thread_count_entry = tk.Entry(root, textvariable=thread_count_var, justify='center', width=5, validate='key')
|
|
|
830 |
general_threshold_var.trace_add('write', lambda *args: update_and_save_config())
|
831 |
character_threshold_var.trace_add('write', lambda *args: update_and_save_config())
|
832 |
thread_count_var.trace_add('write', lambda *args: update_and_save_config())
|
833 |
+
caption_mode_var.trace_add('write', lambda *args: update_and_save_config())
|
834 |
|
835 |
center_window(root, width_extra=200)
|
836 |
root.protocol("WM_DELETE_WINDOW", on_closing)
|