throaway2854 commited on
Commit
20fec7d
1 Parent(s): 93e965c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -52
app.py CHANGED
@@ -6,12 +6,27 @@ from PIL import Image
6
  import base64
7
  import io
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class DataManager:
10
- def __init__(self, characters_file='characters.json',
11
- persistent_tags_file='persistent_tags.json',
 
12
  images_folder='character_images'):
13
  self.characters_file = characters_file
14
  self.persistent_tags_file = persistent_tags_file
 
15
  self.images_folder = images_folder
16
 
17
  # Make sure image folder exists
@@ -19,7 +34,7 @@ class DataManager:
19
  os.makedirs(images_folder)
20
 
21
  self.load_data()
22
-
23
  def load_data(self):
24
  # Load characters
25
  if os.path.exists(self.characters_file):
@@ -27,7 +42,7 @@ class DataManager:
27
  self.characters = json.load(f)
28
  else:
29
  self.characters = []
30
-
31
  # Load persistent tags
32
  if os.path.exists(self.persistent_tags_file):
33
  with open(self.persistent_tags_file, 'r') as f:
@@ -35,6 +50,9 @@ class DataManager:
35
  else:
36
  self.persistent_tags = []
37
 
 
 
 
38
  def save_characters(self):
39
  with open(self.characters_file, 'w') as f:
40
  json.dump(self.characters, f)
@@ -42,7 +60,43 @@ class DataManager:
42
  def save_persistent_tags(self):
43
  with open(self.persistent_tags_file, 'w') as f:
44
  json.dump(self.persistent_tags, f)
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def add_character(self, character):
47
  # Save image to disk and store the filename
48
  image_data = character['image'] # This is base64 encoded string
@@ -66,24 +120,6 @@ class DataManager:
66
 
67
  self.characters.append(character)
68
  self.save_characters()
69
-
70
- def get_characters(self):
71
- # Load character images
72
- for char in self.characters:
73
- image_path = char.get('image_path')
74
- if image_path and os.path.exists(image_path):
75
- char['image'] = image_path
76
- else:
77
- char['image'] = None
78
- return self.characters
79
-
80
- def set_persistent_tags(self, tags_string):
81
- if isinstance(tags_string, str):
82
- self.persistent_tags = [t.strip() for t in tags_string.split(',') if t.strip()]
83
- self.save_persistent_tags()
84
-
85
- def get_persistent_tags(self):
86
- return self.persistent_tags
87
 
88
  def character_creation_app(data_manager):
89
  with gr.Tab("Character Creation"):
@@ -110,30 +146,27 @@ def character_creation_app(data_manager):
110
  character['image'] = None
111
 
112
  data_manager.add_character(character)
113
- # Clear inputs after saving
114
  return f"Character '{name}' saved successfully."
115
 
116
  save_button.click(save_character, inputs=[name_input, traits_input, image_input, gender_input], outputs=output)
117
 
118
  def prompt_generator_app(data_manager):
119
  with gr.Tab("Prompt Generator"):
120
- # Input fields for each category
121
- categories = [
122
- ('Scene', 'scene_tags'),
123
- ('Position', 'position_tags'),
124
- ('Outfit', 'outfit_tags'),
125
- ('Camera View/Angle', 'camera_tags'),
126
- ('Concept', 'concept_tags'),
127
- ('Additional', 'additional_tags'),
128
- ('LORA', 'lora_tags')
129
- ]
130
 
131
  inputs = {}
132
  for category_name, var_name in categories:
133
- with gr.Row():
134
- tag_input = gr.Textbox(label=f"{category_name} Tags (comma separated)")
135
- tag_num = gr.Slider(minimum=0, maximum=10, step=1, value=1, label=f"Number of {category_name} Tags to Select")
136
- inputs[f"{var_name}_input"] = tag_input
 
 
 
 
 
 
 
137
  inputs[f"{var_name}_num"] = tag_num
138
 
139
  # For Character Selection
@@ -166,16 +199,15 @@ def prompt_generator_app(data_manager):
166
  generate_button = gr.Button("Generate Prompt")
167
  prompt_output = gr.Textbox(label="Generated Prompt", lines=5)
168
 
169
- def generate_prompt(*args):
170
  arg_idx = 0
171
 
172
  prompt_tags = []
173
  for category_name, var_name in categories:
174
- tags_input = args[arg_idx]
175
- tags_num = args[arg_idx + 1]
176
- arg_idx += 2
177
 
178
- tags_list = [tag.strip() for tag in tags_input.split(',')] if tags_input else []
179
  if tags_list and tags_num > 0:
180
  selected_tags = random.sample(tags_list, min(len(tags_list), int(tags_num)))
181
  prompt_tags.extend(selected_tags)
@@ -255,31 +287,56 @@ def prompt_generator_app(data_manager):
255
 
