Spaces:
Paused
Paused
File size: 9,329 Bytes
8ca7af9 db9ae99 80c0c14 db9ae99 80c0c14 db9ae99 80c0c14 db9ae99 80c0c14 db9ae99 80c0c14 db9ae99 8ca7af9 80c0c14 8ca7af9 80c0c14 8ca7af9 80c0c14 8ca7af9 db9ae99 80c0c14 8ca7af9 db9ae99 8ca7af9 80c0c14 8ca7af9 80c0c14 e8f0231 8ca7af9 80c0c14 e8f0231 80c0c14 db9ae99 80c0c14 db9ae99 8ca7af9 db9ae99 80c0c14 8ca7af9 80c0c14 e8f0231 80c0c14 8ca7af9 |
|
import gradio as gr
import random
import json
import os
from huggingface_hub import HfApi, hf_hub_download, upload_file
REPO_ID = "throaway2854/promptGen" # Replace with your actual Space name
DATA_FILE = "saved_data.json"
IMAGES_DIR = "character_images"
def load_data():
try:
file_path = hf_hub_download(repo_id=REPO_ID, filename=DATA_FILE)
with open(file_path, 'r') as f:
return json.load(f)
except:
return {
"scene_tags": [], "position_tags": [], "outfit_tags": [],
"camera_tags": [], "concept_tags": [], "lora_tags": [],
"characters": {}
}
def save_data(data):
with open(DATA_FILE, 'w') as f:
json.dump(data, f)
api = HfApi()
api.upload_file(
path_or_fileobj=DATA_FILE,
path_in_repo=DATA_FILE,
repo_id=REPO_ID,
repo_type="space"
)
def save_character_image(name, image):
if not os.path.exists(IMAGES_DIR):
os.makedirs(IMAGES_DIR)
image_path = os.path.join(IMAGES_DIR, f"{name}.png")
image.save(image_path)
api = HfApi()
api.upload_file(
path_or_fileobj=image_path,
path_in_repo=f"{IMAGES_DIR}/{name}.png",
repo_id=REPO_ID,
repo_type="space"
)
return image_path
def generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_counts, data):
all_tags = {
"scene": scene_tags.split(','),
"position": position_tags.split(','),
"outfit": outfit_tags.split(','),
"camera": camera_tags.split(','),
"concept": concept_tags.split(','),
}
for category in all_tags:
all_tags[category] = [tag.strip() for tag in all_tags[category] if tag.strip()]
character_prompts = []
for char_name in selected_characters:
if char_name in data["characters"]:
char_traits = data["characters"][char_name]["traits"]
char_prompt = f"{char_name}, " + ", ".join(random.sample(char_traits, min(tag_counts["character"], len(char_traits))))
character_prompts.append(char_prompt)
selected_tags = []
for category, tags in all_tags.items():
if tags:
selected_tags.extend(f"{tag}:{random.uniform(0.8, 1.2):.2f}" for tag in random.sample(tags, min(tag_counts[category], len(tags))))
if num_people.strip():
selected_tags.append(f"{num_people} people:1.1")
prompt_parts = character_prompts + selected_tags
random.shuffle(prompt_parts)
main_prompt = ", ".join(prompt_parts)
lora_list = [lora.strip() for lora in lora_tags.split(',') if lora.strip()]
lora_prompt = " ".join(f"<lora:{lora}:1>" for lora in lora_list)
fixed_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed"
full_prompt = f"{main_prompt}, {fixed_tags} {lora_prompt}".strip()
return full_prompt
def update_data(data, key, value):
if isinstance(data[key], list):
data[key] = list(set(data[key] + [v.strip() for v in value.split(',') if v.strip()]))
save_data(data)
return data
def create_character(name, traits, image, data):
if name:
data["characters"][name] = {
"traits": [trait.strip() for trait in traits.split(',') if trait.strip()],
"image": save_character_image(name, image) if image else None
}
save_data(data)
return data, gr.update(choices=list(data["characters"].keys()))
def create_ui():
data = load_data()
with gr.Blocks() as demo:
gr.Markdown("# Advanced Pony SDXL Prompt Generator with Character Creation")
with gr.Tabs():
with gr.TabItem("Prompt Generator"):
with gr.Row():
with gr.Column():
scene_input = gr.Textbox(label="Scene Tags (comma-separated)", value=", ".join(data["scene_tags"]))
scene_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of scene tags")
num_people_input = gr.Textbox(label="Number of People")
position_input = gr.Textbox(label="Position Tags (comma-separated)", value=", ".join(data["position_tags"]))
position_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of position tags")
character_select = gr.CheckboxGroup(label="Select Characters", choices=list(data["characters"].keys()))
character_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of character traits")
outfit_input = gr.Textbox(label="Outfit Tags (comma-separated)", value=", ".join(data["outfit_tags"]))
outfit_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of outfit tags")
camera_input = gr.Textbox(label="Camera View/Angle Tags (comma-separated)", value=", ".join(data["camera_tags"]))
camera_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of camera tags")
concept_input = gr.Textbox(label="Concept Tags (comma-separated)", value=", ".join(data["concept_tags"]))
concept_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of concept tags")
lora_input = gr.Textbox(label="LORA Tags (comma-separated)", value=", ".join(data["lora_tags"]))
lora_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of LORA tags")
generate_button = gr.Button("Generate Prompt")
with gr.Column():
output = gr.Textbox(label="Generated Prompt", lines=5)
with gr.Row():
char_images = [char_data["image"] for char_data in data["characters"].values() if char_data["image"]]
char_names = [char_name for char_name, char_data in data["characters"].items() if char_data["image"]]
gr.Gallery(value=char_images, label="Character Images", show_label=True, elem_id="char_gallery", columns=2, rows=2, height="auto")
with gr.TabItem("Character Creation"):
with gr.Row():
with gr.Column():
char_name_input = gr.Textbox(label="Character Name")
char_traits_input = gr.Textbox(label="Character Traits (comma-separated)")
char_image_input = gr.Image(label="Character Image", type="pil")
create_char_button = gr.Button("Create/Update Character")
with gr.Column():
char_gallery = gr.Gallery(label="Existing Characters", show_label=True, elem_id="char_gallery", columns=2, rows=2, height="auto")
def update_and_generate(*args):
nonlocal data
scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, *tag_counts = args
data = update_data(data, "scene_tags", scene_tags)
data = update_data(data, "position_tags", position_tags)
data = update_data(data, "outfit_tags", outfit_tags)
data = update_data(data, "camera_tags", camera_tags)
data = update_data(data, "concept_tags", concept_tags)
data = update_data(data, "lora_tags", lora_tags)
tag_count_dict = {
"scene": tag_counts[0], "position": tag_counts[1], "character": tag_counts[2],
"outfit": tag_counts[3], "camera": tag_counts[4], "concept": tag_counts[5], "lora": tag_counts[6]
}
return generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_count_dict, data)
generate_button.click(
update_and_generate,
inputs=[scene_input, num_people_input, position_input, character_select, outfit_input, camera_input, concept_input, lora_input,
scene_count, position_count, character_count, outfit_count, camera_count, concept_count, lora_count],
outputs=[output]
)
def update_char_gallery():
char_images = [char_data["image"] for char_data in data["characters"].values() if char_data["image"]]
return gr.Gallery(value=char_images)
create_char_button.click(
create_character,
inputs=[char_name_input, char_traits_input, char_image_input],
outputs=[gr.State(data), character_select]
).then(
update_char_gallery,
outputs=[char_gallery]
)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.launch() |