DucHaiten commited on
Commit
007aedf
1 Parent(s): 3734881

Update image_to_tag.py

Browse files
Files changed (1) hide show
  1. image_to_tag.py +88 -32
image_to_tag.py CHANGED
@@ -7,6 +7,7 @@ import threading
7
  import subprocess
8
  import sys
9
  import json
 
10
 
11
  # Global variables to control the process and track errors
12
  stop_processing = False
@@ -43,10 +44,11 @@ total_pages = 1 # Initialize total pages
43
  def update_and_save_config():
44
  """Update and save the configuration to JSON."""
45
  save_config_to_json(
46
- model="swinv2", # Default model
47
  general_threshold=general_threshold_var.get(),
48
  character_threshold=character_threshold_var.get(),
49
- model_dir="D:/test/models/wd-swinv2-tagger-v3" # Default model directory
 
50
  )
51
 
52
  def show_errors(root):
@@ -77,13 +79,14 @@ def show_errors(root):
77
 
78
  error_window.protocol("WM_DELETE_WINDOW", on_close_error_window)
79
 
80
- def save_config_to_json(model, general_threshold, character_threshold, model_dir, filepath='config.json'):
81
  """Save the model and threshold values to a JSON file."""
82
  config = {
83
  'model': model,
84
  'general_threshold': general_threshold,
85
  'character_threshold': character_threshold,
86
- 'model_dir': model_dir
 
87
  }
88
  try:
89
  with open(filepath, 'w') as f:
@@ -95,6 +98,7 @@ def open_image_to_tag():
95
  global stop_processing, error_messages, selected_files, save_directory, caption_window, caption_frame, thumbnails, caption_text_widgets, tag_dict, selected_tag, edit_buttons, tag_text_frame, current_page, total_pages, content_canvas
96
  global status_var, num_files_var, errors_var, progress, character_threshold_var, general_threshold_var, thread_count_var, batch_size_var
97
  global start_button, stop_button, prepend_text_var, append_text_var
 
98
 
99
  # Create Tkinter window
100
  root = tk.Tk()
@@ -107,16 +111,17 @@ def open_image_to_tag():
107
  progress = tk.IntVar()
108
  character_threshold_var = tk.DoubleVar(value=0.35)
109
  general_threshold_var = tk.DoubleVar(value=0.35)
110
- thread_count_var = tk.IntVar(value=4)
111
- batch_size_var = tk.IntVar(value=4)
112
  prepend_text_var = tk.StringVar()
113
  append_text_var = tk.StringVar()
 
114
  q = queue.Queue()
115
 
116
  def center_window(window, width_extra=0, height_extra=0):
117
  window.update_idletasks()
118
  width = 100 + width_extra
