mirabarukaso commited on
Commit
e60fe6a
·
1 Parent(s): e4ad141

Update tag complete method

Browse files

More flexible and reliable tag complete

Files changed (3) hide show
  1. app.py +4 -23
  2. scripts/lib.py +8 -5
  3. scripts/tag_autocomplete.py +213 -11
app.py CHANGED
@@ -1,30 +1,11 @@
1
  import gradio as gr
2
  import sys
3
  sys.path.append("scripts/")
4
- from lib import init, refresh_character_thumb_image
5
  from lib import JAVA_SCRIPT, CSS_SCRIPT, TITLE
6
 
7
  if __name__ == '__main__':
8
- character_list, character_list_cn, LANG, TAG_AUTOCOMPLETE = init()
9
-
10
- def update_suggestions(text):
11
- matches = TAG_AUTOCOMPLETE.get_suggestions(text)
12
- items = []
13
- if matches:
14
- for _, m in enumerate(matches):
15
- key = f"{m['prompt']} ({m['heat']})"
16
- items.append([key])
17
- return gr.Dataset(samples=items)
18
-
19
- def apply_suggestion(evt: gr.SelectData, text, custom_prompt):
20
- suggestion = evt.value
21
- #print(f"You selected {evt.value} at {evt.index} from {evt.target} for {custom_prompt}")
22
- if not custom_prompt or not suggestion:
23
- return custom_prompt
24
-
25
- parts = custom_prompt.split(',')
26
- parts[-1] = suggestion[0].split(' ')[0].replace('_', ' ').replace(':', ' ')
27
- return ', '.join(parts) + ', '
28
 
29
  with gr.Blocks(js=JAVA_SCRIPT, css=CSS_SCRIPT, title=TITLE) as ui:
30
  with gr.Row():
@@ -70,7 +51,7 @@ if __name__ == '__main__':
70
  inputs=[character1,character2,character3],
71
  outputs=[thumb_image])
72
 
73
- custom_prompt.change(fn=update_suggestions, inputs=[custom_prompt], outputs=[suggestions])
74
- suggestions.select(fn=apply_suggestion, inputs=[suggestions, custom_prompt], outputs=[custom_prompt])
75
 
76
  ui.launch()
 
1
  import gradio as gr
2
  import sys
3
  sys.path.append("scripts/")
4
+ from lib import init, refresh_character_thumb_image, get_prompt_manager
5
  from lib import JAVA_SCRIPT, CSS_SCRIPT, TITLE
6
 
7
  if __name__ == '__main__':
8
+ character_list, character_list_cn, LANG = init()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  with gr.Blocks(js=JAVA_SCRIPT, css=CSS_SCRIPT, title=TITLE) as ui:
11
  with gr.Row():
 
51
  inputs=[character1,character2,character3],
52
  outputs=[thumb_image])
53
 
54
+ custom_prompt.change(fn=get_prompt_manager().update_suggestions, inputs=[custom_prompt], outputs=[suggestions])
55
+ suggestions.select(fn=get_prompt_manager().apply_suggestion, inputs=[suggestions, custom_prompt], outputs=[custom_prompt, suggestions])
56
 
57
  ui.launch()
scripts/lib.py CHANGED
@@ -7,7 +7,7 @@ import base64
7
 
8
  from io import BytesIO
9
  from PIL import Image
10
- from tag_autocomplete import PromptSuggester
11
 
12
  # Language
