throaway2854 commited on
Commit
80c0c14
1 Parent(s): bfae348

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -46
app.py CHANGED
@@ -1,14 +1,16 @@
1
  import gradio as gr
2
  import random
3
  import json
 
4
  from huggingface_hub import HfApi, hf_hub_download, upload_file
5
 
6
  REPO_ID = "throaway2854/promptGen" # Replace with your actual Space name
7
- FILE_NAME = "saved_data.json"
 
8
 
9
  def load_data():
10
  try:
11
- file_path = hf_hub_download(repo_id=REPO_ID, filename=FILE_NAME)
12
  with open(file_path, 'r') as f:
13
  return json.load(f)
14
  except:
@@ -19,17 +21,31 @@ def load_data():
19
  }
20
 
21
  def save_data(data):
22
- with open(FILE_NAME, 'w') as f:
23
  json.dump(data, f)
24
  api = HfApi()
25
  api.upload_file(
26
- path_or_fileobj=FILE_NAME,
27
- path_in_repo=FILE_NAME,
28
  repo_id=REPO_ID,
29
  repo_type="space"
30
  )
31
 
32
- def generate_prompt(scene_tags, num_people, position_tags, characters, outfit_tags, camera_tags, concept_tags, lora_tags, num_tags=3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  all_tags = {
34
  "scene": scene_tags.split(','),
35
  "position": position_tags.split(','),
@@ -41,19 +57,17 @@ def generate_prompt(scene_tags, num_people, position_tags, characters, outfit_ta
41
  for category in all_tags:
42
  all_tags[category] = [tag.strip() for tag in all_tags[category] if tag.strip()]
43
 
44
- character_list = [char.strip() for char in characters.split(';') if char.strip()]
45
  character_prompts = []
46
- for character in character_list:
47
- char_tags = character.split(',')
48
- char_name = char_tags[0].strip()
49
- char_traits = [tag.strip() for tag in char_tags[1:] if tag.strip()]
50
- char_prompt = f"{char_name}, " + ", ".join(random.sample(char_traits, min(num_tags, len(char_traits))))
51
- character_prompts.append(char_prompt)
52
 
53
  selected_tags = []
54
  for category, tags in all_tags.items():
55
  if tags:
56
- selected_tags.extend(f"{tag}:{random.uniform(0.8, 1.2):.2f}" for tag in random.sample(tags, min(num_tags, len(tags))))
57
 
58
  if num_people.strip():
59
  selected_tags.append(f"{num_people} people:1.1")
@@ -73,54 +87,108 @@ def generate_prompt(scene_tags, num_people, position_tags, characters, outfit_ta
73
  def update_data(data, key, value):
74
  if isinstance(data[key], list):
75
  data[key] = list(set(data[key] + [v.strip() for v in value.split(',') if v.strip()]))
76
- elif isinstance(data[key], dict):
77
- for char in [c.strip() for c in value.split(';') if c.strip()]:
78
- char_name, *char_traits = [t.strip() for t in char.split(',')]
79
- if char_name:
80
- data[key][char_name] = list(set(data[key].get(char_name, []) + char_traits))
81
  save_data(data)
82
  return data
83
 
 
 
 
 
 
 
 
 
 
84
  def create_ui():
85
  data = load_data()
86
 
87
  with gr.Blocks() as demo:
88
- gr.Markdown("# Advanced Pony SDXL Prompt Generator with Persistent Storage")
89
 
90
- with gr.Row():
91
- with gr.Column():
92
- scene_input = gr.Textbox(label="Scene Tags (comma-separated)", value=", ".join(data["scene_tags"]))
93
- num_people_input = gr.Textbox(label="Number of People")
94
- position_input = gr.Textbox(label="Position Tags (comma-separated)", value=", ".join(data["position_tags"]))
95
- characters_input = gr.Textbox(label="Characters (Format: Name, Trait1, Trait2; Name2, Trait1, Trait2)",
96
- value="; ".join([f"{name}, {', '.join(traits)}" for name, traits in data["characters"].items()]))
97
- outfit_input = gr.Textbox(label="Outfit Tags (comma-separated)", value=", ".join(data["outfit_tags"]))
98
- camera_input = gr.Textbox(label="Camera View/Angle Tags (comma-separated)", value=", ".join(data["camera_tags"]))
99
- concept_input = gr.Textbox(label="Concept Tags (comma-separated)", value=", ".join(data["concept_tags"]))
100
- lora_input = gr.Textbox(label="LORA Tags (comma-separated)", value=", ".join(data["lora_tags"]))
101
- num_tags_slider = gr.Slider(minimum=1, maximum=5, step=1, value=3, label="Number of tags per category")
102
- generate_button = gr.Button("Generate Prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- with gr.Column():
105
- output = gr.Textbox(label="Generated Prompt", lines=5)
106
-
 
 
 
 
 
 
 
 
107
  def update_and_generate(*args):
108
  nonlocal data
109
- data = update_data(data, "scene_tags", args[0])
110
- data = update_data(data, "position_tags", args[2])
111
- data = update_data(data, "characters", args[3])
112
- data = update_data(data, "outfit_tags", args[4])
113
- data = update_data(data, "camera_tags", args[5])
114
- data = update_data(data, "concept_tags", args[6])
115
- data = update_data(data, "lora_tags", args[7])
116
- return generate_prompt(*args)
 
 
 
 
 
 
117
 
118
  generate_button.click(
119
  update_and_generate,
120
- inputs=[scene_input, num_people_input, position_input, characters_input, outfit_input, camera_input, concept_input, lora_input, num_tags_slider],
 
121
  outputs=[output]
122
  )
123
-
 
 
 
 
 
 
 
 
 
 
 
 
124
  return demo
125
 
126
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import random
3
  import json
4
+ import os
5
  from huggingface_hub import HfApi, hf_hub_download, upload_file
6
 
7
  REPO_ID = "throaway2854/promptGen" # Replace with your actual Space name
8
+ DATA_FILE = "saved_data.json"
9
+ IMAGES_DIR = "character_images"
10
 
11
  def load_data():
12
  try:
13
+ file_path = hf_hub_download(repo_id=REPO_ID, filename=DATA_FILE)
14
  with open(file_path, 'r') as f:
15
  return json.load(f)
16
  except:
 
21
  }
22
 
23
  def save_data(data):
24
+ with open(DATA_FILE, 'w') as f:
25
  json.dump(data, f)
26
  api = HfApi()
27
  api.upload_file(
28
+ path_or_fileobj=DATA_FILE,
29
+ path_in_repo=DATA_FILE,
30
  repo_id=REPO_ID,
31
  repo_type="space"
32
  )
33
 
34
+ def save_character_image(name, image):
35
+ if not os.path.exists(IMAGES_DIR):
36
+ os.makedirs(IMAGES_DIR)
37
+ image_path = os.path.join(IMAGES_DIR, f"{name}.png")
38
+ image.save(image_path)
39
+ api = HfApi()
40
+ api.upload_file(
41
+ path_or_fileobj=image_path,
42
+ path_in_repo=f"{IMAGES_DIR}/{name}.png",
43
+ repo_id=REPO_ID,
44
+ repo_type="space"
45
+ )
46
+ return image_path
47
+
48
+ def generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_counts, data):
49
  all_tags = {
50
  "scene": scene_tags.split(','),
51
  "position": position_tags.split(','),
 
57
  for category in all_tags:
58
  all_tags[category] = [tag.strip() for tag in all_tags[category] if tag.strip()]
59
 
 
60
  character_prompts = []
61
+ for char_name in selected_characters:
62
+ if char_name in data["characters"]:
63
+ char_traits = data["characters"][char_name]["traits"]
64
+ char_prompt = f"{char_name}, " + ", ".join(random.sample(char_traits, min(tag_counts["character"], len(char_traits))))
65
+ character_prompts.append(char_prompt)
 
66
 
67
  selected_tags = []
68
  for category, tags in all_tags.items():
69
  if tags:
70
+ selected_tags.extend(f"{tag}:{random.uniform(0.8, 1.2):.2f}" for tag in random.sample(tags, min(tag_counts[category], len(tags))))
71
 
72
  if num_people.strip():
73
  selected_tags.append(f"{num_people} people:1.1")
 
87
  def update_data(data, key, value):
88
  if isinstance(data[key], list):
89
  data[key] = list(set(data[key] + [v.strip() for v in value.split(',') if v.strip()]))
 
 
 
 
 
90
  save_data(data)
91
  return data
92
 
93
+ def create_character(name, traits, image, data):
94
+ if name:
95
+ data["characters"][name] = {
96
+ "traits": [trait.strip() for trait in traits.split(',') if trait.strip()],
97
+ "image": save_character_image(name, image) if image else None
98
+ }
99
+ save_data(data)
100
+ return data, gr.update(choices=list(data["characters"].keys()))
101
+
102
  def create_ui():
103
  data = load_data()
104
 
105
  with gr.Blocks() as demo:
106
+ gr.Markdown("# Advanced Pony SDXL Prompt Generator with Character Creation")
107
 
108
+ with gr.Tabs():
109
+ with gr.TabItem("Prompt Generator"):
110
+ with gr.Row():
111
+ with gr.Column():
112
+ scene_input = gr.Textbox(label="Scene Tags (comma-separated)", value=", ".join(data["scene_tags"]))
113
+ scene_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of scene tags")
114
+
115
+ num_people_input = gr.Textbox(label="Number of People")
116
+
117
+ position_input = gr.Textbox(label="Position Tags (comma-separated)", value=", ".join(data["position_tags"]))
118
+ position_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of position tags")
119
+
120
+ character_select = gr.CheckboxGroup(label="Select Characters", choices=list(data["characters"].keys()))
121
+ character_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of character traits")
122
+
123
+ outfit_input = gr.Textbox(label="Outfit Tags (comma-separated)", value=", ".join(data["outfit_tags"]))
124
+ outfit_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of outfit tags")
125
+
126
+ camera_input = gr.Textbox(label="Camera View/Angle Tags (comma-separated)", value=", ".join(data["camera_tags"]))
127
+ camera_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of camera tags")
128
+
129
+ concept_input = gr.Textbox(label="Concept Tags (comma-separated)", value=", ".join(data["concept_tags"]))
130
+ concept_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of concept tags")
131
+
132
+ lora_input = gr.Textbox(label="LORA Tags (comma-separated)", value=", ".join(data["lora_tags"]))
133
+ lora_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of LORA tags")
134
+
135
+ generate_button = gr.Button("Generate Prompt")
136
+
137
+ with gr.Column():
138
+ output = gr.Textbox(label="Generated Prompt", lines=5)
139
+
140
+ with gr.Row():
141
+ for char_name, char_data in data["characters"].items():
142
+ if char_data["image"]:
143
+ gr.Image(value=char_data["image"], label=char_name, height=100, width=100)
144
 
145
+ with gr.TabItem("Character Creation"):
146
+ with gr.Row():
147
+ with gr.Column():
148
+ char_name_input = gr.Textbox(label="Character Name")
149
+ char_traits_input = gr.Textbox(label="Character Traits (comma-separated)")
150
+ char_image_input = gr.Image(label="Character Image", type="pil")
151
+ create_char_button = gr.Button("Create/Update Character")
152
+
153
+ with gr.Column():
154
+ char_gallery = gr.Gallery(label="Existing Characters", show_label=False, elem_id="char_gallery").style(grid=2, height="auto")
155
+
156
  def update_and_generate(*args):
157
  nonlocal data
158
+ scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, *tag_counts = args
159
+ data = update_data(data, "scene_tags", scene_tags)
160
+ data = update_data(data, "position_tags", position_tags)
161
+ data = update_data(data, "outfit_tags", outfit_tags)
162
+ data = update_data(data, "camera_tags", camera_tags)
163
+ data = update_data(data, "concept_tags", concept_tags)
164
+ data = update_data(data, "lora_tags", lora_tags)
165
+
166
+ tag_count_dict = {
167
+ "scene": tag_counts[0], "position": tag_counts[1], "character": tag_counts[2],
168
+ "outfit": tag_counts[3], "camera": tag_counts[4], "concept": tag_counts[5], "lora": tag_counts[6]
169
+ }
170
+
171
+ return generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_count_dict, data)
172
 
173
  generate_button.click(
174
  update_and_generate,
175
+ inputs=[scene_input, num_people_input, position_input, character_select, outfit_input, camera_input, concept_input, lora_input,
176
+ scene_count, position_count, character_count, outfit_count, camera_count, concept_count, lora_count],
177
  outputs=[output]
178
  )
179
+
180
+ def update_char_gallery():
181
+ return gr.Gallery(value=[char_data["image"] for char_data in data["characters"].values() if char_data["image"]])
182
+
183
+ create_char_button.click(
184
+ create_character,
185
+ inputs=[char_name_input, char_traits_input, char_image_input],
186
+ outputs=[gr.State(data), character_select]
187
+ ).then(
188
+ update_char_gallery,
189
+ outputs=[char_gallery]
190
+ )
191
+
192
  return demo
193
 
194
  if __name__ == "__main__":