119
- height = 820 + height_extra
120
  x = (window.winfo_screenwidth() // 2) - (width // 2)
121
  y = (window.winfo_screenheight() // 2) - (height // 2)
122
  window.geometry(f'{width}x{height}+{x}+{y}')
@@ -158,6 +163,7 @@ def open_image_to_tag():
158
  # Stop button should always be enabled
159
  stop_button.config(state=tk.NORMAL)
160
 
 
161
  def generate_caption(image_path, save_directory, q):
162
  """Generate captions for a single image using the wd-swinv2-tagger-v3 model."""
163
  if stop_processing:
@@ -165,13 +171,18 @@ def open_image_to_tag():
165
 
166
  try:
167
  filename = os.path.splitext(os.path.basename(image_path))[0]
168
- output_path = os.path.join(save_directory, f"{filename}.txt")
 
 
 
 
 
169
 
170
  command = [
171
  sys.executable, 'D:/test/wdv3-timm-main/wdv3_timm.py',
172
- '--model', "swinv2",
173
  '--image_path', image_path,
174
- '--model_dir', "D:/test/models/wd-swinv2-tagger-v3",
175
  '--general_threshold', str(general_threshold_var.get()),
176
  '--character_threshold', str(character_threshold_var.get())
177
  ]
@@ -183,31 +194,54 @@ def open_image_to_tag():
183
  print(output) # In ra đầu ra từ lệnh subprocess
184
  print(error_output) # In ra đầu ra lỗi từ lệnh subprocess
185
 
186
- # Filter out information to contain only Caption or General tags
187
- filtered_output = []
188
- recording = False
189
  for line in output.split('\n'):
190
  if "General tags" in line:
191
- recording = True
192
  continue
193
- if recording:
194
  if line.startswith(' '):
195
  tag = line.strip().split(':')[0].replace('_', ' ')
196
- filtered_output.append(tag)
197
  else:
198
- recording = False
199
- break
 
 
 
 
 
 
 
 
 
 
200
 
201
- # Convert list of tags to comma-separated string
202
- final_tags = ','.join(filtered_output) if filtered_output else "No tags found"
 
 
 
 
 
 
203
  print("Filtered output:", final_tags) # Debug: In ra các nhãn cuối cùng sau khi lọc
204
 
205
- # Add prepend and append text
206
  final_tags = f"{prepend_text_var.get()},{final_tags},{append_text_var.get()}".strip(',')
207
 
208
- # Save result to text file
209
- with open(output_path, 'w', encoding='utf-8') as file:
210
- file.write(final_tags)
 
 
 
 
 
 
 
211
 
212
  q.put(image_path)
213
  except Exception as e:
@@ -432,7 +466,8 @@ def open_image_to_tag():
432
  file_label.grid(row=i*2, column=1, padx=5, pady=5, sticky="nsew")
433
 
434
  # Check and display caption if available
435
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_tags.txt")
 
436
  if os.path.exists(caption_file):
437
  with open(caption_file, 'r', encoding='utf-8') as file:
438
  caption_text = file.read()
@@ -442,7 +477,7 @@ def open_image_to_tag():
442
  caption_text_widget = tk.Text(caption_frame, width=50, height=3, wrap=tk.WORD, font=('Helvetica', 12))
443
  caption_text_widget.insert(tk.END, caption_text)
444
  caption_text_widget.grid(row=i*2, column=2, padx=5, pady=5, sticky="nsew")
445
- caption_text_widget.bind("<FocusOut>", lambda e, fp=file_path: save_caption(fp, caption_text_widget.get("1.0", "end-1c")))
446
  caption_text_widgets.append(caption_text_widget)
447
 
448
  # Update tags in tag_dict
@@ -494,7 +529,8 @@ def open_image_to_tag():
494
 
495
  def save_caption(file_path, caption_text):
496
  """Save caption when user changes it."""
497
- output_path = os.path.join(save_directory, f"{os.path.basename(file_path)}_tags.txt")
 
498
  with open(output_path, 'w', encoding='utf-8') as file:
499
  file.write(caption_text)
500
 
@@ -580,7 +616,8 @@ def open_image_to_tag():
580
 
581
  # Update the captions in the respective files
582
  for file_path in selected_files:
583
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_tags.txt")
 
584
  if os.path.exists(caption_file):
585
  with open(caption_file, 'r', encoding='utf-8') as file:
586
  caption_text = file.read()
@@ -614,7 +651,8 @@ def open_image_to_tag():
614
 
615
  # Update the captions in the respective files
616
  for i, file_path in enumerate(selected_files):
617
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_tags.txt")
 
618
  if os.path.exists(caption_file):
619
  with open(caption_file, 'r', encoding='utf-8') as file:
620
  caption_text = file.read()
@@ -644,11 +682,14 @@ def open_image_to_tag():
644
 
645
  # Update the captions in the respective files
646
  for i, file_path in enumerate(selected_files):
647
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_tags.txt")
 
648
  if os.path.exists(caption_file):
649
  with open(caption_file, 'r', encoding='utf-8') as file:
650
  caption_text = file.read()
651
- new_caption_text = caption_text.replace(tag_to_delete, "")
 
 
652
  with open(caption_file, 'w', encoding='utf-8') as file:
653
  file.write(new_caption_text)
654
 
@@ -664,7 +705,8 @@ def open_image_to_tag():
664
  # Delete the files containing the tag
665
  files_to_delete = []
666
  for i, file_path in enumerate(selected_files):
667
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_tags.txt")
 
668
  if os.path.exists(caption_file):
669
  with open(caption_file, 'r', encoding='utf-8') as file:
670
  caption_text = file.read()
@@ -740,6 +782,19 @@ def open_image_to_tag():
740
  append_text_entry = tk.Entry(root, textvariable=append_text_var, justify='center', width=20)
741
  append_text_entry.pack(pady=5)
742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743
  thread_count_label = tk.Label(root, text="Thread Count:")
744
  thread_count_label.pack(pady=5)
745
  thread_count_entry = tk.Entry(root, textvariable=thread_count_var, justify='center', width=5, validate='key')
@@ -775,6 +830,7 @@ def open_image_to_tag():
775
  general_threshold_var.trace_add('write', lambda *args: update_and_save_config())
776
  character_threshold_var.trace_add('write', lambda *args: update_and_save_config())
777
  thread_count_var.trace_add('write', lambda *args: update_and_save_config())
 
778
 
779
  center_window(root, width_extra=200)
780
  root.protocol("WM_DELETE_WINDOW", on_closing)
 
7
  import subprocess
8
  import sys
9
  import json
10
+ import re
11
 
12
  # Global variables to control the process and track errors
13
  stop_processing = False
 
44
  def update_and_save_config():
45
  """Update and save the configuration to JSON."""
46
  save_config_to_json(
47
+ model="eva02", # Default model
48
  general_threshold=general_threshold_var.get(),
49
  character_threshold=character_threshold_var.get(),
50
+ model_dir="D:/test/models/wd-eva02-large-tagger-v3", # Default model directory
51
+ caption_mode=caption_mode_var.get()
52
  )
53
 
54
  def show_errors(root):
 
79
 
80
  error_window.protocol("WM_DELETE_WINDOW", on_close_error_window)
81
 
82
+ def save_config_to_json(model, general_threshold, character_threshold, model_dir, caption_mode, filepath='config.json'):
83
  """Save the model and threshold values to a JSON file."""
84
  config = {
85
  'model': model,
86
  'general_threshold': general_threshold,
87
  'character_threshold': character_threshold,
88
+ 'model_dir': model_dir,
89
+ 'caption_mode': caption_mode
90
  }
91
  try:
92
  with open(filepath, 'w') as f:
 
98
  global stop_processing, error_messages, selected_files, save_directory, caption_window, caption_frame, thumbnails, caption_text_widgets, tag_dict, selected_tag, edit_buttons, tag_text_frame, current_page, total_pages, content_canvas
99
  global status_var, num_files_var, errors_var, progress, character_threshold_var, general_threshold_var, thread_count_var, batch_size_var
100
  global start_button, stop_button, prepend_text_var, append_text_var
101
+ global caption_mode_var # Khai báo biến toàn cục
102
 
103
  # Create Tkinter window
104
  root = tk.Tk()
 
111
  progress = tk.IntVar()
112
  character_threshold_var = tk.DoubleVar(value=0.35)
113
  general_threshold_var = tk.DoubleVar(value=0.35)
114
+ thread_count_var = tk.IntVar(value=1)
115
+ batch_size_var = tk.IntVar(value=8)
116
  prepend_text_var = tk.StringVar()
117
  append_text_var = tk.StringVar()
118
+ caption_mode_var = tk.IntVar(value=1)
119
  q = queue.Queue()
120
 
121
  def center_window(window, width_extra=0, height_extra=0):
122
  window.update_idletasks()
123
  width = 100 + width_extra
124
+ height = 950 + height_extra
125
  x = (window.winfo_screenwidth() // 2) - (width // 2)
126
  y = (window.winfo_screenheight() // 2) - (height // 2)
127
  window.geometry(f'{width}x{height}+{x}+{y}')
 
163
  # Stop button should always be enabled
164
  stop_button.config(state=tk.NORMAL)
165
 
166
+
167
  def generate_caption(image_path, save_directory, q):
168
  """Generate captions for a single image using the wd-swinv2-tagger-v3 model."""
169
  if stop_processing:
 
171
 
172
  try:
173
  filename = os.path.splitext(os.path.basename(image_path))[0]
174
+ output_path = os.path.join(save_directory, f"{filename}.txt") # Sửa lại tên tệp caption
175
+
176
+ # Kiểm tra chế độ tạo caption
177
+ if caption_mode_var.get() == 2 and os.path.exists(output_path):
178
+ q.put(image_path)
179
+ return
180
 
181
  command = [
182
  sys.executable, 'D:/test/wdv3-timm-main/wdv3_timm.py',
183
+ '--model', "eva02",
184
  '--image_path', image_path,
185
+ '--model_dir', "D:/test/models/wd-eva02-large-tagger-v3",
186
  '--general_threshold', str(general_threshold_var.get()),
187
  '--character_threshold', str(character_threshold_var.get())
188
  ]
 
194
  print(output) # In ra đầu ra từ lệnh subprocess
195
  print(error_output) # In ra đầu ra lỗi từ lệnh subprocess
196
 
197
+ # Lọc thông tin "General tags"
198
+ general_tags = []
199
+ recording_general = False
200
  for line in output.split('\n'):
201
  if "General tags" in line:
202
+ recording_general = True
203
  continue
204
+ if recording_general:
205
  if line.startswith(' '):
206
  tag = line.strip().split(':')[0].replace('_', ' ')
207
+ general_tags.append(tag)
208
  else:
209
+ recording_general = False
210
+
211
+ # Lọc thông tin "Character tags"
212
+ character_tags = []
213
+ recording_character = False
214
+ for line in output.split('\n'):
215
+ if "Character tags" in line:
216
+ recording_character = True
217
+ continue
218
+ if recording_character:
219
+ if line.startswith(' '):
220
+ tag = line.strip().split(':')[0].replace('_', ' ')
221
 
222
+ # Loại bỏ từ khóa chứa từ 'costume'
223
+ if 'costume' not in tag.lower():
224
+ character_tags.append(tag) # Giữ lại từ khóa không chứa 'costume'
225
+ else:
226
+ recording_character = False
227
+
228
+ # Kết hợp cả general và character tags
229
+ final_tags = ','.join(general_tags + character_tags) if general_tags or character_tags else "No tags found"
230
  print("Filtered output:", final_tags) # Debug: In ra các nhãn cuối cùng sau khi lọc
231
 
232
+ # Thêm văn bản trước và sau
233
  final_tags = f"{prepend_text_var.get()},{final_tags},{append_text_var.get()}".strip(',')
234
 
235
+ # Xử ghi đè, nối thêm hoặc bỏ qua caption hiện có
236
+ if caption_mode_var.get() == 0: # Overwrite
237
+ with open(output_path, 'w', encoding='utf-8') as file:
238
+ file.write(final_tags)
239
+ elif caption_mode_var.get() == 1 and os.path.exists(output_path): # Append
240
+ with open(output_path, 'a', encoding='utf-8') as file:
241
+ file.write(f",{final_tags}")
242
+ else: # Tạo mới hoặc ghi đè nếu file không tồn tại
243
+ with open(output_path, 'w', encoding='utf-8') as file:
244
+ file.write(final_tags)
245
 
246
  q.put(image_path)
247
  except Exception as e:
 
466
  file_label.grid(row=i*2, column=1, padx=5, pady=5, sticky="nsew")
467
 
468
  # Check and display caption if available
469
+ filename = os.path.splitext(os.path.basename(file_path))[0] # Lấy tên tệp không có phần mở rộng
470
+ caption_file = os.path.join(save_directory, f"{filename}.txt")
471
  if os.path.exists(caption_file):
472
  with open(caption_file, 'r', encoding='utf-8') as file:
473
  caption_text = file.read()
 
477
  caption_text_widget = tk.Text(caption_frame, width=50, height=3, wrap=tk.WORD, font=('Helvetica', 12))
478
  caption_text_widget.insert(tk.END, caption_text)
479
  caption_text_widget.grid(row=i*2, column=2, padx=5, pady=5, sticky="nsew")
480
+ caption_text_widget.bind("<FocusOut>", lambda e, fp=file_path, w=caption_text_widget: save_caption(fp, w.get("1.0", "end-1c")))
481
  caption_text_widgets.append(caption_text_widget)
482
 
483
  # Update tags in tag_dict
 
529
 
530
  def save_caption(file_path, caption_text):
531
  """Save caption when user changes it."""
532
+ filename = os.path.splitext(os.path.basename(file_path))[0]
533
+ output_path = os.path.join(save_directory, f"{filename}.txt")
534
  with open(output_path, 'w', encoding='utf-8') as file:
535
  file.write(caption_text)
536
 
 
616
 
617
  # Update the captions in the respective files
618
  for file_path in selected_files:
619
+ filename = os.path.splitext(os.path.basename(file_path))[0]
620
+ caption_file = os.path.join(save_directory, f"{filename}.txt")
621
  if os.path.exists(caption_file):
622
  with open(caption_file, 'r', encoding='utf-8') as file:
623
  caption_text = file.read()
 
651
 
652
  # Update the captions in the respective files
653
  for i, file_path in enumerate(selected_files):
654
+ filename = os.path.splitext(os.path.basename(file_path))[0]
655
+ caption_file = os.path.join(save_directory, f"{filename}.txt")
656
  if os.path.exists(caption_file):
657
  with open(caption_file, 'r', encoding='utf-8') as file:
658
  caption_text = file.read()
 
682
 
683
  # Update the captions in the respective files
684
  for i, file_path in enumerate(selected_files):
685
+ filename = os.path.splitext(os.path.basename(file_path))[0]
686
+ caption_file = os.path.join(save_directory, f"{filename}.txt")
687
  if os.path.exists(caption_file):
688
  with open(caption_file, 'r', encoding='utf-8') as file:
689
  caption_text = file.read()
690
+ # Remove the tag from the caption
691
+ tags = [tag.strip() for tag in caption_text.split(',') if tag.strip() != tag_to_delete]
692
+ new_caption_text = ','.join(tags)
693
  with open(caption_file, 'w', encoding='utf-8') as file:
694
  file.write(new_caption_text)
695
 
 
705
  # Delete the files containing the tag
706
  files_to_delete = []
707
  for i, file_path in enumerate(selected_files):
708
+ filename = os.path.splitext(os.path.basename(file_path))[0]
709
+ caption_file = os.path.join(save_directory, f"{filename}.txt")
710
  if os.path.exists(caption_file):
711
  with open(caption_file, 'r', encoding='utf-8') as file:
712
  caption_text = file.read()
 
782
  append_text_entry = tk.Entry(root, textvariable=append_text_var, justify='center', width=20)
783
  append_text_entry.pack(pady=5)
784
 
785
+ # Add Radio buttons for caption mode
786
+ caption_mode_label = tk.Label(root, text="Caption Mode:")
787
+ caption_mode_label.pack(fill='x', pady=5)
788
+
789
+ overwrite_radio = tk.Radiobutton(root, text="Overwrite existing caption", variable=caption_mode_var, value=0)
790
+ overwrite_radio.pack(fill='x', pady=5)
791
+
792
+ append_radio = tk.Radiobutton(root, text="Append to existing caption", variable=caption_mode_var, value=1)
793
+ append_radio.pack(fill='x', pady=5)
794
+
795
+ skip_radio = tk.Radiobutton(root, text="Skip images with existing caption", variable=caption_mode_var, value=2)
796
+ skip_radio.pack(fill='x', pady=5)
797
+
798
  thread_count_label = tk.Label(root, text="Thread Count:")
799
  thread_count_label.pack(pady=5)
800
  thread_count_entry = tk.Entry(root, textvariable=thread_count_var, justify='center', width=5, validate='key')
 
830
  general_threshold_var.trace_add('write', lambda *args: update_and_save_config())
831
  character_threshold_var.trace_add('write', lambda *args: update_and_save_config())
832
  thread_count_var.trace_add('write', lambda *args: update_and_save_config())
833
+ caption_mode_var.trace_add('write', lambda *args: update_and_save_config())
834
 
835
  center_window(root, width_extra=200)
836
  root.protocol("WM_DELETE_WINDOW", on_closing)