13
  LANG_EN = {
@@ -53,7 +53,7 @@ character_list = ''
53
  character_dict = {}
54
  wai_image_dict = {}
55
  character_list_cn = ''
56
- TAG_AUTOCOMPLETE = None
57
 
58
  wai_illustrious_character_select_files = [
59
  {'name': 'wai_character', 'file_path': os.path.join(json_folder, 'wai_characters.csv'), 'url':'https://raw.githubusercontent.com/mirabarukaso/character_select_stand_alone_app/refs/heads/main/json/wai_characters.csv'},
@@ -83,7 +83,7 @@ def load_jsons():
83
  global character_dict
84
  global wai_image_dict
85
  global character_list_cn
86
- global TAG_AUTOCOMPLETE
87
 
88
  # download file
89
  for item in wai_illustrious_character_select_files:
@@ -95,7 +95,7 @@ def load_jsons():
95
  download_file(url, file_path)
96
 
97
  if 'e621_sfw' == name:
98
- TAG_AUTOCOMPLETE = PromptSuggester(file_path)
99
  else:
100
  with open(file_path, 'r', encoding='utf-8') as file:
101
  if 'wai_character' == name:
@@ -162,8 +162,11 @@ def refresh_character_thumb_image(character1, character2, character3):
162
  thumb_image.append(thumb_image3)
163
  return thumb_image
164
 
 
 
 
165
  def init():
166
  load_jsons()
167
  print(f'[{CAT}]:Starting...')
168
 
169
- return character_list, character_list_cn, LANG, TAG_AUTOCOMPLETE
 
7
 
8
  from io import BytesIO
9
  from PIL import Image
10
+ from tag_autocomplete import PromptManager
11
 
12
  # Language
13
  LANG_EN = {
 
53
  character_dict = {}
54
  wai_image_dict = {}
55
  character_list_cn = ''
56
+ PROMPT_MANAGER = None
57
 
58
  wai_illustrious_character_select_files = [
59
  {'name': 'wai_character', 'file_path': os.path.join(json_folder, 'wai_characters.csv'), 'url':'https://raw.githubusercontent.com/mirabarukaso/character_select_stand_alone_app/refs/heads/main/json/wai_characters.csv'},
 
83
  global character_dict
84
  global wai_image_dict
85
  global character_list_cn
86
+ global PROMPT_MANAGER
87
 
88
  # download file
89
  for item in wai_illustrious_character_select_files:
 
95
  download_file(url, file_path)
96
 
97
  if 'e621_sfw' == name:
98
+ PROMPT_MANAGER = PromptManager(file_path)
99
  else:
100
  with open(file_path, 'r', encoding='utf-8') as file:
101
  if 'wai_character' == name:
 
162
  thumb_image.append(thumb_image3)
163
  return thumb_image
164
 
165
+ def get_prompt_manager():
166
+ return PROMPT_MANAGER
167
+
168
  def init():
169
  load_jsons()
170
  print(f'[{CAT}]:Starting...')
171
 
172
+ return character_list, character_list_cn, LANG
scripts/tag_autocomplete.py CHANGED
@@ -1,16 +1,36 @@
1
  import os
 
2
  from typing import List, Dict
 
3
 
4
  CAT = "Auto Tag Complete"
5
 
6
- class PromptSuggester:
7
  def __init__(self, prompt_file_path):
 
 
 
8
  self.prompts = []
 
 
 
 
 
 
 
 
 
 
 
9
  self.load_prompts(prompt_file_path)
10
-
11
  def load_prompts(self, file_path: str):
 
 
 
12
  if not os.path.exists(file_path):
13
  print(f"File {file_path} not found.")
 
14
  return
15
 
16
  with open(file_path, 'r', encoding='utf-8') as f:
@@ -18,7 +38,8 @@ class PromptSuggester:
18
  line = line.strip()
19
  if not line:
20
  continue
21
-
 
22
  parts = line.split(',', 3)
23
  if len(parts) >= 2:
24
  prompt = parts[0].strip()
@@ -26,6 +47,7 @@ class PromptSuggester:
26
  heat = int(parts[2]) if len(parts) > 2 and parts[2].strip().isdigit() else 0
27
  aliases = parts[3].strip('"') if len(parts) > 3 else ""
28
 
 
29
  if heat == 0:
30
  heat = group
31
  group = 0
@@ -38,20 +60,33 @@ class PromptSuggester:
38
  })
39
 
40
  self.prompts.sort(key=lambda x: x['heat'], reverse=True)
 
41
  print(f"[{CAT}] Loaded {len(self.prompts)} prompts.")
42
-
43
- def get_suggestions(self, text: str) -> List[Dict]:
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  if not text:
45
  return []
46
-
 
47
  parts = text.split(',')
48
  last_word = parts[-1].strip().lower()
49
 
50
  if not last_word:
51
  return []
52
 
53
- #print(f"Getting suggestions for: {last_word}")
54
-
55
  matches = {}
56
  for prompt_info in self.prompts:
57
  prompt = prompt_info['prompt'].lower()
@@ -83,12 +118,179 @@ class PromptSuggester:
83
  if prompt not in matches or prompt_info['heat'] > matches[prompt]['heat']:
84
  matches[prompt] = {'prompt': prompt_info['prompt'], 'heat': prompt_info['heat']}
85
 
86
- # we only need 30 items
87
- if len(matches) == 30:
88
  break
89
 
90
  sorted_matches = sorted(matches.values(), key=lambda x: x['heat'], reverse=True)
91
- #print(f"Found {len(sorted_matches)} unique matches")
92
  return sorted_matches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
 
 
94
 
 
1
  import os
2
+ import re
3
  from typing import List, Dict
4
+ import gradio as gr
5
 
6
  CAT = "Auto Tag Complete"
7
 
8
+ class PromptManager:
9
  def __init__(self, prompt_file_path):
10
+ """
11
+ Initialize the prompt manager with the path to the prompt file.
12
+ """
13
  self.prompts = []
14
+ # file path
15
+ self.prompt_file_path = prompt_file_path
16
+ # save user input to keep track of changes
17
+ self.last_custom_prompt = ""
18
+ self.previous_custom_prompt = ""
19
+ # flag for just applied suggestion
20
+ self.just_applied_suggestion = False
21
+ # flag for data loaded
22
+ self.data_loaded = False
23
+
24
+ # load prompts
25
  self.load_prompts(prompt_file_path)
26
+
27
  def load_prompts(self, file_path: str):
28
+ """
29
+ Load prompts from a file, and sort them by heat in descending order.
30
+ """
31
  if not os.path.exists(file_path):
32
  print(f"File {file_path} not found.")
33
+ self.data_loaded = False
34
  return
35
 
36
  with open(file_path, 'r', encoding='utf-8') as f:
 
38
  line = line.strip()
39
  if not line:
40
  continue
41
+
42
+ # split the line by comma, and get the prompt, group, heat, and aliases
43
  parts = line.split(',', 3)
44
  if len(parts) >= 2:
45
  prompt = parts[0].strip()
 
47
  heat = int(parts[2]) if len(parts) > 2 and parts[2].strip().isdigit() else 0
48
  aliases = parts[3].strip('"') if len(parts) > 3 else ""
49
 
50
+ # some tags didn't have a group, so we set the group to the heat
51
  if heat == 0:
52
  heat = group
53
  group = 0
 
60
  })
