promptGen / app.py
throaway2854's picture
Update app.py
3d1e39b verified
raw
history blame
26.1 kB
import gradio as gr
import json
import os
import random
from PIL import Image
import base64
import io
# Define categories at the top so they are accessible throughout the code
categories = [
('Setting', 'scene_tags'),
('Position', 'position_tags'),
('Outfit', 'outfit_tags'),
('Camera View/Angle', 'camera_tags'),
('Concept', 'concept_tags'),
('Facial Expression', 'facial_expression_tags'),
('Pose', 'pose_tags'),
('Additional', 'additional_tags'),
('LORA', 'lora_tags')
]
class DataManager:
def __init__(self, base_dir='/data'):
self.base_dir = base_dir
# Ensure the base directory exists
if not os.path.exists(self.base_dir):
os.makedirs(self.base_dir)
self.characters_file = os.path.join(self.base_dir, 'characters.json')
self.persistent_tags_file = os.path.join(self.base_dir, 'persistent_tags.json')
self.category_tags_file = os.path.join(self.base_dir, 'category_tags.json')
self.images_folder = os.path.join(self.base_dir, 'character_images')
# Ensure the images folder exists
if not os.path.exists(self.images_folder):
os.makedirs(self.images_folder)
self.load_data()
def load_data(self):
# Load characters
if os.path.exists(self.characters_file):
with open(self.characters_file, 'r') as f:
self.characters = json.load(f)
else:
self.characters = []
# Load persistent tags
if os.path.exists(self.persistent_tags_file):
with open(self.persistent_tags_file, 'r') as f:
self.persistent_tags = json.load(f)
else:
self.persistent_tags = []
# Load category tags
self.load_category_tags()
def save_characters(self):
with open(self.characters_file, 'w') as f:
json.dump(self.characters, f, indent=4)
def save_persistent_tags(self):
with open(self.persistent_tags_file, 'w') as f:
json.dump(self.persistent_tags, f, indent=4)
def load_category_tags(self):
if os.path.exists(self.category_tags_file):
with open(self.category_tags_file, 'r') as f:
self.category_tags = json.load(f)
else:
self.category_tags = {}
def save_category_tags(self):
with open(self.category_tags_file, 'w') as f:
json.dump(self.category_tags, f, indent=4)
def get_category_tags(self, category_var_name):
# Return the tags list for the given category variable name
return self.category_tags.get(category_var_name, [])
def set_category_tags(self, category_var_name, tags_list):
self.category_tags[category_var_name] = tags_list
self.save_category_tags()
def get_characters(self):
# Ensure images paths are up-to-date
for char in self.characters:
image_path = char.get('image_path')
if image_path and os.path.exists(image_path):
char['image'] = image_path
else:
char['image'] = None
return self.characters
def set_persistent_tags(self, tags_list):
self.persistent_tags = tags_list
self.save_persistent_tags()
def get_persistent_tags(self):
return self.persistent_tags
def add_character(self, character):
# Save image to disk and store the filename
image_data = character['image'] # This is the file path
safe_name = "".join(c for c in character['name'] if c.isalnum() or c in (' ', '_', '-')).rstrip()
image_filename = f"{safe_name}.png"
image_path = os.path.join(self.images_folder, image_filename)
# Save the image if provided
if image_data:
try:
# Open and save the uploaded image to the images folder
image = Image.open(image_data)
image.save(image_path)
character['image_path'] = image_path
except Exception as e:
print(f"Error saving image: {e}")
character['image_path'] = None
else:
character['image_path'] = None
character.pop('image', None)
# Assuming traits is a string, split into list if necessary
if isinstance(character['traits'], str):
character['traits'] = character['traits'].split(',')
character['traits'] = [t.strip() for t in character['traits']]
self.characters.append(character)
self.save_characters()
def update_character(self, original_name, updated_character):
# Find the character by original_name
for idx, char in enumerate(self.characters):
if char['name'] == original_name:
# Handle image update
if updated_character['image']:
image_path = updated_character['image'] # File path
safe_name = "".join(c for c in updated_character['name'] if c.isalnum() or c in (' ', '_', '-')).rstrip()
image_filename = f"{safe_name}.png"
new_image_path = os.path.join(self.images_folder, image_filename)
try:
image = Image.open(image_path)
image.save(new_image_path)
updated_character['image_path'] = new_image_path
# Remove old image if name has changed
if original_name != updated_character['name']:
old_image_path = char.get('image_path')
if old_image_path and os.path.exists(old_image_path):
os.remove(old_image_path)
except Exception as e:
print(f"Error updating image: {e}")
updated_character['image_path'] = char.get('image_path')
else:
# If no new image provided, retain the old image_path
updated_character['image_path'] = char.get('image_path')
updated_character.pop('image', None)
# Process traits
if isinstance(updated_character['traits'], str):
updated_character['traits'] = updated_character['traits'].split(',')
updated_character['traits'] = [t.strip() for t in updated_character['traits']]
self.characters[idx] = updated_character
self.save_characters()
return True
return False
def delete_character(self, name):
# Find and remove the character
for idx, char in enumerate(self.characters):
if char['name'] == name:
image_path = char.get('image_path')
if image_path and os.path.exists(image_path):
os.remove(image_path)
del self.characters[idx]
self.save_characters()
return True
return False
def prompt_generator_app(data_manager):
with gr.Tab("Prompt Generator"):
gr.Markdown("## Prompt Generator")
# Add a refresh tags button
refresh_tags_button = gr.Button("Refresh Tags")
inputs = {}
tag_displays = {}
for category_name, var_name in categories:
tags_list = data_manager.get_category_tags(var_name)
tags_string = ', '.join(tags_list)
max_tags = len(tags_list)
if max_tags == 0:
default_value = 0
else:
default_value = min(1, max_tags)
with gr.Group():
gr.Markdown(f"### {category_name}")
tag_display = gr.Markdown(f"**Tags:** {tags_string}")
tag_num = gr.Slider(minimum=0, maximum=max_tags, step=1, value=default_value, label=f"Number of {category_name} Tags to Select")
inputs[f"{var_name}_num"] = tag_num
tag_displays[var_name] = (tag_display, tag_num)
# For Character Selection
with gr.Group():
gr.Markdown("### Character Selection")
# Get the list of characters
def get_character_options():
characters = data_manager.get_characters()
character_options = []
for char in characters:
option_label = f"{char['name']} ({char['gender']})"
character_options.append(option_label)
return character_options
character_options = get_character_options()
character_select = gr.CheckboxGroup(choices=character_options, label="Select Characters", interactive=True)
refresh_characters_button = gr.Button("Refresh Character List")
def refresh_characters():
new_options = get_character_options()
return gr.CheckboxGroup.update(choices=new_options)
refresh_characters_button.click(refresh_characters, outputs=character_select)
random_characters = gr.Checkbox(label="Select Random Characters")
num_characters = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Characters (if random)")
generate_button = gr.Button("Generate Prompt")
prompt_output = gr.Textbox(label="Generated Prompt", lines=5)
def generate_prompt(*args):
arg_idx = 0
prompt_tags = []
for category_name, var_name in categories:
tags_list = data_manager.get_category_tags(var_name)
tags_num = args[arg_idx]
arg_idx += 1
if tags_list and tags_num > 0:
selected_tags = random.sample(tags_list, min(len(tags_list), int(tags_num)))
prompt_tags.extend(selected_tags)
# Handle Characters
selected_character_options = args[arg_idx]
random_chars = args[arg_idx + 1]
num_random_chars = args[arg_idx + 2]
arg_idx += 3
characters = data_manager.get_characters()
if random_chars:
num = min(len(characters), int(num_random_chars))
selected_chars = random.sample(characters, num)
else:
# Extract selected character names from options
selected_chars = []
for option in selected_character_options:
name = option.split(' (')[0]
for char in characters:
if char['name'] == name:
selected_chars.append(char)
break
# Determine the number of boys and girls
num_girls = sum(1 for char in selected_chars if char.get('gender') == 'Girl')
num_boys = sum(1 for char in selected_chars if char.get('gender') == 'Boy')
# Build the initial character count tags
character_count_tags = []
if num_girls > 0:
character_count_tags.append(f"{num_girls}girl" if num_girls == 1 else f"{num_girls}girls")
if num_boys > 0:
character_count_tags.append(f"{num_boys}boy" if num_boys == 1 else f"{num_boys}boys")
prompt_parts = []
if character_count_tags:
prompt_parts.append(', '.join(character_count_tags))
# Build character descriptions
character_descriptions = []
for idx, char in enumerate(selected_chars):
# Get traits for the character
traits = ', '.join(char['traits'])
# Create a description for each character
# For SDXL models, use the format "[char1 description] AND [char2 description]"
# Each character's description is enclosed in parentheses
character_description = f"({traits})"
character_descriptions.append(character_description)
# Join character descriptions appropriately for SDXL models
if character_descriptions:
character_descriptions_str = ' AND '.join(character_descriptions)
prompt_parts.append(character_descriptions_str)
# Append selected prompt tags from categories
if prompt_tags:
prompt_tags_str = ', '.join(prompt_tags)
prompt_parts.append(prompt_tags_str)
# Load persistent tags
persistent_tags = data_manager.get_persistent_tags()
if persistent_tags:
persistent_tags_str = ', '.join(persistent_tags)
prompt_parts.append(persistent_tags_str)
# Add ending tags
ending_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed"
prompt_parts.append(ending_tags)
prompt_string = ', '.join(prompt_parts)
return prompt_string
# Prepare the list of inputs for the generate_prompt function
inputs_list = []
for category_name, var_name in categories:
inputs_list.append(inputs[f"{var_name}_num"])
# Add character_select directly to inputs
inputs_list.extend([character_select, random_characters, num_characters])
generate_button.click(generate_prompt, inputs=inputs_list, outputs=prompt_output)
# Function to refresh tags display and sliders
def refresh_tags():
updates = []
for category_name, var_name in categories:
# Reload tags from data_manager
tags_list = data_manager.get_category_tags(var_name)
tags_string = ', '.join(tags_list)
max_tags = len(tags_list)
if max_tags == 0:
slider_value = 0
else:
slider_value = min(1, max_tags)
# Update the tag display and slider
tag_display, tag_num = tag_displays[var_name]
updates.append(gr.Markdown.update(value=f"**Tags:** {tags_string}"))
updates.append(gr.Slider.update(maximum=max_tags, value=slider_value))
return updates
# Prepare the outputs list
outputs = [component for pair in tag_displays.values() for component in pair]
# Connect the refresh_tags function to the refresh_tags_button
refresh_tags_button.click(refresh_tags, outputs=outputs)
def character_creation_app(data_manager):
with gr.Tab("Character Creation"):
gr.Markdown("## Create a New Character")
with gr.Row():
name_input = gr.Textbox(label="Character Name", placeholder="Enter unique character name")
traits_input = gr.Textbox(label="Traits/Appearance Tags (comma separated)", placeholder="e.g., blue hair, green eyes, tall")
image_input = gr.Image(label="Upload Character Image", type="filepath")
gender_input = gr.Radio(choices=["Boy", "Girl"], label="Gender")
save_button = gr.Button("Save Character")
output = gr.Textbox(label="Status", interactive=False)
def save_character(name, traits, image_path, gender, characters_container):
if not name.strip() or not traits.strip() or not gender:
return "Please enter all fields.", None
# Check for duplicate names
existing_names = [char['name'] for char in data_manager.get_characters()]
if name in existing_names:
return f"Character with name '{name}' already exists. Please choose a different name.", None
character = {'name': name, 'traits': traits, 'gender': gender, 'image': image_path}
data_manager.add_character(character)
# After adding, re-render the character list
return f"Character '{name}' saved successfully.", list_characters(data_manager)
save_button.click(
save_character,
inputs=[name_input, traits_input, image_input, gender_input, gr.State()],
outputs=[output, "characters_renderer"]
)
# Divider
gr.Markdown("---")
# Display Existing Characters
gr.Markdown("## Existing Characters")
# Renderer for characters
characters_renderer = gr.Column(label="Characters")
def list_characters(data_manager):
characters = data_manager.get_characters()
if not characters:
return gr.Markdown("No characters created yet.")
components = []
for char in characters:
with gr.Accordion(label=char['name'], open=False):
with gr.Row():
if char['image']:
img_display = gr.Image(value=char['image'], label="Character Image", type="filepath", interactive=False)
else:
img_display = gr.Markdown("**No Image Provided**")
with gr.Column():
gr.Markdown(f"**Name:** {char['name']}").render()
gr.Markdown(f"**Gender:** {char['gender']}").render()
gr.Markdown(f"**Traits:** {', '.join(char['traits'])}").render()
with gr.Row():
edit_btn = gr.Button("Edit", variant="primary")
delete_btn = gr.Button("Delete", variant="secondary")
edit_output = gr.Textbox(label="", interactive=False, visible=False)
delete_output = gr.Textbox(label="", interactive=False, visible=False)
# Edit Interface (initially hidden)
edit_interface = gr.Column(visible=False)
with edit_interface:
with gr.Row():
edit_name = gr.Textbox(label="Character Name", value=char['name'], placeholder="Enter new name")
edit_traits = gr.Textbox(label="Traits/Appearance Tags (comma separated)", value=', '.join(char['traits']), placeholder="e.g., red hair, blue eyes")
edit_image = gr.Image(label="Upload New Character Image", type="filepath")
edit_gender = gr.Radio(choices=["Boy", "Girl"], label="Gender", value=char['gender'])
save_edit_btn = gr.Button("Save Changes")
save_edit_output = gr.Textbox(label="Edit Status", interactive=False)
def save_edit(original_name, new_name, new_traits, new_image_path, new_gender):
if not new_name.strip() or not new_traits.strip() or not new_gender:
return "Please enter all fields.", None
# If the name has changed, check for duplicates
if new_name != original_name:
existing_names = [c['name'] for c in data_manager.get_characters()]
if new_name in existing_names:
return f"Character with name '{new_name}' already exists. Please choose a different name.", None
updated_char = {
'name': new_name,
'traits': new_traits,
'gender': new_gender,
'image': new_image_path if new_image_path else data_manager.characters[[c['name'] for c in data_manager.get_characters()].index(original_name)]['image']
}
success = data_manager.update_character(original_name, updated_char)
if success:
return f"Character '{new_name}' updated successfully.", list_characters(data_manager)
else:
return "Failed to update character.", None
save_edit_btn.click(
save_edit,
inputs=[gr.State(char['name']), edit_name, edit_traits, edit_image, edit_gender],
outputs=[save_edit_output, "characters_renderer"]
)
# Define edit functionality
def toggle_edit_visibility():
return gr.update(visible=True)
edit_btn.click(
lambda: True,
inputs=None,
outputs=edit_interface
)
# Define delete functionality
def confirm_delete(char_name):
return f"Are you sure you want to delete '{char_name}'?", True
def perform_delete(char_name, confirm):
if confirm:
success = data_manager.delete_character(char_name)
if success:
return f"Character '{char_name}' deleted successfully.", list_characters(data_manager)
else:
return f"Failed to delete character '{char_name}'.", None
return "Deletion cancelled.", None
delete_btn.click(
confirm_delete,
inputs=[char['name']],
outputs=[delete_output, delete_output]
)
delete_confirm = gr.Button("Confirm Delete", visible=False)
delete_cancel = gr.Button("Cancel Delete", visible=False)
# Show confirmation buttons when delete is clicked
def show_delete_buttons(message):
if "Are you sure" in message:
return gr.update(visible=True), gr.update(visible=True)
return gr.update(visible=False), gr.update(visible=False)
delete_btn.click(
show_delete_buttons,
inputs=[delete_output],
outputs=[delete_confirm, delete_cancel]
)
# Handle delete confirmation
delete_confirm.click(
perform_delete,
inputs=[char['name'], gr.Checkbox(label="", value=True)],
outputs=[delete_output, "characters_renderer"]
)
# Handle delete cancellation
delete_cancel.click(
lambda: ("Deletion cancelled.", None),
inputs=None,
outputs=[delete_output, "characters_renderer"]
)
return components
def list_characters_render():
return list_characters(data_manager)
characters_renderer.render(list_characters(data_manager))
def tags_app(data_manager):
with gr.Tab("Tags"):
gr.Markdown("## Edit Tags for Each Category")
for category_name, var_name in categories:
gr.Markdown(f"### {category_name} Tags")
tags_list = data_manager.get_category_tags(var_name)
tags_string = ', '.join(tags_list)
tag_input = gr.Textbox(label=f"{category_name} Tags (comma separated)", value=tags_string)
save_button = gr.Button(f"Save {category_name} Tags")
status_output = gr.Textbox(label="", interactive=False)
def make_save_category_tags_fn(var_name, category_name):
def fn(tags_string):
tags_list = [t.strip() for t in tags_string.split(',') if t.strip()]
data_manager.set_category_tags(var_name, tags_list)
return f"{category_name} tags saved successfully."
return fn
save_fn = make_save_category_tags_fn(var_name, category_name)
save_button.click(save_fn, inputs=tag_input, outputs=status_output)
# Persistent Tags
gr.Markdown(f"### Persistent Tags")
persistent_tags_string = ', '.join(data_manager.get_persistent_tags())
persistent_tags_input = gr.Textbox(label="Persistent Tags (comma separated)", value=persistent_tags_string)
save_persistent_tags_button = gr.Button("Save Persistent Tags")
persistent_status_output = gr.Textbox(label="", interactive=False)
def save_persistent_tags_fn(tags_string):
tags_list = [t.strip() for t in tags_string.split(',') if t.strip()]
data_manager.set_persistent_tags(tags_list)
return "Persistent tags saved successfully."
save_persistent_tags_button.click(save_persistent_tags_fn, inputs=persistent_tags_input, outputs=persistent_status_output)
def main():
data_manager = DataManager(base_dir='/data')
with gr.Blocks() as demo:
with gr.Tabs():
prompt_generator_app(data_manager)
character_creation_app(data_manager)
tags_app(data_manager)
demo.launch()
if __name__ == "__main__":
main()