import gradio as gr import json import os import random from PIL import Image import base64 import io # Define categories at the top so they are accessible throughout the code categories = [ ('Setting', 'scene_tags'), ('Position', 'position_tags'), ('Outfit', 'outfit_tags'), ('Camera View/Angle', 'camera_tags'), ('Concept', 'concept_tags'), ('Facial Expression', 'facial_expression_tags'), ('Pose', 'pose_tags'), ('Additional', 'additional_tags'), ('LORA', 'lora_tags') ] class DataManager: def __init__(self, base_dir='/data'): self.base_dir = base_dir # Ensure the base directory exists if not os.path.exists(self.base_dir): os.makedirs(self.base_dir) self.characters_file = os.path.join(self.base_dir, 'characters.json') self.persistent_tags_file = os.path.join(self.base_dir, 'persistent_tags.json') self.category_tags_file = os.path.join(self.base_dir, 'category_tags.json') self.images_folder = os.path.join(self.base_dir, 'character_images') # Ensure the images folder exists if not os.path.exists(self.images_folder): os.makedirs(self.images_folder) self.load_data() def load_data(self): # Load characters if os.path.exists(self.characters_file): with open(self.characters_file, 'r') as f: self.characters = json.load(f) else: self.characters = [] # Load persistent tags if os.path.exists(self.persistent_tags_file): with open(self.persistent_tags_file, 'r') as f: self.persistent_tags = json.load(f) else: self.persistent_tags = [] # Load category tags self.load_category_tags() def save_characters(self): with open(self.characters_file, 'w') as f: json.dump(self.characters, f, indent=4) def save_persistent_tags(self): with open(self.persistent_tags_file, 'w') as f: json.dump(self.persistent_tags, f, indent=4) def load_category_tags(self): if os.path.exists(self.category_tags_file): with open(self.category_tags_file, 'r') as f: self.category_tags = json.load(f) else: self.category_tags = {} def save_category_tags(self): with open(self.category_tags_file, 'w') as f: json.dump(self.category_tags, f, indent=4) def get_category_tags(self, category_var_name): # Return the tags list for the given category variable name return self.category_tags.get(category_var_name, []) def set_category_tags(self, category_var_name, tags_list): self.category_tags[category_var_name] = tags_list self.save_category_tags() def get_characters(self): # Ensure images paths are up-to-date for char in self.characters: image_path = char.get('image_path') if image_path and os.path.exists(image_path): char['image'] = image_path else: char['image'] = None return self.characters def set_persistent_tags(self, tags_list): self.persistent_tags = tags_list self.save_persistent_tags() def get_persistent_tags(self): return self.persistent_tags def add_character(self, character): # Save image to disk and store the filename image_data = character['image'] # This is the file path safe_name = "".join(c for c in character['name'] if c.isalnum() or c in (' ', '_', '-')).rstrip() image_filename = f"{safe_name}.png" image_path = os.path.join(self.images_folder, image_filename) # Save the image if provided if image_data: try: # Open and save the uploaded image to the images folder image = Image.open(image_data) image.save(image_path) character['image_path'] = image_path except Exception as e: print(f"Error saving image: {e}") character['image_path'] = None else: character['image_path'] = None character.pop('image', None) # Assuming traits is a string, split into list if necessary if isinstance(character['traits'], str): character['traits'] = character['traits'].split(',') character['traits'] = [t.strip() for t in character['traits']] self.characters.append(character) self.save_characters() def update_character(self, original_name, updated_character): # Find the character by original_name for idx, char in enumerate(self.characters): if char['name'] == original_name: # Handle image update if updated_character['image']: image_path = updated_character['image'] # File path safe_name = "".join(c for c in updated_character['name'] if c.isalnum() or c in (' ', '_', '-')).rstrip() image_filename = f"{safe_name}.png" new_image_path = os.path.join(self.images_folder, image_filename) try: image = Image.open(image_path) image.save(new_image_path) updated_character['image_path'] = new_image_path # Remove old image if name has changed if original_name != updated_character['name']: old_image_path = char.get('image_path') if old_image_path and os.path.exists(old_image_path): os.remove(old_image_path) except Exception as e: print(f"Error updating image: {e}") updated_character['image_path'] = char.get('image_path') else: # If no new image provided, retain the old image_path updated_character['image_path'] = char.get('image_path') updated_character.pop('image', None) # Process traits if isinstance(updated_character['traits'], str): updated_character['traits'] = updated_character['traits'].split(',') updated_character['traits'] = [t.strip() for t in updated_character['traits']] self.characters[idx] = updated_character self.save_characters() return True return False def delete_character(self, name): # Find and remove the character for idx, char in enumerate(self.characters): if char['name'] == name: image_path = char.get('image_path') if image_path and os.path.exists(image_path): os.remove(image_path) del self.characters[idx] self.save_characters() return True return False def prompt_generator_app(data_manager): with gr.Tab("Prompt Generator"): gr.Markdown("## Prompt Generator") # Add a refresh tags button refresh_tags_button = gr.Button("Refresh Tags") inputs = {} tag_displays = {} for category_name, var_name in categories: tags_list = data_manager.get_category_tags(var_name) tags_string = ', '.join(tags_list) max_tags = len(tags_list) if max_tags == 0: default_value = 0 else: default_value = min(1, max_tags) with gr.Group(): gr.Markdown(f"### {category_name}") tag_display = gr.Markdown(f"**Tags:** {tags_string}") tag_num = gr.Slider(minimum=0, maximum=max_tags, step=1, value=default_value, label=f"Number of {category_name} Tags to Select") inputs[f"{var_name}_num"] = tag_num tag_displays[var_name] = (tag_display, tag_num) # For Character Selection with gr.Group(): gr.Markdown("### Character Selection") # Get the list of characters def get_character_options(): characters = data_manager.get_characters() character_options = [] for char in characters: option_label = f"{char['name']} ({char['gender']})" character_options.append(option_label) return character_options character_options = get_character_options() character_select = gr.CheckboxGroup(choices=character_options, label="Select Characters", interactive=True) refresh_characters_button = gr.Button("Refresh Character List") def refresh_characters(): new_options = get_character_options() return gr.CheckboxGroup.update(choices=new_options) refresh_characters_button.click(refresh_characters, outputs=character_select) random_characters = gr.Checkbox(label="Select Random Characters") num_characters = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Characters (if random)") generate_button = gr.Button("Generate Prompt") prompt_output = gr.Textbox(label="Generated Prompt", lines=5) def generate_prompt(*args): arg_idx = 0 prompt_tags = [] for category_name, var_name in categories: tags_list = data_manager.get_category_tags(var_name) tags_num = args[arg_idx] arg_idx += 1 if tags_list and tags_num > 0: selected_tags = random.sample(tags_list, min(len(tags_list), int(tags_num))) prompt_tags.extend(selected_tags) # Handle Characters selected_character_options = args[arg_idx] random_chars = args[arg_idx + 1] num_random_chars = args[arg_idx + 2] arg_idx += 3 characters = data_manager.get_characters() if random_chars: num = min(len(characters), int(num_random_chars)) selected_chars = random.sample(characters, num) else: # Extract selected character names from options selected_chars = [] for option in selected_character_options: name = option.split(' (')[0] for char in characters: if char['name'] == name: selected_chars.append(char) break # Determine the number of boys and girls num_girls = sum(1 for char in selected_chars if char.get('gender') == 'Girl') num_boys = sum(1 for char in selected_chars if char.get('gender') == 'Boy') # Build the initial character count tags character_count_tags = [] if num_girls > 0: character_count_tags.append(f"{num_girls}girl" if num_girls == 1 else f"{num_girls}girls") if num_boys > 0: character_count_tags.append(f"{num_boys}boy" if num_boys == 1 else f"{num_boys}boys") prompt_parts = [] if character_count_tags: prompt_parts.append(', '.join(character_count_tags)) # Build character descriptions character_descriptions = [] for idx, char in enumerate(selected_chars): # Get traits for the character traits = ', '.join(char['traits']) # Create a description for each character # For SDXL models, use the format "[char1 description] AND [char2 description]" # Each character's description is enclosed in parentheses character_description = f"({traits})" character_descriptions.append(character_description) # Join character descriptions appropriately for SDXL models if character_descriptions: character_descriptions_str = ' AND '.join(character_descriptions) prompt_parts.append(character_descriptions_str) # Append selected prompt tags from categories if prompt_tags: prompt_tags_str = ', '.join(prompt_tags) prompt_parts.append(prompt_tags_str) # Load persistent tags persistent_tags = data_manager.get_persistent_tags() if persistent_tags: persistent_tags_str = ', '.join(persistent_tags) prompt_parts.append(persistent_tags_str) # Add ending tags ending_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed" prompt_parts.append(ending_tags) prompt_string = ', '.join(prompt_parts) return prompt_string # Prepare the list of inputs for the generate_prompt function inputs_list = [] for category_name, var_name in categories: inputs_list.append(inputs[f"{var_name}_num"]) # Add character_select directly to inputs inputs_list.extend([character_select, random_characters, num_characters]) generate_button.click(generate_prompt, inputs=inputs_list, outputs=prompt_output) # Function to refresh tags display and sliders def refresh_tags(): updates = [] for category_name, var_name in categories: # Reload tags from data_manager tags_list = data_manager.get_category_tags(var_name) tags_string = ', '.join(tags_list) max_tags = len(tags_list) if max_tags == 0: slider_value = 0 else: slider_value = min(1, max_tags) # Update the tag display and slider tag_display, tag_num = tag_displays[var_name] updates.append(gr.Markdown.update(value=f"**Tags:** {tags_string}")) updates.append(gr.Slider.update(maximum=max_tags, value=slider_value)) return updates # Prepare the outputs list outputs = [component for pair in tag_displays.values() for component in pair] # Connect the refresh_tags function to the refresh_tags_button refresh_tags_button.click(refresh_tags, outputs=outputs) def character_creation_app(data_manager): with gr.Tab("Character Creation"): gr.Markdown("## Create a New Character") with gr.Row(): name_input = gr.Textbox(label="Character Name", placeholder="Enter unique character name") traits_input = gr.Textbox(label="Traits/Appearance Tags (comma separated)", placeholder="e.g., blue hair, green eyes, tall") image_input = gr.Image(label="Upload Character Image", type="filepath") gender_input = gr.Radio(choices=["Boy", "Girl"], label="Gender") save_button = gr.Button("Save Character") output = gr.Textbox(label="Status", interactive=False) def save_character(name, traits, image_path, gender, characters_container): if not name.strip() or not traits.strip() or not gender: return "Please enter all fields.", None # Check for duplicate names existing_names = [char['name'] for char in data_manager.get_characters()] if name in existing_names: return f"Character with name '{name}' already exists. Please choose a different name.", None character = {'name': name, 'traits': traits, 'gender': gender, 'image': image_path} data_manager.add_character(character) # After adding, re-render the character list return f"Character '{name}' saved successfully.", list_characters(data_manager) save_button.click( save_character, inputs=[name_input, traits_input, image_input, gender_input, gr.State()], outputs=[output, "characters_renderer"] ) # Divider gr.Markdown("---") # Display Existing Characters gr.Markdown("## Existing Characters") # Renderer for characters characters_renderer = gr.Column(label="Characters") def list_characters(data_manager): characters = data_manager.get_characters() if not characters: return gr.Markdown("No characters created yet.") components = [] for char in characters: with gr.Accordion(label=char['name'], open=False): with gr.Row(): if char['image']: img_display = gr.Image(value=char['image'], label="Character Image", type="filepath", interactive=False) else: img_display = gr.Markdown("**No Image Provided**") with gr.Column(): gr.Markdown(f"**Name:** {char['name']}").render() gr.Markdown(f"**Gender:** {char['gender']}").render() gr.Markdown(f"**Traits:** {', '.join(char['traits'])}").render() with gr.Row(): edit_btn = gr.Button("Edit", variant="primary") delete_btn = gr.Button("Delete", variant="secondary") edit_output = gr.Textbox(label="", interactive=False, visible=False) delete_output = gr.Textbox(label="", interactive=False, visible=False) # Edit Interface (initially hidden) edit_interface = gr.Column(visible=False) with edit_interface: with gr.Row(): edit_name = gr.Textbox(label="Character Name", value=char['name'], placeholder="Enter new name") edit_traits = gr.Textbox(label="Traits/Appearance Tags (comma separated)", value=', '.join(char['traits']), placeholder="e.g., red hair, blue eyes") edit_image = gr.Image(label="Upload New Character Image", type="filepath") edit_gender = gr.Radio(choices=["Boy", "Girl"], label="Gender", value=char['gender']) save_edit_btn = gr.Button("Save Changes") save_edit_output = gr.Textbox(label="Edit Status", interactive=False) def save_edit(original_name, new_name, new_traits, new_image_path, new_gender): if not new_name.strip() or not new_traits.strip() or not new_gender: return "Please enter all fields.", None # If the name has changed, check for duplicates if new_name != original_name: existing_names = [c['name'] for c in data_manager.get_characters()] if new_name in existing_names: return f"Character with name '{new_name}' already exists. Please choose a different name.", None updated_char = { 'name': new_name, 'traits': new_traits, 'gender': new_gender, 'image': new_image_path if new_image_path else data_manager.characters[[c['name'] for c in data_manager.get_characters()].index(original_name)]['image'] } success = data_manager.update_character(original_name, updated_char) if success: return f"Character '{new_name}' updated successfully.", list_characters(data_manager) else: return "Failed to update character.", None save_edit_btn.click( save_edit, inputs=[gr.State(char['name']), edit_name, edit_traits, edit_image, edit_gender], outputs=[save_edit_output, "characters_renderer"] ) # Define edit functionality def toggle_edit_visibility(): return gr.update(visible=True) edit_btn.click( lambda: True, inputs=None, outputs=edit_interface ) # Define delete functionality def confirm_delete(char_name): return f"Are you sure you want to delete '{char_name}'?", True def perform_delete(char_name, confirm): if confirm: success = data_manager.delete_character(char_name) if success: return f"Character '{char_name}' deleted successfully.", list_characters(data_manager) else: return f"Failed to delete character '{char_name}'.", None return "Deletion cancelled.", None delete_btn.click( confirm_delete, inputs=[char['name']], outputs=[delete_output, delete_output] ) delete_confirm = gr.Button("Confirm Delete", visible=False) delete_cancel = gr.Button("Cancel Delete", visible=False) # Show confirmation buttons when delete is clicked def show_delete_buttons(message): if "Are you sure" in message: return gr.update(visible=True), gr.update(visible=True) return gr.update(visible=False), gr.update(visible=False) delete_btn.click( show_delete_buttons, inputs=[delete_output], outputs=[delete_confirm, delete_cancel] ) # Handle delete confirmation delete_confirm.click( perform_delete, inputs=[char['name'], gr.Checkbox(label="", value=True)], outputs=[delete_output, "characters_renderer"] ) # Handle delete cancellation delete_cancel.click( lambda: ("Deletion cancelled.", None), inputs=None, outputs=[delete_output, "characters_renderer"] ) return components def list_characters_render(): return list_characters(data_manager) characters_renderer.render(list_characters(data_manager)) def tags_app(data_manager): with gr.Tab("Tags"): gr.Markdown("## Edit Tags for Each Category") for category_name, var_name in categories: gr.Markdown(f"### {category_name} Tags") tags_list = data_manager.get_category_tags(var_name) tags_string = ', '.join(tags_list) tag_input = gr.Textbox(label=f"{category_name} Tags (comma separated)", value=tags_string) save_button = gr.Button(f"Save {category_name} Tags") status_output = gr.Textbox(label="", interactive=False) def make_save_category_tags_fn(var_name, category_name): def fn(tags_string): tags_list = [t.strip() for t in tags_string.split(',') if t.strip()] data_manager.set_category_tags(var_name, tags_list) return f"{category_name} tags saved successfully." return fn save_fn = make_save_category_tags_fn(var_name, category_name) save_button.click(save_fn, inputs=tag_input, outputs=status_output) # Persistent Tags gr.Markdown(f"### Persistent Tags") persistent_tags_string = ', '.join(data_manager.get_persistent_tags()) persistent_tags_input = gr.Textbox(label="Persistent Tags (comma separated)", value=persistent_tags_string) save_persistent_tags_button = gr.Button("Save Persistent Tags") persistent_status_output = gr.Textbox(label="", interactive=False) def save_persistent_tags_fn(tags_string): tags_list = [t.strip() for t in tags_string.split(',') if t.strip()] data_manager.set_persistent_tags(tags_list) return "Persistent tags saved successfully." save_persistent_tags_button.click(save_persistent_tags_fn, inputs=persistent_tags_input, outputs=persistent_status_output) def main(): data_manager = DataManager(base_dir='/data') with gr.Blocks() as demo: with gr.Tabs(): prompt_generator_app(data_manager) character_creation_app(data_manager) tags_app(data_manager) demo.launch() if __name__ == "__main__": main()