61
 
62
  self.prompts.sort(key=lambda x: x['heat'], reverse=True)
63
+ self.data_loaded = True
64
  print(f"[{CAT}] Loaded {len(self.prompts)} prompts.")
65
+
66
+ def reload_data(self):
67
+ """
68
+ Reset the state and reload the prompt data.
69
+ """
70
+ print(f"[{CAT}] Reloading prompts from {self.prompt_file_path}...")
71
+ self.prompts = []
72
+ self.last_custom_prompt = ""
73
+ self.previous_custom_prompt = ""
74
+ self.load_prompts(self.prompt_file_path)
75
+
76
+ def get_suggestions(self, text: str) -> List[Dict]:
77
+ """
78
+ Create suggestions based on the input text.
79
+ """
80
  if not text:
81
  return []
82
+
83
+ # split the text by comma, and get the last word
84
  parts = text.split(',')
85
  last_word = parts[-1].strip().lower()
86
 
87
  if not last_word:
88
  return []
89
 
 
 
90
  matches = {}
91
  for prompt_info in self.prompts:
92
  prompt = prompt_info['prompt'].lower()
 
118
  if prompt not in matches or prompt_info['heat'] > matches[prompt]['heat']:
119
  matches[prompt] = {'prompt': prompt_info['prompt'], 'heat': prompt_info['heat']}
