Spaces:
Paused
Paused
throaway2854
commited on
Commit
•
46e41c9
1
Parent(s):
cf60c6f
Update app.py
Browse files
app.py
CHANGED
@@ -1,175 +1,211 @@
|
|
1 |
import gradio as gr
|
2 |
-
import random
|
3 |
import json
|
4 |
import os
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
"characters": {}
|
18 |
-
}
|
19 |
-
|
20 |
-
def save_data(data):
|
21 |
-
os.makedirs(DATA_DIR, exist_ok=True)
|
22 |
-
with open(DATA_FILE, 'w') as f:
|
23 |
-
json.dump(data, f)
|
24 |
-
|
25 |
-
def save_character_image(name, image):
|
26 |
-
os.makedirs(IMAGES_DIR, exist_ok=True)
|
27 |
-
image_path = os.path.join(IMAGES_DIR, f"{name}.png")
|
28 |
-
image.save(image_path)
|
29 |
-
return image_path
|
30 |
-
|
31 |
-
def generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_counts, data):
|
32 |
-
all_tags = {
|
33 |
-
"scene": scene_tags.split(','),
|
34 |
-
"position": position_tags.split(','),
|
35 |
-
"outfit": outfit_tags.split(','),
|
36 |
-
"camera": camera_tags.split(','),
|
37 |
-
"concept": concept_tags.split(','),
|
38 |
-
}
|
39 |
-
|
40 |
-
all_tags = {k: [tag.strip() for tag in v if tag.strip()] for k, v in all_tags.items()}
|
41 |
-
|
42 |
-
character_prompts = [
|
43 |
-
f"{char_name}, " + ", ".join(random.sample(data["characters"][char_name]["traits"],
|
44 |
-
min(tag_counts["character"], len(data["characters"][char_name]["traits"]))))
|
45 |
-
for char_name in selected_characters if char_name in data["characters"]
|
46 |
-
]
|
47 |
-
|
48 |
-
selected_tags = [
|
49 |
-
f"{tag}:{random.uniform(0.8, 1.2):.2f}"
|
50 |
-
for category, tags in all_tags.items()
|
51 |
-
for tag in random.sample(tags, min(tag_counts[category], len(tags)))
|
52 |
-
]
|
53 |
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
|
|
|
|
|
|
65 |
|
66 |
-
|
|
|
67 |
|
68 |
-
def
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
def
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
}
|
79 |
-
save_data(data)
|
80 |
-
return data, gr.update(choices=list(data["characters"].keys()))
|
81 |
|
82 |
-
|
83 |
-
data = load_data()
|
84 |
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
update_and_generate,
|
153 |
-
inputs=[scene_input, num_people_input, position_input, character_select, outfit_input, camera_input, concept_input, lora_input,
|
154 |
-
scene_count, position_count, character_count, outfit_count, camera_count, concept_count, lora_count],
|
155 |
-
outputs=[output]
|
156 |
-
)
|
157 |
-
|
158 |
-
def update_char_gallery():
|
159 |
-
char_images = [char_data["image"] for char_data in data["characters"].values() if char_data["image"]]
|
160 |
-
return gr.Gallery(value=char_images)
|
161 |
-
|
162 |
-
create_char_button.click(
|
163 |
-
create_character,
|
164 |
-
inputs=[char_name_input, char_traits_input, char_image_input],
|
165 |
-
outputs=[gr.State(data), character_select]
|
166 |
-
).then(
|
167 |
-
update_char_gallery,
|
168 |
-
outputs=[char_gallery]
|
169 |
-
)
|
170 |
-
|
171 |
-
return demo
|
172 |
|
173 |
if __name__ == "__main__":
|
174 |
-
|
175 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import json
|
3 |
import os
|
4 |
+
import random
|
5 |
|
6 |
+
class DataManager:
|
7 |
+
def __init__(self, characters_file='characters.json', persistent_tags_file='persistent_tags.json', images_folder='character_images'):
|
8 |
+
self.characters_file = characters_file
|
9 |
+
self.persistent_tags_file = persistent_tags_file
|
10 |
+
self.images_folder = images_folder
|
11 |
+
|
12 |
+
# Make sure image folder exists
|
13 |
+
if not os.path.exists(images_folder):
|
14 |
+
os.makedirs(images_folder)
|
15 |
+
|
16 |
+
self.load_data()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
def load_data(self):
|
19 |
+
# Load characters
|
20 |
+
if os.path.exists(self.characters_file):
|
21 |
+
with open(self.characters_file, 'r') as f:
|
22 |
+
self.characters = json.load(f)
|
23 |
+
else:
|
24 |
+
self.characters = []
|
25 |
+
|
26 |
+
# Load persistent tags
|
27 |
+
if os.path.exists(self.persistent_tags_file):
|
28 |
+
with open(self.persistent_tags_file, 'r') as f:
|
29 |
+
self.persistent_tags = json.load(f)
|
30 |
+
else:
|
31 |
+
self.persistent_tags = []
|
32 |
+
|
33 |
+
def save_characters(self):
|
34 |
+
with open(self.characters_file, 'w') as f:
|
35 |
+
json.dump(self.characters, f)
|
36 |
+
|
37 |
+
def save_persistent_tags(self):
|
38 |
+
with open(self.persistent_tags_file, 'w') as f:
|
39 |
+
json.dump(self.persistent_tags, f)
|
40 |
|
41 |
+
def add_character(self, character):
|
42 |
+
# Save image to disk and store the filename
|
43 |
+
image_data = character['image'] # This is a PIL Image
|
44 |
+
image_filename = f"{character['name']}.png"
|
45 |
+
image_path = os.path.join(self.images_folder, image_filename)
|
46 |
+
image_data.save(image_path)
|
47 |
+
character['image_path'] = image_path
|
48 |
+
character.pop('image', None)
|
49 |
+
|
50 |
+
# Assuming traits is a string, split into list if necessary
|
51 |
+
if isinstance(character['traits'], str):
|
52 |
+
character['traits'] = character['traits'].split(',')
|
53 |
+
character['traits'] = [t.strip() for t in character['traits']]
|
54 |
+
|
55 |
+
self.characters.append(character)
|
56 |
+
self.save_characters()
|
57 |
|
58 |
+
def get_characters(self):
|
59 |
+
# Load character images
|
60 |
+
for char in self.characters:
|
61 |
+
image_path = char.get('image_path')
|
62 |
+
if image_path and os.path.exists(image_path):
|
63 |
+
char['image'] = image_path
|
64 |
+
else:
|
65 |
+
char['image'] = None
|
66 |
+
return self.characters
|
67 |
|
68 |
+
def set_persistent_tags(self, tags_string):
|
69 |
+
if isinstance(tags_string, str):
|
70 |
+
self.persistent_tags = [t.strip() for t in tags_string.split(',') if t.strip()]
|
71 |
+
self.save_persistent_tags()
|
72 |
|
73 |
+
def get_persistent_tags(self):
|
74 |
+
return self.persistent_tags
|
75 |
|
76 |
+
def character_creation_app(data_manager):
|
77 |
+
with gr.Tab("Character Creation"):
|
78 |
+
with gr.Row():
|
79 |
+
name_input = gr.Textbox(label="Character Name")
|
80 |
+
traits_input = gr.Textbox(label="Traits/Appearance Tags (comma separated)")
|
81 |
+
image_input = gr.Image(label="Upload Character Image", type="pil")
|
82 |
+
save_button = gr.Button("Save Character")
|
83 |
+
output = gr.Textbox(label="Status", interactive=False)
|
84 |
|
85 |
+
def save_character(name, traits, image):
|
86 |
+
if not name or not traits or image is None:
|
87 |
+
return "Please enter all fields."
|
88 |
+
character = {'name': name, 'traits': traits, 'image': image}
|
89 |
+
data_manager.add_character(character)
|
90 |
+
return f"Character '{name}' saved successfully."
|
|
|
|
|
91 |
|
92 |
+
save_button.click(save_character, inputs=[name_input, traits_input, image_input], outputs=output)
|
|
|
93 |
|
94 |
+
def prompt_generator_app(data_manager):
|
95 |
+
with gr.Tab("Prompt Generator"):
|
96 |
+
# Input fields for each category
|
97 |
+
categories = [
|
98 |
+
('Scene', 'scene_tags'),
|
99 |
+
('Position', 'position_tags'),
|
100 |
+
('Outfit', 'outfit_tags'),
|
101 |
+
('Camera View/Angle', 'camera_tags'),
|
102 |
+
('Concept', 'concept_tags'),
|
103 |
+
('Additional', 'additional_tags'),
|
104 |
+
('LORA', 'lora_tags')
|
105 |
+
]
|
106 |
+
|
107 |
+
inputs = {}
|
108 |
+
for category_name, var_name in categories:
|
109 |
+
with gr.Row():
|
110 |
+
tag_input = gr.Textbox(label=f"{category_name} Tags (comma separated)")
|
111 |
+
tag_num = gr.Slider(minimum=0, maximum=10, step=1, value=1, label=f"Number of {category_name} Tags to Select")
|
112 |
+
inputs[f"{var_name}_input"] = tag_input
|
113 |
+
inputs[f"{var_name}_num"] = tag_num
|
114 |
+
|
115 |
+
# For Character Selection
|
116 |
+
with gr.Box():
|
117 |
+
gr.Markdown("### Character Selection")
|
118 |
+
character_options = [char['name'] for char in data_manager.get_characters()]
|
119 |
+
character_select = gr.CheckboxGroup(choices=character_options, label="Select Characters")
|
120 |
+
random_characters = gr.Checkbox(label="Select Random Characters")
|
121 |
+
num_characters = gr.Slider(minimum=0, maximum=10, step=1, value=1, label="Number of Characters (if random)")
|
122 |
+
|
123 |
+
# Number of people in the scene
|
124 |
+
num_people = gr.Slider(minimum=0, maximum=10, step=1, value=1, label="Number of People in the Scene")
|
125 |
+
|
126 |
+
generate_button = gr.Button("Generate Prompt")
|
127 |
+
prompt_output = gr.Textbox(label="Generated Prompt", lines=5)
|
128 |
+
|
129 |
+
def generate_prompt(*args):
|
130 |
+
# args correspond to inputs in the order they are defined
|
131 |
+
# Need to map args to variables
|
132 |
+
arg_idx = 0
|
133 |
+
|
134 |
+
prompt_tags = []
|
135 |
+
for category_name, var_name in categories:
|
136 |
+
tags_input = args[arg_idx]
|
137 |
+
tags_num = args[arg_idx + 1]
|
138 |
+
arg_idx += 2
|
139 |
+
|
140 |
+
tags_list = [tag.strip() for tag in tags_input.split(',')] if tags_input else []
|
141 |
+
if tags_list and tags_num > 0:
|
142 |
+
selected_tags = random.sample(tags_list, min(len(tags_list), int(tags_num)))
|
143 |
+
prompt_tags.extend(selected_tags)
|
144 |
|
145 |
+
# Handle Characters
|
146 |
+
selected_characters = args[arg_idx]
|
147 |
+
random_chars = args[arg_idx + 1]
|
148 |
+
num_random_chars = args[arg_idx + 2]
|
149 |
+
num_people_in_scene = args[arg_idx + 3]
|
150 |
+
|
151 |
+
arg_idx += 4
|
152 |
+
|
153 |
+
characters = data_manager.get_characters()
|
154 |
+
selected_chars = []
|
155 |
+
if random_chars:
|
156 |
+
num = min(len(characters), int(num_random_chars))
|
157 |
+
selected_chars = random.sample(characters, num)
|
158 |
+
else:
|
159 |
+
selected_chars = [char for char in characters if char['name'] in selected_characters]
|
160 |
+
|
161 |
+
# Adjust the number of people in the scene
|
162 |
+
if num_people_in_scene > len(selected_chars):
|
163 |
+
# Add generic people or adjust as needed
|
164 |
+
pass # This part can be customized based on requirements
|
165 |
+
|
166 |
+
for idx, char in enumerate(selected_chars):
|
167 |
+
prompt_tags.append(f"{char['name']}")
|
168 |
+
prompt_tags.extend(char['traits'])
|
169 |
+
|
170 |
+
# Load persistent tags
|
171 |
+
persistent_tags = data_manager.get_persistent_tags()
|
172 |
+
prompt_tags.extend(persistent_tags)
|
173 |
+
|
174 |
+
# Assemble prompt with classic syntax (word:weight)
|
175 |
+
prompt = ', '.join(f"({tag}:1.0)" for tag in prompt_tags if tag)
|
176 |
+
# Add ending tags
|
177 |
+
ending_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed"
|
178 |
+
prompt = f"{prompt}, {ending_tags}"
|
179 |
+
return prompt
|
180 |
+
|
181 |
+
# Prepare the list of inputs for the generate_prompt function
|
182 |
+
inputs_list = []
|
183 |
+
for var_name in inputs:
|
184 |
+
inputs_list.append(inputs[var_name])
|
185 |
+
inputs_list.extend([character_select, random_characters, num_characters, num_people])
|
186 |
+
|
187 |
+
generate_button.click(generate_prompt, inputs=inputs_list, outputs=prompt_output)
|
188 |
+
|
189 |
+
def settings_app(data_manager):
|
190 |
+
with gr.Tab("Settings"):
|
191 |
+
persistent_tags_input = gr.Textbox(label="Persistent Tags (comma separated)")
|
192 |
+
save_persistent_tags_button = gr.Button("Save Persistent Tags")
|
193 |
+
status_output = gr.Textbox(label="Status", interactive=False)
|
194 |
+
|
195 |
+
def save_persistent_tags(tags_string):
|
196 |
+
data_manager.set_persistent_tags(tags_string)
|
197 |
+
return "Persistent tags saved successfully."
|
198 |
+
|
199 |
+
save_persistent_tags_button.click(save_persistent_tags, inputs=persistent_tags_input, outputs=status_output)
|
200 |
+
|
201 |
+
def main():
|
202 |
+
data_manager = DataManager()
|
203 |
+
with gr.Blocks() as demo:
|
204 |
+
character_creation_app(data_manager)
|
205 |
+
prompt_generator_app(data_manager)
|
206 |
+
settings_app(data_manager)
|
207 |
+
|
208 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
if __name__ == "__main__":
|
211 |
+
main()
|
|