256
  # Prepare the list of inputs for the generate_prompt function
257
  inputs_list = []
258
- for var_name in inputs:
259
- inputs_list.append(inputs[var_name])
260
  # Add character_select directly to inputs
261
  inputs_list.extend([character_select, random_characters, num_characters])
262
 
263
  generate_button.click(generate_prompt, inputs=inputs_list, outputs=prompt_output)
264
 
265
- def settings_app(data_manager):
266
- with gr.Tab("Settings"):
267
- persistent_tags_input = gr.Textbox(label="Persistent Tags (comma separated)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  save_persistent_tags_button = gr.Button("Save Persistent Tags")
269
- status_output = gr.Textbox(label="Status", interactive=False)
270
 
271
  def save_persistent_tags(tags_string):
272
- data_manager.set_persistent_tags(tags_string)
 
273
  return "Persistent tags saved successfully."
274
 
275
- save_persistent_tags_button.click(save_persistent_tags, inputs=[persistent_tags_input], outputs=status_output)
276
 
277
  def main():
278
  data_manager = DataManager()
279
  with gr.Blocks() as demo:
280
  prompt_generator_app(data_manager)
281
  character_creation_app(data_manager)
282
- settings_app(data_manager)
283
 
284
  demo.launch()
285
 
 
6
  import base64
7
  import io
8
 
9
+ # Define categories at the top so they are accessible throughout the code
10
+ categories = [
11
+ ('Scene', 'scene_tags'),
12
+ ('Position', 'position_tags'),
13
+ ('Outfit', 'outfit_tags'),
14
+ ('Camera View/Angle', 'camera_tags'),
15
+ ('Concept', 'concept_tags'),
16
+ ('Facial Expression', 'facial_expression_tags'),
17
+ ('Pose', 'pose_tags'),
18
+ ('Additional', 'additional_tags'),
19
+ ('LORA', 'lora_tags')
20
+ ]
21
+
22
  class DataManager:
23
+ def __init__(self, characters_file='characters.json',
24
+ persistent_tags_file='persistent_tags.json',
25
+ category_tags_file='category_tags.json',
26
  images_folder='character_images'):
27
  self.characters_file = characters_file
28
  self.persistent_tags_file = persistent_tags_file
29
+ self.category_tags_file = category_tags_file
30
  self.images_folder = images_folder
31
 
32
  # Make sure image folder exists
 
34
  os.makedirs(images_folder)
35
 
36
  self.load_data()
37
+
38
  def load_data(self):
39
  # Load characters
40
  if os.path.exists(self.characters_file):
 
42
  self.characters = json.load(f)
43
  else:
44
  self.characters = []
45
+
46
  # Load persistent tags
47
  if os.path.exists(self.persistent_tags_file):
48
  with open(self.persistent_tags_file, 'r') as f:
 
50
  else:
51
  self.persistent_tags = []
52
 
53
+ # Load category tags
54
+ self.load_category_tags()
55
+
56
  def save_characters(self):
57
  with open(self.characters_file, 'w') as f:
58
  json.dump(self.characters, f)
 
60
  def save_persistent_tags(self):
61
  with open(self.persistent_tags_file, 'w') as f:
62
  json.dump(self.persistent_tags, f)
63
+
64
+ def load_category_tags(self):
65
+ if os.path.exists(self.category_tags_file):
66
+ with open(self.category_tags_file, 'r') as f:
67
+ self.category_tags = json.load(f)
68
+ else:
69
+ self.category_tags = {}
70
+
71
+ def save_category_tags(self):
72
+ with open(self.category_tags_file, 'w') as f:
73
+ json.dump(self.category_tags, f)
74
+
75
+ def get_category_tags(self, category_var_name):
76
+ # Return the tags list for the given category variable name
77
+ return self.category_tags.get(category_var_name, [])
78
+
79
+ def set_category_tags(self, category_var_name, tags_list):
80
+ self.category_tags[category_var_name] = tags_list
81
+ self.save_category_tags()
82
+
83
+ def get_characters(self):
84
+ # Load character images
85
+ for char in self.characters:
86
+ image_path = char.get('image_path')
87
+ if image_path and os.path.exists(image_path):
88
+ char['image'] = image_path
89
+ else:
90
+ char['image'] = None
91
+ return self.characters
92
+
93
+ def set_persistent_tags(self, tags_list):
94
+ self.persistent_tags = tags_list
95
+ self.save_persistent_tags()
96
+
97
+ def get_persistent_tags(self):
98
+ return self.persistent_tags
99
+
100
  def add_character(self, character):
101
  # Save image to disk and store the filename
102
  image_data = character['image'] # This is base64 encoded string
 
120
 
121
  self.characters.append(character)
122
  self.save_characters()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  def character_creation_app(data_manager):
125
  with gr.Tab("Character Creation"):
 
146
  character['image'] = None
147
 
148
  data_manager.add_character(character)
 
149
  return f"Character '{name}' saved successfully."
150
 
151
  save_button.click(save_character, inputs=[name_input, traits_input, image_input, gender_input], outputs=output)
152
 
153
  def prompt_generator_app(data_manager):
154
  with gr.Tab("Prompt Generator"):
155
+ gr.Markdown("## Prompt Generator")
 
 
 
 
 
 
 
 
 
156
 
157
  inputs = {}
158
  for category_name, var_name in categories:
159
+ tags_list = data_manager.get_category_tags(var_name)
160
+ tags_string = ', '.join(tags_list)
161
+ max_tags = len(tags_list)
162
+ if max_tags == 0:
163
+ default_value = 0
164
+ else:
165
+ default_value = min(1, max_tags)
166
+ with gr.Box():
167
+ gr.Markdown(f"### {category_name}")
168
+ gr.Markdown(f"**Tags:** {tags_string}")
169
+ tag_num = gr.Slider(minimum=0, maximum=max_tags, step=1, value=default_value, label=f"Number of {category_name} Tags to Select")
170
  inputs[f"{var_name}_num"] = tag_num
171
 
172
  # For Character Selection
 
199
  generate_button = gr.Button("Generate Prompt")
200
  prompt_output = gr.Textbox(label="Generated Prompt", lines=5)
201
 
202
+ def generate_prompt(*args, data_manager=data_manager, categories=categories):
203
  arg_idx = 0
204
 
205
  prompt_tags = []
206
  for category_name, var_name in categories:
207
+ tags_list = data_manager.get_category_tags(var_name)
208
+ tags_num = args[arg_idx]
209
+ arg_idx += 1
210
 
 
211
  if tags_list and tags_num > 0:
212
  selected_tags = random.sample(tags_list, min(len(tags_list), int(tags_num)))
213
  prompt_tags.extend(selected_tags)
 
287
 
288
  # Prepare the list of inputs for the generate_prompt function
289
  inputs_list = []
290
+ for category_name, var_name in categories:
291
+ inputs_list.append(inputs[f"{var_name}_num"])
292
  # Add character_select directly to inputs
293
  inputs_list.extend([character_select, random_characters, num_characters])
294
 
295
  generate_button.click(generate_prompt, inputs=inputs_list, outputs=prompt_output)
296
 
297
+ def tags_app(data_manager):
298
+ with gr.Tab("Tags"):
299
+ gr.Markdown("## Edit Tags for Each Category")
300
+
301
+ for category_name, var_name in categories:
302
+ gr.Markdown(f"### {category_name} Tags")
303
+ tags_list = data_manager.get_category_tags(var_name)
304
+ tags_string = ', '.join(tags_list)
305
+ tag_input = gr.Textbox(label=f"{category_name} Tags (comma separated)", value=tags_string)
306
+ save_button = gr.Button(f"Save {category_name} Tags")
307
+ status_output = gr.Textbox(label="", interactive=False)
308
+
309
+ # Capture var_name and category_name in the closure
310
+ def make_save_category_tags_fn(var_name, category_name):
311
+ def fn(tags_string):
312
+ tags_list = [t.strip() for t in tags_string.split(',') if t.strip()]
313
+ data_manager.set_category_tags(var_name, tags_list)
314
+ return f"{category_name} tags saved successfully."
315
+ return fn
316
+
317
+ save_fn = make_save_category_tags_fn(var_name, category_name)
318
+ save_button.click(save_fn, inputs=tag_input, outputs=status_output)
319
+
320
+ # Persistent Tags
321
+ gr.Markdown(f"### Persistent Tags")
322
+ persistent_tags_string = ', '.join(data_manager.get_persistent_tags())
323
+ persistent_tags_input = gr.Textbox(label="Persistent Tags (comma separated)", value=persistent_tags_string)
324
  save_persistent_tags_button = gr.Button("Save Persistent Tags")
325
+ persistent_status_output = gr.Textbox(label="", interactive=False)
326
 
327
  def save_persistent_tags(tags_string):
328
+ tags_list = [t.strip() for t in tags_string.split(',') if t.strip()]
329
+ data_manager.set_persistent_tags(tags_list)
330
  return "Persistent tags saved successfully."
331
 
332
+ save_persistent_tags_button.click(save_persistent_tags, inputs=persistent_tags_input, outputs=persistent_status_output)
333
 
334
  def main():
335
  data_manager = DataManager()
336
  with gr.Blocks() as demo:
337
  prompt_generator_app(data_manager)
338
  character_creation_app(data_manager)
339
+ tags_app(data_manager)
340
 
341
  demo.launch()
342