120
 
121
+ # we only need 12 items
122
+ if len(matches) == 12:
123
  break
124
 
125
  sorted_matches = sorted(matches.values(), key=lambda x: x['heat'], reverse=True)
 
126
  return sorted_matches
127
+
128
+ def process_parts(self, parts):
129
+ """
130
+ Process each part of the prompt, handling ':' and '_'.
131
+ """
132
+ processed_parts = []
133
+ for part in parts:
134
+ # Remove extra spaces
135
+ part = part.strip()
136
+
137
+ # If the part starts with ':', keep it as is
138
+ if part.startswith(':'):
139
+ processed_parts.append(part)
140
+ continue
141
+
142
+ # If the part is enclosed in parentheses, process the content inside
143
+ if part.startswith('(') and part.endswith(')'):
144
+ # Remove parentheses and split the content by commas
145
+ inner_content = part[1:-1]
146
+ inner_parts = inner_content.split(',')
147
+ # Recursively process the parts inside the parentheses
148
+ processed_inner_parts = self.process_parts(inner_parts)
149
+ # Rejoin the processed parts and re-enclose them in parentheses
150
+ processed_parts.append(f"({', '.join(processed_inner_parts)})")
151
+ continue
152
+
153
+ # If the part contains ':', handle it based on the content after ':'
154
+ if ':' in part:
155
+ # Match the content before and after ':'
156
+ match = re.match(r'^(.*?):([-+]?\d*\.?\d+)$', part)
157
+ if match:
158
+ prefix, number = match.groups()
159
+ try:
160
+ # Try to convert the number to a float
161
+ number = float(number)
162
+ # If the number is an integer, convert it to an integer
163
+ if number.is_integer():
164
+ number = int(number)
165
+ # If the absolute value of the number is greater than 10, remove ':'
166
+ if abs(number) > 10:
167
+ processed_parts.append(f"{prefix} {number}")
168
+ else:
169
+ # If the absolute value is less than or equal to 10, keep ':'
170
+ processed_parts.append(part)
171
+ except ValueError:
172
+ # If conversion fails, replace ':' with a space
173
+ processed_parts.append(part.replace(':', ' '))
174
+ else:
175
+ # If the content after ':' does not match the pattern, replace ':' with a space
176
+ processed_parts.append(part.replace(':', ' '))
177
+ else:
178
+ # If the part does not contain ':', add it directly
179
+ processed_parts.append(part)
180
+ return processed_parts
181
+
182
+ def update_suggestions(self, text):
183
+ """
184
+ Update suggestions based on the current input and update global variables.
185
+ """
186
+ # If data is not loaded, return an empty dataset
187
+ if not self.data_loaded:
188
+ print(f"[{CAT}] No data loaded. Returning empty dataset.")
189
+ return gr.Dataset(samples=[])
190
+
191
+ # If apply_suggestion was just executed, return an empty dataset
192
+ if self.just_applied_suggestion:
193
+ # Reset the flag
194
+ self.just_applied_suggestion = False
195
+ return gr.Dataset(samples=[])
196
+
197
+ matches = []
198
+ items = []
199
+
200
+ # Split the text by commas
201
+ current_parts = text.split(',') if text else []
202
+ previous_parts = self.previous_custom_prompt.split(',') if self.previous_custom_prompt else []
203
+
204
+ # Locate the position of the word modified by the user
205
+ modified_index = -1
206
+ for i, (current, previous) in enumerate(zip(current_parts, previous_parts)):
207
+ if current.strip() != previous.strip():
208
+ modified_index = i
209
+ break
210
+
211
+ # If no modified word is found and the current input is longer than the previous input, set the modified index to the last index
212
+ if modified_index == -1 and len(current_parts) > len(previous_parts):
213
+ modified_index = len(current_parts) - 1
214
+
215
+ # If a modified word is found, get suggestions
216
+ target_word = None
217
+ if 0 <= modified_index < len(current_parts):
218
+ target_word = current_parts[modified_index].strip()
219
+ matches = self.get_suggestions(target_word)
220
+
221
+ # Create a list of suggestions
222
+ if matches:
223
+ for _, m in enumerate(matches):
224
+ key = f"{m['prompt']} ({m['heat']})"
225
+ items.append([key])
226
+
227
+ # Update global variables to save the current input
228
+ self.previous_custom_prompt = self.last_custom_prompt
229
+ self.last_custom_prompt = text
230
+
231
+ # Debugging
232
+ """
233
+ print(f"CURRENT_CUSTOM_PROMPT: {text}")
234
+ print(f"PREVIOUS_CUSTOM_PROMPT: {self.previous_custom_prompt}")
235
+ print(f"LAST_CUSTOM_PROMPT: {self.last_custom_prompt}")
236
+ print(f"Modified index: {modified_index}")
237
+ if target_word is not None:
238
+ print(f"Suggestions for '{target_word}': {items}")
239
+ """
240
+ return gr.Dataset(samples=items)
241
+
242
+ def apply_suggestion(self, evt: gr.SelectData, text, custom_prompt):
243
+ """
244
+ Apply the suggestion selected by the user to the prompt and clear the prompt box.
245
+ """
246
+ # If data is not loaded, return the original prompt and an empty dataset
247
+ if not self.data_loaded:
248
+ print(f"[{CAT}] No data loaded. Cannot apply suggestion.")
249
+ return custom_prompt, gr.Dataset(samples=[])
250
+
251
+ # Check the type of evt.value
252
+ if isinstance(evt.value, list):
253
+ suggestion = evt.value[0]
254
+ elif isinstance(evt.value, str):
255
+ suggestion = evt.value
256
+ else:
257
+ # Should not reach here
258
+ raise ValueError(f"Unexpected value type: {type(evt.value)}. Expected a string or list.")
259
+
260
+ # Get the suggestion content
261
+ suggestion = suggestion.split(' ')[0]
262
+ if not suggestion:
263
+ return custom_prompt, gr.Dataset(samples=[])
264
+
265
+ # Split custom_prompt by commas
266
+ parts = custom_prompt.split(',') if custom_prompt else []
267
+ previous_parts = self.previous_custom_prompt.split(',') if self.previous_custom_prompt else []
268
+
269
+ # Locate the position of the word modified by the user
270
+ modified_index = -1
271
+ for i, (current, previous) in enumerate(zip(parts, previous_parts)):
272
+ if current.strip() != previous.strip():
273
+ modified_index = i
274
+ break
275
+
276
+ # If no modified word is found and the current input is longer than the previous input, set the modified index to the last index
277
+ if modified_index == -1 and len(parts) > len(previous_parts):
278
+ modified_index = len(parts) - 1
279
+
280
+ # If the modified word is not found, set the modified index to the last index
281
+ if modified_index < 0 or modified_index >= len(parts):
282
+ modified_index = len(parts) - 1
283
+
284
+ # Replace the modified word with the suggestion
285
+ parts[modified_index] = suggestion
286
+
287
+ # Update the global variables
288
+ self.previous_custom_prompt = self.last_custom_prompt
289
+ self.last_custom_prompt = ', '.join(self.process_parts(parts)).replace('_', ' ')
290
+
291
+ # Set just_applied_suggestion to True
292
+ self.just_applied_suggestion = True
293
 
294
+ # Return the updated prompt and an empty dataset
295
+ return self.last_custom_prompt, gr.Dataset(samples=[])
296