Spaces:
Paused
Paused
throaway2854
commited on
Commit
•
20fec7d
1
Parent(s):
93e965c
Update app.py
Browse files
app.py
CHANGED
@@ -6,12 +6,27 @@ from PIL import Image
|
|
6 |
import base64
|
7 |
import io
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
class DataManager:
|
10 |
-
def __init__(self, characters_file='characters.json',
|
11 |
-
persistent_tags_file='persistent_tags.json',
|
|
|
12 |
images_folder='character_images'):
|
13 |
self.characters_file = characters_file
|
14 |
self.persistent_tags_file = persistent_tags_file
|
|
|
15 |
self.images_folder = images_folder
|
16 |
|
17 |
# Make sure image folder exists
|
@@ -19,7 +34,7 @@ class DataManager:
|
|
19 |
os.makedirs(images_folder)
|
20 |
|
21 |
self.load_data()
|
22 |
-
|
23 |
def load_data(self):
|
24 |
# Load characters
|
25 |
if os.path.exists(self.characters_file):
|
@@ -27,7 +42,7 @@ class DataManager:
|
|
27 |
self.characters = json.load(f)
|
28 |
else:
|
29 |
self.characters = []
|
30 |
-
|
31 |
# Load persistent tags
|
32 |
if os.path.exists(self.persistent_tags_file):
|
33 |
with open(self.persistent_tags_file, 'r') as f:
|
@@ -35,6 +50,9 @@ class DataManager:
|
|
35 |
else:
|
36 |
self.persistent_tags = []
|
37 |
|
|
|
|
|
|
|
38 |
def save_characters(self):
|
39 |
with open(self.characters_file, 'w') as f:
|
40 |
json.dump(self.characters, f)
|
@@ -42,7 +60,43 @@ class DataManager:
|
|
42 |
def save_persistent_tags(self):
|
43 |
with open(self.persistent_tags_file, 'w') as f:
|
44 |
json.dump(self.persistent_tags, f)
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
def add_character(self, character):
|
47 |
# Save image to disk and store the filename
|
48 |
image_data = character['image'] # This is base64 encoded string
|
@@ -66,24 +120,6 @@ class DataManager:
|
|
66 |
|
67 |
self.characters.append(character)
|
68 |
self.save_characters()
|
69 |
-
|
70 |
-
def get_characters(self):
|
71 |
-
# Load character images
|
72 |
-
for char in self.characters:
|
73 |
-
image_path = char.get('image_path')
|
74 |
-
if image_path and os.path.exists(image_path):
|
75 |
-
char['image'] = image_path
|
76 |
-
else:
|
77 |
-
char['image'] = None
|
78 |
-
return self.characters
|
79 |
-
|
80 |
-
def set_persistent_tags(self, tags_string):
|
81 |
-
if isinstance(tags_string, str):
|
82 |
-
self.persistent_tags = [t.strip() for t in tags_string.split(',') if t.strip()]
|
83 |
-
self.save_persistent_tags()
|
84 |
-
|
85 |
-
def get_persistent_tags(self):
|
86 |
-
return self.persistent_tags
|
87 |
|
88 |
def character_creation_app(data_manager):
|
89 |
with gr.Tab("Character Creation"):
|
@@ -110,30 +146,27 @@ def character_creation_app(data_manager):
|
|
110 |
character['image'] = None
|
111 |
|
112 |
data_manager.add_character(character)
|
113 |
-
# Clear inputs after saving
|
114 |
return f"Character '{name}' saved successfully."
|
115 |
|
116 |
save_button.click(save_character, inputs=[name_input, traits_input, image_input, gender_input], outputs=output)
|
117 |
|
118 |
def prompt_generator_app(data_manager):
|
119 |
with gr.Tab("Prompt Generator"):
|
120 |
-
|
121 |
-
categories = [
|
122 |
-
('Scene', 'scene_tags'),
|
123 |
-
('Position', 'position_tags'),
|
124 |
-
('Outfit', 'outfit_tags'),
|
125 |
-
('Camera View/Angle', 'camera_tags'),
|
126 |
-
('Concept', 'concept_tags'),
|
127 |
-
('Additional', 'additional_tags'),
|
128 |
-
('LORA', 'lora_tags')
|
129 |
-
]
|
130 |
|
131 |
inputs = {}
|
132 |
for category_name, var_name in categories:
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
inputs[f"{var_name}_num"] = tag_num
|
138 |
|
139 |
# For Character Selection
|
@@ -166,16 +199,15 @@ def prompt_generator_app(data_manager):
|
|
166 |
generate_button = gr.Button("Generate Prompt")
|
167 |
prompt_output = gr.Textbox(label="Generated Prompt", lines=5)
|
168 |
|
169 |
-
def generate_prompt(*args):
|
170 |
arg_idx = 0
|
171 |
|
172 |
prompt_tags = []
|
173 |
for category_name, var_name in categories:
|
174 |
-
|
175 |
-
tags_num = args[arg_idx
|
176 |
-
arg_idx +=
|
177 |
|
178 |
-
tags_list = [tag.strip() for tag in tags_input.split(',')] if tags_input else []
|
179 |
if tags_list and tags_num > 0:
|
180 |
selected_tags = random.sample(tags_list, min(len(tags_list), int(tags_num)))
|
181 |
prompt_tags.extend(selected_tags)
|
@@ -255,31 +287,56 @@ def prompt_generator_app(data_manager):
|
|
255 |
|
256 |
# Prepare the list of inputs for the generate_prompt function
|
257 |
inputs_list = []
|
258 |
-
for var_name in
|
259 |
-
inputs_list.append(inputs[var_name])
|
260 |
# Add character_select directly to inputs
|
261 |
inputs_list.extend([character_select, random_characters, num_characters])
|
262 |
|
263 |
generate_button.click(generate_prompt, inputs=inputs_list, outputs=prompt_output)
|
264 |
|
265 |
-
def
|
266 |
-
with gr.Tab("
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
save_persistent_tags_button = gr.Button("Save Persistent Tags")
|
269 |
-
|
270 |
|
271 |
def save_persistent_tags(tags_string):
|
272 |
-
|
|
|
273 |
return "Persistent tags saved successfully."
|
274 |
|
275 |
-
save_persistent_tags_button.click(save_persistent_tags, inputs=
|
276 |
|
277 |
def main():
|
278 |
data_manager = DataManager()
|
279 |
with gr.Blocks() as demo:
|
280 |
prompt_generator_app(data_manager)
|
281 |
character_creation_app(data_manager)
|
282 |
-
|
283 |
|
284 |
demo.launch()
|
285 |
|
|
|
6 |
import base64
|
7 |
import io
|
8 |
|
9 |
+
# Define categories at the top so they are accessible throughout the code
|
10 |
+
categories = [
|
11 |
+
('Scene', 'scene_tags'),
|
12 |
+
('Position', 'position_tags'),
|
13 |
+
('Outfit', 'outfit_tags'),
|
14 |
+
('Camera View/Angle', 'camera_tags'),
|
15 |
+
('Concept', 'concept_tags'),
|
16 |
+
('Facial Expression', 'facial_expression_tags'),
|
17 |
+
('Pose', 'pose_tags'),
|
18 |
+
('Additional', 'additional_tags'),
|
19 |
+
('LORA', 'lora_tags')
|
20 |
+
]
|
21 |
+
|
22 |
class DataManager:
|
23 |
+
def __init__(self, characters_file='characters.json',
|
24 |
+
persistent_tags_file='persistent_tags.json',
|
25 |
+
category_tags_file='category_tags.json',
|
26 |
images_folder='character_images'):
|
27 |
self.characters_file = characters_file
|
28 |
self.persistent_tags_file = persistent_tags_file
|
29 |
+
self.category_tags_file = category_tags_file
|
30 |
self.images_folder = images_folder
|
31 |
|
32 |
# Make sure image folder exists
|
|
|
34 |
os.makedirs(images_folder)
|
35 |
|
36 |
self.load_data()
|
37 |
+
|
38 |
def load_data(self):
|
39 |
# Load characters
|
40 |
if os.path.exists(self.characters_file):
|
|
|
42 |
self.characters = json.load(f)
|
43 |
else:
|
44 |
self.characters = []
|
45 |
+
|
46 |
# Load persistent tags
|
47 |
if os.path.exists(self.persistent_tags_file):
|
48 |
with open(self.persistent_tags_file, 'r') as f:
|
|
|
50 |
else:
|
51 |
self.persistent_tags = []
|
52 |
|
53 |
+
# Load category tags
|
54 |
+
self.load_category_tags()
|
55 |
+
|
56 |
def save_characters(self):
|
57 |
with open(self.characters_file, 'w') as f:
|
58 |
json.dump(self.characters, f)
|
|
|
60 |
def save_persistent_tags(self):
|
61 |
with open(self.persistent_tags_file, 'w') as f:
|
62 |
json.dump(self.persistent_tags, f)
|
63 |
+
|
64 |
+
def load_category_tags(self):
|
65 |
+
if os.path.exists(self.category_tags_file):
|
66 |
+
with open(self.category_tags_file, 'r') as f:
|
67 |
+
self.category_tags = json.load(f)
|
68 |
+
else:
|
69 |
+
self.category_tags = {}
|
70 |
+
|
71 |
+
def save_category_tags(self):
|
72 |
+
with open(self.category_tags_file, 'w') as f:
|
73 |
+
json.dump(self.category_tags, f)
|
74 |
+
|
75 |
+
def get_category_tags(self, category_var_name):
|
76 |
+
# Return the tags list for the given category variable name
|
77 |
+
return self.category_tags.get(category_var_name, [])
|
78 |
+
|
79 |
+
def set_category_tags(self, category_var_name, tags_list):
|
80 |
+
self.category_tags[category_var_name] = tags_list
|
81 |
+
self.save_category_tags()
|
82 |
+
|
83 |
+
def get_characters(self):
|
84 |
+
# Load character images
|
85 |
+
for char in self.characters:
|
86 |
+
image_path = char.get('image_path')
|
87 |
+
if image_path and os.path.exists(image_path):
|
88 |
+
char['image'] = image_path
|
89 |
+
else:
|
90 |
+
char['image'] = None
|
91 |
+
return self.characters
|
92 |
+
|
93 |
+
def set_persistent_tags(self, tags_list):
|
94 |
+
self.persistent_tags = tags_list
|
95 |
+
self.save_persistent_tags()
|
96 |
+
|
97 |
+
def get_persistent_tags(self):
|
98 |
+
return self.persistent_tags
|
99 |
+
|
100 |
def add_character(self, character):
|
101 |
# Save image to disk and store the filename
|
102 |
image_data = character['image'] # This is base64 encoded string
|
|
|
120 |
|
121 |
self.characters.append(character)
|
122 |
self.save_characters()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
def character_creation_app(data_manager):
|
125 |
with gr.Tab("Character Creation"):
|
|
|
146 |
character['image'] = None
|
147 |
|
148 |
data_manager.add_character(character)
|
|
|
149 |
return f"Character '{name}' saved successfully."
|
150 |
|
151 |
save_button.click(save_character, inputs=[name_input, traits_input, image_input, gender_input], outputs=output)
|
152 |
|
153 |
def prompt_generator_app(data_manager):
|
154 |
with gr.Tab("Prompt Generator"):
|
155 |
+
gr.Markdown("## Prompt Generator")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
inputs = {}
|
158 |
for category_name, var_name in categories:
|
159 |
+
tags_list = data_manager.get_category_tags(var_name)
|
160 |
+
tags_string = ', '.join(tags_list)
|
161 |
+
max_tags = len(tags_list)
|
162 |
+
if max_tags == 0:
|
163 |
+
default_value = 0
|
164 |
+
else:
|
165 |
+
default_value = min(1, max_tags)
|
166 |
+
with gr.Box():
|
167 |
+
gr.Markdown(f"### {category_name}")
|
168 |
+
gr.Markdown(f"**Tags:** {tags_string}")
|
169 |
+
tag_num = gr.Slider(minimum=0, maximum=max_tags, step=1, value=default_value, label=f"Number of {category_name} Tags to Select")
|
170 |
inputs[f"{var_name}_num"] = tag_num
|
171 |
|
172 |
# For Character Selection
|
|
|
199 |
generate_button = gr.Button("Generate Prompt")
|
200 |
prompt_output = gr.Textbox(label="Generated Prompt", lines=5)
|
201 |
|
202 |
+
def generate_prompt(*args, data_manager=data_manager, categories=categories):
|
203 |
arg_idx = 0
|
204 |
|
205 |
prompt_tags = []
|
206 |
for category_name, var_name in categories:
|
207 |
+
tags_list = data_manager.get_category_tags(var_name)
|
208 |
+
tags_num = args[arg_idx]
|
209 |
+
arg_idx += 1
|
210 |
|
|
|
211 |
if tags_list and tags_num > 0:
|
212 |
selected_tags = random.sample(tags_list, min(len(tags_list), int(tags_num)))
|
213 |
prompt_tags.extend(selected_tags)
|
|
|
287 |
|
288 |
# Prepare the list of inputs for the generate_prompt function
|
289 |
inputs_list = []
|
290 |
+
for category_name, var_name in categories:
|
291 |
+
inputs_list.append(inputs[f"{var_name}_num"])
|
292 |
# Add character_select directly to inputs
|
293 |
inputs_list.extend([character_select, random_characters, num_characters])
|
294 |
|
295 |
generate_button.click(generate_prompt, inputs=inputs_list, outputs=prompt_output)
|
296 |
|
297 |
+
def tags_app(data_manager):
|
298 |
+
with gr.Tab("Tags"):
|
299 |
+
gr.Markdown("## Edit Tags for Each Category")
|
300 |
+
|
301 |
+
for category_name, var_name in categories:
|
302 |
+
gr.Markdown(f"### {category_name} Tags")
|
303 |
+
tags_list = data_manager.get_category_tags(var_name)
|
304 |
+
tags_string = ', '.join(tags_list)
|
305 |
+
tag_input = gr.Textbox(label=f"{category_name} Tags (comma separated)", value=tags_string)
|
306 |
+
save_button = gr.Button(f"Save {category_name} Tags")
|
307 |
+
status_output = gr.Textbox(label="", interactive=False)
|
308 |
+
|
309 |
+
# Capture var_name and category_name in the closure
|
310 |
+
def make_save_category_tags_fn(var_name, category_name):
|
311 |
+
def fn(tags_string):
|
312 |
+
tags_list = [t.strip() for t in tags_string.split(',') if t.strip()]
|
313 |
+
data_manager.set_category_tags(var_name, tags_list)
|
314 |
+
return f"{category_name} tags saved successfully."
|
315 |
+
return fn
|
316 |
+
|
317 |
+
save_fn = make_save_category_tags_fn(var_name, category_name)
|
318 |
+
save_button.click(save_fn, inputs=tag_input, outputs=status_output)
|
319 |
+
|
320 |
+
# Persistent Tags
|
321 |
+
gr.Markdown(f"### Persistent Tags")
|
322 |
+
persistent_tags_string = ', '.join(data_manager.get_persistent_tags())
|
323 |
+
persistent_tags_input = gr.Textbox(label="Persistent Tags (comma separated)", value=persistent_tags_string)
|
324 |
save_persistent_tags_button = gr.Button("Save Persistent Tags")
|
325 |
+
persistent_status_output = gr.Textbox(label="", interactive=False)
|
326 |
|
327 |
def save_persistent_tags(tags_string):
|
328 |
+
tags_list = [t.strip() for t in tags_string.split(',') if t.strip()]
|
329 |
+
data_manager.set_persistent_tags(tags_list)
|
330 |
return "Persistent tags saved successfully."
|
331 |
|
332 |
+
save_persistent_tags_button.click(save_persistent_tags, inputs=persistent_tags_input, outputs=persistent_status_output)
|
333 |
|
334 |
def main():
|
335 |
data_manager = DataManager()
|
336 |
with gr.Blocks() as demo:
|
337 |
prompt_generator_app(data_manager)
|
338 |
character_creation_app(data_manager)
|
339 |
+
tags_app(data_manager)
|
340 |
|
341 |
demo.launch()
|
342 |
|