promptGen / app.py
throaway2854's picture
Update app.py
e8f0231 verified
raw
history blame
9.33 kB
import gradio as gr
import random
import json
import os
from huggingface_hub import HfApi, hf_hub_download, upload_file
REPO_ID = "throaway2854/promptGen" # Replace with your actual Space name
DATA_FILE = "saved_data.json"
IMAGES_DIR = "character_images"
def load_data():
try:
file_path = hf_hub_download(repo_id=REPO_ID, filename=DATA_FILE)
with open(file_path, 'r') as f:
return json.load(f)
except:
return {
"scene_tags": [], "position_tags": [], "outfit_tags": [],
"camera_tags": [], "concept_tags": [], "lora_tags": [],
"characters": {}
}
def save_data(data):
with open(DATA_FILE, 'w') as f:
json.dump(data, f)
api = HfApi()
api.upload_file(
path_or_fileobj=DATA_FILE,
path_in_repo=DATA_FILE,
repo_id=REPO_ID,
repo_type="space"
)
def save_character_image(name, image):
if not os.path.exists(IMAGES_DIR):
os.makedirs(IMAGES_DIR)
image_path = os.path.join(IMAGES_DIR, f"{name}.png")
image.save(image_path)
api = HfApi()
api.upload_file(
path_or_fileobj=image_path,
path_in_repo=f"{IMAGES_DIR}/{name}.png",
repo_id=REPO_ID,
repo_type="space"
)
return image_path
def generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_counts, data):
all_tags = {
"scene": scene_tags.split(','),
"position": position_tags.split(','),
"outfit": outfit_tags.split(','),
"camera": camera_tags.split(','),
"concept": concept_tags.split(','),
}
for category in all_tags:
all_tags[category] = [tag.strip() for tag in all_tags[category] if tag.strip()]
character_prompts = []
for char_name in selected_characters:
if char_name in data["characters"]:
char_traits = data["characters"][char_name]["traits"]
char_prompt = f"{char_name}, " + ", ".join(random.sample(char_traits, min(tag_counts["character"], len(char_traits))))
character_prompts.append(char_prompt)
selected_tags = []
for category, tags in all_tags.items():
if tags:
selected_tags.extend(f"{tag}:{random.uniform(0.8, 1.2):.2f}" for tag in random.sample(tags, min(tag_counts[category], len(tags))))
if num_people.strip():
selected_tags.append(f"{num_people} people:1.1")
prompt_parts = character_prompts + selected_tags
random.shuffle(prompt_parts)
main_prompt = ", ".join(prompt_parts)
lora_list = [lora.strip() for lora in lora_tags.split(',') if lora.strip()]
lora_prompt = " ".join(f"<lora:{lora}:1>" for lora in lora_list)
fixed_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed"
full_prompt = f"{main_prompt}, {fixed_tags} {lora_prompt}".strip()
return full_prompt
def update_data(data, key, value):
if isinstance(data[key], list):
data[key] = list(set(data[key] + [v.strip() for v in value.split(',') if v.strip()]))
save_data(data)
return data
def create_character(name, traits, image, data):
if name:
data["characters"][name] = {
"traits": [trait.strip() for trait in traits.split(',') if trait.strip()],
"image": save_character_image(name, image) if image else None
}
save_data(data)
return data, gr.update(choices=list(data["characters"].keys()))
def create_ui():
data = load_data()
with gr.Blocks() as demo:
gr.Markdown("# Advanced Pony SDXL Prompt Generator with Character Creation")
with gr.Tabs():
with gr.TabItem("Prompt Generator"):
with gr.Row():
with gr.Column():
scene_input = gr.Textbox(label="Scene Tags (comma-separated)", value=", ".join(data["scene_tags"]))
scene_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of scene tags")
num_people_input = gr.Textbox(label="Number of People")
position_input = gr.Textbox(label="Position Tags (comma-separated)", value=", ".join(data["position_tags"]))
position_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of position tags")
character_select = gr.CheckboxGroup(label="Select Characters", choices=list(data["characters"].keys()))
character_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of character traits")
outfit_input = gr.Textbox(label="Outfit Tags (comma-separated)", value=", ".join(data["outfit_tags"]))
outfit_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of outfit tags")
camera_input = gr.Textbox(label="Camera View/Angle Tags (comma-separated)", value=", ".join(data["camera_tags"]))
camera_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of camera tags")
concept_input = gr.Textbox(label="Concept Tags (comma-separated)", value=", ".join(data["concept_tags"]))
concept_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of concept tags")
lora_input = gr.Textbox(label="LORA Tags (comma-separated)", value=", ".join(data["lora_tags"]))
lora_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of LORA tags")
generate_button = gr.Button("Generate Prompt")
with gr.Column():
output = gr.Textbox(label="Generated Prompt", lines=5)
with gr.Row():
char_images = [char_data["image"] for char_data in data["characters"].values() if char_data["image"]]
char_names = [char_name for char_name, char_data in data["characters"].items() if char_data["image"]]
gr.Gallery(value=char_images, label="Character Images", show_label=True, elem_id="char_gallery", columns=2, rows=2, height="auto")
with gr.TabItem("Character Creation"):
with gr.Row():
with gr.Column():
char_name_input = gr.Textbox(label="Character Name")
char_traits_input = gr.Textbox(label="Character Traits (comma-separated)")
char_image_input = gr.Image(label="Character Image", type="pil")
create_char_button = gr.Button("Create/Update Character")
with gr.Column():
char_gallery = gr.Gallery(label="Existing Characters", show_label=True, elem_id="char_gallery", columns=2, rows=2, height="auto")
def update_and_generate(*args):
nonlocal data
scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, *tag_counts = args
data = update_data(data, "scene_tags", scene_tags)
data = update_data(data, "position_tags", position_tags)
data = update_data(data, "outfit_tags", outfit_tags)
data = update_data(data, "camera_tags", camera_tags)
data = update_data(data, "concept_tags", concept_tags)
data = update_data(data, "lora_tags", lora_tags)
tag_count_dict = {
"scene": tag_counts[0], "position": tag_counts[1], "character": tag_counts[2],
"outfit": tag_counts[3], "camera": tag_counts[4], "concept": tag_counts[5], "lora": tag_counts[6]
}
return generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_count_dict, data)
generate_button.click(
update_and_generate,
inputs=[scene_input, num_people_input, position_input, character_select, outfit_input, camera_input, concept_input, lora_input,
scene_count, position_count, character_count, outfit_count, camera_count, concept_count, lora_count],
outputs=[output]
)
def update_char_gallery():
char_images = [char_data["image"] for char_data in data["characters"].values() if char_data["image"]]
return gr.Gallery(value=char_images)
create_char_button.click(
create_character,
inputs=[char_name_input, char_traits_input, char_image_input],
outputs=[gr.State(data), character_select]
).then(
update_char_gallery,
outputs=[char_gallery]
)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.launch()