Spaces:
Paused
Paused
throaway2854
commited on
Commit
•
80c0c14
1
Parent(s):
bfae348
Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1 |
import gradio as gr
|
2 |
import random
|
3 |
import json
|
|
|
4 |
from huggingface_hub import HfApi, hf_hub_download, upload_file
|
5 |
|
6 |
REPO_ID = "throaway2854/promptGen" # Replace with your actual Space name
|
7 |
-
|
|
|
8 |
|
9 |
def load_data():
|
10 |
try:
|
11 |
-
file_path = hf_hub_download(repo_id=REPO_ID, filename=
|
12 |
with open(file_path, 'r') as f:
|
13 |
return json.load(f)
|
14 |
except:
|
@@ -19,17 +21,31 @@ def load_data():
|
|
19 |
}
|
20 |
|
21 |
def save_data(data):
|
22 |
-
with open(
|
23 |
json.dump(data, f)
|
24 |
api = HfApi()
|
25 |
api.upload_file(
|
26 |
-
path_or_fileobj=
|
27 |
-
path_in_repo=
|
28 |
repo_id=REPO_ID,
|
29 |
repo_type="space"
|
30 |
)
|
31 |
|
32 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
all_tags = {
|
34 |
"scene": scene_tags.split(','),
|
35 |
"position": position_tags.split(','),
|
@@ -41,19 +57,17 @@ def generate_prompt(scene_tags, num_people, position_tags, characters, outfit_ta
|
|
41 |
for category in all_tags:
|
42 |
all_tags[category] = [tag.strip() for tag in all_tags[category] if tag.strip()]
|
43 |
|
44 |
-
character_list = [char.strip() for char in characters.split(';') if char.strip()]
|
45 |
character_prompts = []
|
46 |
-
for
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
character_prompts.append(char_prompt)
|
52 |
|
53 |
selected_tags = []
|
54 |
for category, tags in all_tags.items():
|
55 |
if tags:
|
56 |
-
selected_tags.extend(f"{tag}:{random.uniform(0.8, 1.2):.2f}" for tag in random.sample(tags, min(
|
57 |
|
58 |
if num_people.strip():
|
59 |
selected_tags.append(f"{num_people} people:1.1")
|
@@ -73,54 +87,108 @@ def generate_prompt(scene_tags, num_people, position_tags, characters, outfit_ta
|
|
73 |
def update_data(data, key, value):
|
74 |
if isinstance(data[key], list):
|
75 |
data[key] = list(set(data[key] + [v.strip() for v in value.split(',') if v.strip()]))
|
76 |
-
elif isinstance(data[key], dict):
|
77 |
-
for char in [c.strip() for c in value.split(';') if c.strip()]:
|
78 |
-
char_name, *char_traits = [t.strip() for t in char.split(',')]
|
79 |
-
if char_name:
|
80 |
-
data[key][char_name] = list(set(data[key].get(char_name, []) + char_traits))
|
81 |
save_data(data)
|
82 |
return data
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
def create_ui():
|
85 |
data = load_data()
|
86 |
|
87 |
with gr.Blocks() as demo:
|
88 |
-
gr.Markdown("# Advanced Pony SDXL Prompt Generator with
|
89 |
|
90 |
-
with gr.
|
91 |
-
with gr.
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
with gr.
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
def update_and_generate(*args):
|
108 |
nonlocal data
|
109 |
-
|
110 |
-
data = update_data(data, "
|
111 |
-
data = update_data(data, "
|
112 |
-
data = update_data(data, "outfit_tags",
|
113 |
-
data = update_data(data, "camera_tags",
|
114 |
-
data = update_data(data, "concept_tags",
|
115 |
-
data = update_data(data, "lora_tags",
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
generate_button.click(
|
119 |
update_and_generate,
|
120 |
-
inputs=[scene_input, num_people_input, position_input,
|
|
|
121 |
outputs=[output]
|
122 |
)
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
return demo
|
125 |
|
126 |
if __name__ == "__main__":
|
|
|
1 |
import gradio as gr
|
2 |
import random
|
3 |
import json
|
4 |
+
import os
|
5 |
from huggingface_hub import HfApi, hf_hub_download, upload_file
|
6 |
|
7 |
REPO_ID = "throaway2854/promptGen" # Replace with your actual Space name
|
8 |
+
DATA_FILE = "saved_data.json"
|
9 |
+
IMAGES_DIR = "character_images"
|
10 |
|
11 |
def load_data():
|
12 |
try:
|
13 |
+
file_path = hf_hub_download(repo_id=REPO_ID, filename=DATA_FILE)
|
14 |
with open(file_path, 'r') as f:
|
15 |
return json.load(f)
|
16 |
except:
|
|
|
21 |
}
|
22 |
|
23 |
def save_data(data):
|
24 |
+
with open(DATA_FILE, 'w') as f:
|
25 |
json.dump(data, f)
|
26 |
api = HfApi()
|
27 |
api.upload_file(
|
28 |
+
path_or_fileobj=DATA_FILE,
|
29 |
+
path_in_repo=DATA_FILE,
|
30 |
repo_id=REPO_ID,
|
31 |
repo_type="space"
|
32 |
)
|
33 |
|
34 |
+
def save_character_image(name, image):
|
35 |
+
if not os.path.exists(IMAGES_DIR):
|
36 |
+
os.makedirs(IMAGES_DIR)
|
37 |
+
image_path = os.path.join(IMAGES_DIR, f"{name}.png")
|
38 |
+
image.save(image_path)
|
39 |
+
api = HfApi()
|
40 |
+
api.upload_file(
|
41 |
+
path_or_fileobj=image_path,
|
42 |
+
path_in_repo=f"{IMAGES_DIR}/{name}.png",
|
43 |
+
repo_id=REPO_ID,
|
44 |
+
repo_type="space"
|
45 |
+
)
|
46 |
+
return image_path
|
47 |
+
|
48 |
+
def generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_counts, data):
|
49 |
all_tags = {
|
50 |
"scene": scene_tags.split(','),
|
51 |
"position": position_tags.split(','),
|
|
|
57 |
for category in all_tags:
|
58 |
all_tags[category] = [tag.strip() for tag in all_tags[category] if tag.strip()]
|
59 |
|
|
|
60 |
character_prompts = []
|
61 |
+
for char_name in selected_characters:
|
62 |
+
if char_name in data["characters"]:
|
63 |
+
char_traits = data["characters"][char_name]["traits"]
|
64 |
+
char_prompt = f"{char_name}, " + ", ".join(random.sample(char_traits, min(tag_counts["character"], len(char_traits))))
|
65 |
+
character_prompts.append(char_prompt)
|
|
|
66 |
|
67 |
selected_tags = []
|
68 |
for category, tags in all_tags.items():
|
69 |
if tags:
|
70 |
+
selected_tags.extend(f"{tag}:{random.uniform(0.8, 1.2):.2f}" for tag in random.sample(tags, min(tag_counts[category], len(tags))))
|
71 |
|
72 |
if num_people.strip():
|
73 |
selected_tags.append(f"{num_people} people:1.1")
|
|
|
87 |
def update_data(data, key, value):
|
88 |
if isinstance(data[key], list):
|
89 |
data[key] = list(set(data[key] + [v.strip() for v in value.split(',') if v.strip()]))
|
|
|
|
|
|
|
|
|
|
|
90 |
save_data(data)
|
91 |
return data
|
92 |
|
93 |
+
def create_character(name, traits, image, data):
|
94 |
+
if name:
|
95 |
+
data["characters"][name] = {
|
96 |
+
"traits": [trait.strip() for trait in traits.split(',') if trait.strip()],
|
97 |
+
"image": save_character_image(name, image) if image else None
|
98 |
+
}
|
99 |
+
save_data(data)
|
100 |
+
return data, gr.update(choices=list(data["characters"].keys()))
|
101 |
+
|
102 |
def create_ui():
|
103 |
data = load_data()
|
104 |
|
105 |
with gr.Blocks() as demo:
|
106 |
+
gr.Markdown("# Advanced Pony SDXL Prompt Generator with Character Creation")
|
107 |
|
108 |
+
with gr.Tabs():
|
109 |
+
with gr.TabItem("Prompt Generator"):
|
110 |
+
with gr.Row():
|
111 |
+
with gr.Column():
|
112 |
+
scene_input = gr.Textbox(label="Scene Tags (comma-separated)", value=", ".join(data["scene_tags"]))
|
113 |
+
scene_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of scene tags")
|
114 |
+
|
115 |
+
num_people_input = gr.Textbox(label="Number of People")
|
116 |
+
|
117 |
+
position_input = gr.Textbox(label="Position Tags (comma-separated)", value=", ".join(data["position_tags"]))
|
118 |
+
position_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of position tags")
|
119 |
+
|
120 |
+
character_select = gr.CheckboxGroup(label="Select Characters", choices=list(data["characters"].keys()))
|
121 |
+
character_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of character traits")
|
122 |
+
|
123 |
+
outfit_input = gr.Textbox(label="Outfit Tags (comma-separated)", value=", ".join(data["outfit_tags"]))
|
124 |
+
outfit_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of outfit tags")
|
125 |
+
|
126 |
+
camera_input = gr.Textbox(label="Camera View/Angle Tags (comma-separated)", value=", ".join(data["camera_tags"]))
|
127 |
+
camera_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of camera tags")
|
128 |
+
|
129 |
+
concept_input = gr.Textbox(label="Concept Tags (comma-separated)", value=", ".join(data["concept_tags"]))
|
130 |
+
concept_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of concept tags")
|
131 |
+
|
132 |
+
lora_input = gr.Textbox(label="LORA Tags (comma-separated)", value=", ".join(data["lora_tags"]))
|
133 |
+
lora_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of LORA tags")
|
134 |
+
|
135 |
+
generate_button = gr.Button("Generate Prompt")
|
136 |
+
|
137 |
+
with gr.Column():
|
138 |
+
output = gr.Textbox(label="Generated Prompt", lines=5)
|
139 |
+
|
140 |
+
with gr.Row():
|
141 |
+
for char_name, char_data in data["characters"].items():
|
142 |
+
if char_data["image"]:
|
143 |
+
gr.Image(value=char_data["image"], label=char_name, height=100, width=100)
|
144 |
|
145 |
+
with gr.TabItem("Character Creation"):
|
146 |
+
with gr.Row():
|
147 |
+
with gr.Column():
|
148 |
+
char_name_input = gr.Textbox(label="Character Name")
|
149 |
+
char_traits_input = gr.Textbox(label="Character Traits (comma-separated)")
|
150 |
+
char_image_input = gr.Image(label="Character Image", type="pil")
|
151 |
+
create_char_button = gr.Button("Create/Update Character")
|
152 |
+
|
153 |
+
with gr.Column():
|
154 |
+
char_gallery = gr.Gallery(label="Existing Characters", show_label=False, elem_id="char_gallery").style(grid=2, height="auto")
|
155 |
+
|
156 |
def update_and_generate(*args):
|
157 |
nonlocal data
|
158 |
+
scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, *tag_counts = args
|
159 |
+
data = update_data(data, "scene_tags", scene_tags)
|
160 |
+
data = update_data(data, "position_tags", position_tags)
|
161 |
+
data = update_data(data, "outfit_tags", outfit_tags)
|
162 |
+
data = update_data(data, "camera_tags", camera_tags)
|
163 |
+
data = update_data(data, "concept_tags", concept_tags)
|
164 |
+
data = update_data(data, "lora_tags", lora_tags)
|
165 |
+
|
166 |
+
tag_count_dict = {
|
167 |
+
"scene": tag_counts[0], "position": tag_counts[1], "character": tag_counts[2],
|
168 |
+
"outfit": tag_counts[3], "camera": tag_counts[4], "concept": tag_counts[5], "lora": tag_counts[6]
|
169 |
+
}
|
170 |
+
|
171 |
+
return generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_count_dict, data)
|
172 |
|
173 |
generate_button.click(
|
174 |
update_and_generate,
|
175 |
+
inputs=[scene_input, num_people_input, position_input, character_select, outfit_input, camera_input, concept_input, lora_input,
|
176 |
+
scene_count, position_count, character_count, outfit_count, camera_count, concept_count, lora_count],
|
177 |
outputs=[output]
|
178 |
)
|
179 |
+
|
180 |
+
def update_char_gallery():
|
181 |
+
return gr.Gallery(value=[char_data["image"] for char_data in data["characters"].values() if char_data["image"]])
|
182 |
+
|
183 |
+
create_char_button.click(
|
184 |
+
create_character,
|
185 |
+
inputs=[char_name_input, char_traits_input, char_image_input],
|
186 |
+
outputs=[gr.State(data), character_select]
|
187 |
+
).then(
|
188 |
+
update_char_gallery,
|
189 |
+
outputs=[char_gallery]
|
190 |
+
)
|
191 |
+
|
192 |
return demo
|
193 |
|
194 |
if __name__ == "__main__":
|