Spaces:
Paused
Paused
import gradio as gr | |
import random | |
import json | |
import os | |
from huggingface_hub import HfApi, hf_hub_download, upload_file | |
REPO_ID = "throaway2854/promptGen" # Replace with your actual Space name | |
DATA_FILE = "saved_data.json" | |
IMAGES_DIR = "character_images" | |
def load_data(): | |
try: | |
file_path = hf_hub_download(repo_id=REPO_ID, filename=DATA_FILE) | |
with open(file_path, 'r') as f: | |
return json.load(f) | |
except: | |
return { | |
"scene_tags": [], "position_tags": [], "outfit_tags": [], | |
"camera_tags": [], "concept_tags": [], "lora_tags": [], | |
"characters": {} | |
} | |
def save_data(data): | |
with open(DATA_FILE, 'w') as f: | |
json.dump(data, f) | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj=DATA_FILE, | |
path_in_repo=DATA_FILE, | |
repo_id=REPO_ID, | |
repo_type="space" | |
) | |
def save_character_image(name, image): | |
if not os.path.exists(IMAGES_DIR): | |
os.makedirs(IMAGES_DIR) | |
image_path = os.path.join(IMAGES_DIR, f"{name}.png") | |
image.save(image_path) | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj=image_path, | |
path_in_repo=f"{IMAGES_DIR}/{name}.png", | |
repo_id=REPO_ID, | |
repo_type="space" | |
) | |
return image_path | |
def generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_counts, data): | |
all_tags = { | |
"scene": scene_tags.split(','), | |
"position": position_tags.split(','), | |
"outfit": outfit_tags.split(','), | |
"camera": camera_tags.split(','), | |
"concept": concept_tags.split(','), | |
} | |
for category in all_tags: | |
all_tags[category] = [tag.strip() for tag in all_tags[category] if tag.strip()] | |
character_prompts = [] | |
for char_name in selected_characters: | |
if char_name in data["characters"]: | |
char_traits = data["characters"][char_name]["traits"] | |
char_prompt = f"{char_name}, " + ", ".join(random.sample(char_traits, min(tag_counts["character"], len(char_traits)))) | |
character_prompts.append(char_prompt) | |
selected_tags = [] | |
for category, tags in all_tags.items(): | |
if tags: | |
selected_tags.extend(f"{tag}:{random.uniform(0.8, 1.2):.2f}" for tag in random.sample(tags, min(tag_counts[category], len(tags)))) | |
if num_people.strip(): | |
selected_tags.append(f"{num_people} people:1.1") | |
prompt_parts = character_prompts + selected_tags | |
random.shuffle(prompt_parts) | |
main_prompt = ", ".join(prompt_parts) | |
lora_list = [lora.strip() for lora in lora_tags.split(',') if lora.strip()] | |
lora_prompt = " ".join(f"<lora:{lora}:1>" for lora in lora_list) | |
fixed_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed" | |
full_prompt = f"{main_prompt}, {fixed_tags} {lora_prompt}".strip() | |
return full_prompt | |
def update_data(data, key, value): | |
if isinstance(data[key], list): | |
data[key] = list(set(data[key] + [v.strip() for v in value.split(',') if v.strip()])) | |
save_data(data) | |
return data | |
def create_character(name, traits, image, data): | |
if name: | |
data["characters"][name] = { | |
"traits": [trait.strip() for trait in traits.split(',') if trait.strip()], | |
"image": save_character_image(name, image) if image else None | |
} | |
save_data(data) | |
return data, gr.update(choices=list(data["characters"].keys())) | |
def create_ui(): | |
data = load_data() | |
with gr.Blocks() as demo: | |
gr.Markdown("# Advanced Pony SDXL Prompt Generator with Character Creation") | |
with gr.Tabs(): | |
with gr.TabItem("Prompt Generator"): | |
with gr.Row(): | |
with gr.Column(): | |
scene_input = gr.Textbox(label="Scene Tags (comma-separated)", value=", ".join(data["scene_tags"])) | |
scene_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of scene tags") | |
num_people_input = gr.Textbox(label="Number of People") | |
position_input = gr.Textbox(label="Position Tags (comma-separated)", value=", ".join(data["position_tags"])) | |
position_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of position tags") | |
character_select = gr.CheckboxGroup(label="Select Characters", choices=list(data["characters"].keys())) | |
character_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of character traits") | |
outfit_input = gr.Textbox(label="Outfit Tags (comma-separated)", value=", ".join(data["outfit_tags"])) | |
outfit_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of outfit tags") | |
camera_input = gr.Textbox(label="Camera View/Angle Tags (comma-separated)", value=", ".join(data["camera_tags"])) | |
camera_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of camera tags") | |
concept_input = gr.Textbox(label="Concept Tags (comma-separated)", value=", ".join(data["concept_tags"])) | |
concept_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of concept tags") | |
lora_input = gr.Textbox(label="LORA Tags (comma-separated)", value=", ".join(data["lora_tags"])) | |
lora_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of LORA tags") | |
generate_button = gr.Button("Generate Prompt") | |
with gr.Column(): | |
output = gr.Textbox(label="Generated Prompt", lines=5) | |
with gr.Row(): | |
char_images = [char_data["image"] for char_data in data["characters"].values() if char_data["image"]] | |
char_names = [char_name for char_name, char_data in data["characters"].items() if char_data["image"]] | |
gr.Gallery(value=char_images, label="Character Images", show_label=True, elem_id="char_gallery", columns=2, rows=2, height="auto") | |
with gr.TabItem("Character Creation"): | |
with gr.Row(): | |
with gr.Column(): | |
char_name_input = gr.Textbox(label="Character Name") | |
char_traits_input = gr.Textbox(label="Character Traits (comma-separated)") | |
char_image_input = gr.Image(label="Character Image", type="pil") | |
create_char_button = gr.Button("Create/Update Character") | |
with gr.Column(): | |
char_gallery = gr.Gallery(label="Existing Characters", show_label=True, elem_id="char_gallery", columns=2, rows=2, height="auto") | |
def update_and_generate(*args): | |
nonlocal data | |
scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, *tag_counts = args | |
data = update_data(data, "scene_tags", scene_tags) | |
data = update_data(data, "position_tags", position_tags) | |
data = update_data(data, "outfit_tags", outfit_tags) | |
data = update_data(data, "camera_tags", camera_tags) | |
data = update_data(data, "concept_tags", concept_tags) | |
data = update_data(data, "lora_tags", lora_tags) | |
tag_count_dict = { | |
"scene": tag_counts[0], "position": tag_counts[1], "character": tag_counts[2], | |
"outfit": tag_counts[3], "camera": tag_counts[4], "concept": tag_counts[5], "lora": tag_counts[6] | |
} | |
return generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_count_dict, data) | |
generate_button.click( | |
update_and_generate, | |
inputs=[scene_input, num_people_input, position_input, character_select, outfit_input, camera_input, concept_input, lora_input, | |
scene_count, position_count, character_count, outfit_count, camera_count, concept_count, lora_count], | |
outputs=[output] | |
) | |
def update_char_gallery(): | |
char_images = [char_data["image"] for char_data in data["characters"].values() if char_data["image"]] | |
return gr.Gallery(value=char_images) | |
create_char_button.click( | |
create_character, | |
inputs=[char_name_input, char_traits_input, char_image_input], | |
outputs=[gr.State(data), character_select] | |
).then( | |
update_char_gallery, | |
outputs=[char_gallery] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_ui() | |
demo.launch() |