throaway2854 commited on
Commit
46e41c9
1 Parent(s): cf60c6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -159
app.py CHANGED
@@ -1,175 +1,211 @@
1
  import gradio as gr
2
- import random
3
  import json
4
  import os
 
5
 
6
- DATA_DIR = "/data"
7
- IMAGES_DIR = "/images"
8
- DATA_FILE = os.path.join(DATA_DIR, "saved_data.json")
9
-
10
- def load_data():
11
- if os.path.exists(DATA_FILE):
12
- with open(DATA_FILE, 'r') as f:
13
- return json.load(f)
14
- return {
15
- "scene_tags": [], "position_tags": [], "outfit_tags": [],
16
- "camera_tags": [], "concept_tags": [], "lora_tags": [],
17
- "characters": {}
18
- }
19
-
20
- def save_data(data):
21
- os.makedirs(DATA_DIR, exist_ok=True)
22
- with open(DATA_FILE, 'w') as f:
23
- json.dump(data, f)
24
-
25
- def save_character_image(name, image):
26
- os.makedirs(IMAGES_DIR, exist_ok=True)
27
- image_path = os.path.join(IMAGES_DIR, f"{name}.png")
28
- image.save(image_path)
29
- return image_path
30
-
31
- def generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_counts, data):
32
- all_tags = {
33
- "scene": scene_tags.split(','),
34
- "position": position_tags.split(','),
35
- "outfit": outfit_tags.split(','),
36
- "camera": camera_tags.split(','),
37
- "concept": concept_tags.split(','),
38
- }
39
-
40
- all_tags = {k: [tag.strip() for tag in v if tag.strip()] for k, v in all_tags.items()}
41
-
42
- character_prompts = [
43
- f"{char_name}, " + ", ".join(random.sample(data["characters"][char_name]["traits"],
44
- min(tag_counts["character"], len(data["characters"][char_name]["traits"]))))
45
- for char_name in selected_characters if char_name in data["characters"]
46
- ]
47
-
48
- selected_tags = [
49
- f"{tag}:{random.uniform(0.8, 1.2):.2f}"
50
- for category, tags in all_tags.items()
51
- for tag in random.sample(tags, min(tag_counts[category], len(tags)))
52
- ]
53
 
54
- if num_people.strip():
55
- selected_tags.append(f"{num_people} people:1.1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- prompt_parts = character_prompts + selected_tags
58
- random.shuffle(prompt_parts)
59
- main_prompt = ", ".join(prompt_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- lora_list = [lora.strip() for lora in lora_tags.split(',') if lora.strip()]
62
- lora_prompt = " ".join(f"<lora:{lora}:1>" for lora in lora_list)
 
 
 
 
 
 
 
63
 
64
- fixed_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed"
 
 
 
65
 
66
- return f"{main_prompt}, {fixed_tags} {lora_prompt}".strip()
 
67
 
68
- def update_data(data, key, value):
69
- data[key] = list(set(data[key] + [v.strip() for v in value.split(',') if v.strip()]))
70
- save_data(data)
71
- return data
 
 
 
 
72
 
73
- def create_character(name, traits, image, data):
74
- if name:
75
- data["characters"][name] = {
76
- "traits": [trait.strip() for trait in traits.split(',') if trait.strip()],
77
- "image": save_character_image(name, image) if image else None
78
- }
79
- save_data(data)
80
- return data, gr.update(choices=list(data["characters"].keys()))
81
 
82
- def create_ui():
83
- data = load_data()
84
 
85
- with gr.Blocks() as demo:
86
- gr.Markdown("# Advanced Pony SDXL Prompt Generator with Character Creation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- with gr.Tabs():
89
- with gr.TabItem("Prompt Generator"):
90
- with gr.Row():
91
- with gr.Column():
92
- scene_input = gr.Textbox(label="Scene Tags (comma-separated)", value=", ".join(data["scene_tags"]))
93
- scene_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of scene tags")
94
-
95
- num_people_input = gr.Textbox(label="Number of People")
96
-
97
- position_input = gr.Textbox(label="Position Tags (comma-separated)", value=", ".join(data["position_tags"]))
98
- position_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of position tags")
99
-
100
- character_select = gr.CheckboxGroup(label="Select Characters", choices=list(data["characters"].keys()))
101
- character_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of character traits")
102
-
103
- outfit_input = gr.Textbox(label="Outfit Tags (comma-separated)", value=", ".join(data["outfit_tags"]))
104
- outfit_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of outfit tags")
105
-
106
- camera_input = gr.Textbox(label="Camera View/Angle Tags (comma-separated)", value=", ".join(data["camera_tags"]))
107
- camera_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of camera tags")
108
-
109
- concept_input = gr.Textbox(label="Concept Tags (comma-separated)", value=", ".join(data["concept_tags"]))
110
- concept_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of concept tags")
111
-
112
- lora_input = gr.Textbox(label="LORA Tags (comma-separated)", value=", ".join(data["lora_tags"]))
113
- lora_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of LORA tags")
114
-
115
- generate_button = gr.Button("Generate Prompt")
116
-
117
- with gr.Column():
118
- output = gr.Textbox(label="Generated Prompt", lines=5)
119
-
120
- char_images = [char_data["image"] for char_data in data["characters"].values() if char_data["image"]]
121
- gr.Gallery(value=char_images, label="Character Images", show_label=True, elem_id="char_gallery", columns=2, rows=2, height="auto")
122
-
123
- with gr.TabItem("Character Creation"):
124
- with gr.Row():
125
- with gr.Column():
126
- char_name_input = gr.Textbox(label="Character Name")
127
- char_traits_input = gr.Textbox(label="Character Traits (comma-separated)")
128
- char_image_input = gr.Image(label="Character Image", type="pil")
129
- create_char_button = gr.Button("Create/Update Character")
130
-
131
- with gr.Column():
132
- char_gallery = gr.Gallery(label="Existing Characters", show_label=True, elem_id="char_gallery", columns=2, rows=2, height="auto")
133
-
134
- def update_and_generate(*args):
135
- nonlocal data
136
- scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, *tag_counts = args
137
- data = update_data(data, "scene_tags", scene_tags)
138
- data = update_data(data, "position_tags", position_tags)
139
- data = update_data(data, "outfit_tags", outfit_tags)
140
- data = update_data(data, "camera_tags", camera_tags)
141
- data = update_data(data, "concept_tags", concept_tags)
142
- data = update_data(data, "lora_tags", lora_tags)
143
-
144
- tag_count_dict = {
145
- "scene": tag_counts[0], "position": tag_counts[1], "character": tag_counts[2],
146
- "outfit": tag_counts[3], "camera": tag_counts[4], "concept": tag_counts[5], "lora": tag_counts[6]
147
- }
148
-
149
- return generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_count_dict, data)
150
-
151
- generate_button.click(
152
- update_and_generate,
153
- inputs=[scene_input, num_people_input, position_input, character_select, outfit_input, camera_input, concept_input, lora_input,
154
- scene_count, position_count, character_count, outfit_count, camera_count, concept_count, lora_count],
155
- outputs=[output]
156
- )
157
-
158
- def update_char_gallery():
159
- char_images = [char_data["image"] for char_data in data["characters"].values() if char_data["image"]]
160
- return gr.Gallery(value=char_images)
161
-
162
- create_char_button.click(
163
- create_character,
164
- inputs=[char_name_input, char_traits_input, char_image_input],
165
- outputs=[gr.State(data), character_select]
166
- ).then(
167
- update_char_gallery,
168
- outputs=[char_gallery]
169
- )
170
-
171
- return demo
172
 
173
  if __name__ == "__main__":
174
- demo = create_ui()
175
- demo.launch()
 
1
  import gradio as gr
 
2
  import json
3
  import os
4
+ import random
5
 
6
+ class DataManager:
7
+ def __init__(self, characters_file='characters.json', persistent_tags_file='persistent_tags.json', images_folder='character_images'):
8
+ self.characters_file = characters_file
9
+ self.persistent_tags_file = persistent_tags_file
10
+ self.images_folder = images_folder
11
+
12
+ # Make sure image folder exists
13
+ if not os.path.exists(images_folder):
14
+ os.makedirs(images_folder)
15
+
16
+ self.load_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ def load_data(self):
19
+ # Load characters
20
+ if os.path.exists(self.characters_file):
21
+ with open(self.characters_file, 'r') as f:
22
+ self.characters = json.load(f)
23
+ else:
24
+ self.characters = []
25
+
26
+ # Load persistent tags
27
+ if os.path.exists(self.persistent_tags_file):
28
+ with open(self.persistent_tags_file, 'r') as f:
29
+ self.persistent_tags = json.load(f)
30
+ else:
31
+ self.persistent_tags = []
32
+
33
+ def save_characters(self):
34
+ with open(self.characters_file, 'w') as f:
35
+ json.dump(self.characters, f)
36
+
37
+ def save_persistent_tags(self):
38
+ with open(self.persistent_tags_file, 'w') as f:
39
+ json.dump(self.persistent_tags, f)
40
 
41
+ def add_character(self, character):
42
+ # Save image to disk and store the filename
43
+ image_data = character['image'] # This is a PIL Image
44
+ image_filename = f"{character['name']}.png"
45
+ image_path = os.path.join(self.images_folder, image_filename)
46
+ image_data.save(image_path)
47
+ character['image_path'] = image_path
48
+ character.pop('image', None)
49
+
50
+ # Assuming traits is a string, split into list if necessary
51
+ if isinstance(character['traits'], str):
52
+ character['traits'] = character['traits'].split(',')
53
+ character['traits'] = [t.strip() for t in character['traits']]
54
+
55
+ self.characters.append(character)
56
+ self.save_characters()
57
 
58
+ def get_characters(self):
59
+ # Load character images
60
+ for char in self.characters:
61
+ image_path = char.get('image_path')
62
+ if image_path and os.path.exists(image_path):
63
+ char['image'] = image_path
64
+ else:
65
+ char['image'] = None
66
+ return self.characters
67
 
68
+ def set_persistent_tags(self, tags_string):
69
+ if isinstance(tags_string, str):
70
+ self.persistent_tags = [t.strip() for t in tags_string.split(',') if t.strip()]
71
+ self.save_persistent_tags()
72
 
73
+ def get_persistent_tags(self):
74
+ return self.persistent_tags
75
 
76
+ def character_creation_app(data_manager):
77
+ with gr.Tab("Character Creation"):
78
+ with gr.Row():
79
+ name_input = gr.Textbox(label="Character Name")
80
+ traits_input = gr.Textbox(label="Traits/Appearance Tags (comma separated)")
81
+ image_input = gr.Image(label="Upload Character Image", type="pil")
82
+ save_button = gr.Button("Save Character")
83
+ output = gr.Textbox(label="Status", interactive=False)
84
 
85
+ def save_character(name, traits, image):
86
+ if not name or not traits or image is None:
87
+ return "Please enter all fields."
88
+ character = {'name': name, 'traits': traits, 'image': image}
89
+ data_manager.add_character(character)
90
+ return f"Character '{name}' saved successfully."
 
 
91
 
92
+ save_button.click(save_character, inputs=[name_input, traits_input, image_input], outputs=output)
 
93
 
94
+ def prompt_generator_app(data_manager):
95
+ with gr.Tab("Prompt Generator"):
96
+ # Input fields for each category
97
+ categories = [
98
+ ('Scene', 'scene_tags'),
99
+ ('Position', 'position_tags'),
100
+ ('Outfit', 'outfit_tags'),
101
+ ('Camera View/Angle', 'camera_tags'),
102
+ ('Concept', 'concept_tags'),
103
+ ('Additional', 'additional_tags'),
104
+ ('LORA', 'lora_tags')
105
+ ]
106
+
107
+ inputs = {}
108
+ for category_name, var_name in categories:
109
+ with gr.Row():
110
+ tag_input = gr.Textbox(label=f"{category_name} Tags (comma separated)")
111
+ tag_num = gr.Slider(minimum=0, maximum=10, step=1, value=1, label=f"Number of {category_name} Tags to Select")
112
+ inputs[f"{var_name}_input"] = tag_input
113
+ inputs[f"{var_name}_num"] = tag_num
114
+
115
+ # For Character Selection
116
+ with gr.Box():
117
+ gr.Markdown("### Character Selection")
118
+ character_options = [char['name'] for char in data_manager.get_characters()]
119
+ character_select = gr.CheckboxGroup(choices=character_options, label="Select Characters")
120
+ random_characters = gr.Checkbox(label="Select Random Characters")
121
+ num_characters = gr.Slider(minimum=0, maximum=10, step=1, value=1, label="Number of Characters (if random)")
122
+
123
+ # Number of people in the scene
124
+ num_people = gr.Slider(minimum=0, maximum=10, step=1, value=1, label="Number of People in the Scene")
125
+
126
+ generate_button = gr.Button("Generate Prompt")
127
+ prompt_output = gr.Textbox(label="Generated Prompt", lines=5)
128
+
129
+ def generate_prompt(*args):
130
+ # args correspond to inputs in the order they are defined
131
+ # Need to map args to variables
132
+ arg_idx = 0
133
+
134
+ prompt_tags = []
135
+ for category_name, var_name in categories:
136
+ tags_input = args[arg_idx]
137
+ tags_num = args[arg_idx + 1]
138
+ arg_idx += 2
139
+
140
+ tags_list = [tag.strip() for tag in tags_input.split(',')] if tags_input else []
141
+ if tags_list and tags_num > 0:
142
+ selected_tags = random.sample(tags_list, min(len(tags_list), int(tags_num)))
143
+ prompt_tags.extend(selected_tags)
144
 
145
+ # Handle Characters
146
+ selected_characters = args[arg_idx]
147
+ random_chars = args[arg_idx + 1]
148
+ num_random_chars = args[arg_idx + 2]
149
+ num_people_in_scene = args[arg_idx + 3]
150
+
151
+ arg_idx += 4
152
+
153
+ characters = data_manager.get_characters()
154
+ selected_chars = []
155
+ if random_chars:
156
+ num = min(len(characters), int(num_random_chars))
157
+ selected_chars = random.sample(characters, num)
158
+ else:
159
+ selected_chars = [char for char in characters if char['name'] in selected_characters]
160
+
161
+ # Adjust the number of people in the scene
162
+ if num_people_in_scene > len(selected_chars):
163
+ # Add generic people or adjust as needed
164
+ pass # This part can be customized based on requirements
165
+
166
+ for idx, char in enumerate(selected_chars):
167
+ prompt_tags.append(f"{char['name']}")
168
+ prompt_tags.extend(char['traits'])
169
+
170
+ # Load persistent tags
171
+ persistent_tags = data_manager.get_persistent_tags()
172
+ prompt_tags.extend(persistent_tags)
173
+
174
+ # Assemble prompt with classic syntax (word:weight)
175
+ prompt = ', '.join(f"({tag}:1.0)" for tag in prompt_tags if tag)
176
+ # Add ending tags
177
+ ending_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed"
178
+ prompt = f"{prompt}, {ending_tags}"
179
+ return prompt
180
+
181
+ # Prepare the list of inputs for the generate_prompt function
182
+ inputs_list = []
183
+ for var_name in inputs:
184
+ inputs_list.append(inputs[var_name])
185
+ inputs_list.extend([character_select, random_characters, num_characters, num_people])
186
+
187
+ generate_button.click(generate_prompt, inputs=inputs_list, outputs=prompt_output)
188
+
189
+ def settings_app(data_manager):
190
+ with gr.Tab("Settings"):
191
+ persistent_tags_input = gr.Textbox(label="Persistent Tags (comma separated)")
192
+ save_persistent_tags_button = gr.Button("Save Persistent Tags")
193
+ status_output = gr.Textbox(label="Status", interactive=False)
194
+
195
+ def save_persistent_tags(tags_string):
196
+ data_manager.set_persistent_tags(tags_string)
197
+ return "Persistent tags saved successfully."
198
+
199
+ save_persistent_tags_button.click(save_persistent_tags, inputs=persistent_tags_input, outputs=status_output)
200
+
201
+ def main():
202
+ data_manager = DataManager()
203
+ with gr.Blocks() as demo:
204
+ character_creation_app(data_manager)
205
+ prompt_generator_app(data_manager)
206
+ settings_app(data_manager)
207
+
208
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  if __name__ == "__main__":
211
+ main()