Spaces:
Paused
Paused
import gradio as gr | |
import json | |
import os | |
import random | |
class DataManager: | |
def __init__(self, characters_file='characters.json', persistent_tags_file='persistent_tags.json', images_folder='character_images'): | |
self.characters_file = characters_file | |
self.persistent_tags_file = persistent_tags_file | |
self.images_folder = images_folder | |
# Make sure image folder exists | |
if not os.path.exists(images_folder): | |
os.makedirs(images_folder) | |
self.load_data() | |
def load_data(self): | |
# Load characters | |
if os.path.exists(self.characters_file): | |
with open(self.characters_file, 'r') as f: | |
self.characters = json.load(f) | |
else: | |
self.characters = [] | |
# Load persistent tags | |
if os.path.exists(self.persistent_tags_file): | |
with open(self.persistent_tags_file, 'r') as f: | |
self.persistent_tags = json.load(f) | |
else: | |
self.persistent_tags = [] | |
def save_characters(self): | |
with open(self.characters_file, 'w') as f: | |
json.dump(self.characters, f) | |
def save_persistent_tags(self): | |
with open(self.persistent_tags_file, 'w') as f: | |
json.dump(self.persistent_tags, f) | |
def add_character(self, character): | |
# Save image to disk and store the filename | |
image_data = character['image'] # This is a PIL Image | |
image_filename = f"{character['name']}.png" | |
image_path = os.path.join(self.images_folder, image_filename) | |
image_data.save(image_path) | |
character['image_path'] = image_path | |
character.pop('image', None) | |
# Assuming traits is a string, split into list if necessary | |
if isinstance(character['traits'], str): | |
character['traits'] = character['traits'].split(',') | |
character['traits'] = [t.strip() for t in character['traits']] | |
self.characters.append(character) | |
self.save_characters() | |
def get_characters(self): | |
# Load character images | |
for char in self.characters: | |
image_path = char.get('image_path') | |
if image_path and os.path.exists(image_path): | |
char['image'] = image_path | |
else: | |
char['image'] = None | |
return self.characters | |
def set_persistent_tags(self, tags_string): | |
if isinstance(tags_string, str): | |
self.persistent_tags = [t.strip() for t in tags_string.split(',') if t.strip()] | |
self.save_persistent_tags() | |
def get_persistent_tags(self): | |
return self.persistent_tags | |
def character_creation_app(data_manager): | |
with gr.Tab("Character Creation"): | |
with gr.Row(): | |
name_input = gr.Textbox(label="Character Name") | |
traits_input = gr.Textbox(label="Traits/Appearance Tags (comma separated)") | |
image_input = gr.Image(label="Upload Character Image", type="pil") | |
save_button = gr.Button("Save Character") | |
output = gr.Textbox(label="Status", interactive=False) | |
def save_character(name, traits, image): | |
if not name or not traits or image is None: | |
return "Please enter all fields." | |
character = {'name': name, 'traits': traits, 'image': image} | |
data_manager.add_character(character) | |
return f"Character '{name}' saved successfully." | |
save_button.click(save_character, inputs=[name_input, traits_input, image_input], outputs=output) | |
def prompt_generator_app(data_manager): | |
with gr.Tab("Prompt Generator"): | |
# Input fields for each category | |
categories = [ | |
('Scene', 'scene_tags'), | |
('Position', 'position_tags'), | |
('Outfit', 'outfit_tags'), | |
('Camera View/Angle', 'camera_tags'), | |
('Concept', 'concept_tags'), | |
('Additional', 'additional_tags'), | |
('LORA', 'lora_tags') | |
] | |
inputs = {} | |
for category_name, var_name in categories: | |
with gr.Row(): | |
tag_input = gr.Textbox(label=f"{category_name} Tags (comma separated)") | |
tag_num = gr.Slider(minimum=0, maximum=10, step=1, value=1, label=f"Number of {category_name} Tags to Select") | |
inputs[f"{var_name}_input"] = tag_input | |
inputs[f"{var_name}_num"] = tag_num | |
# For Character Selection | |
with gr.Group(): | |
gr.Markdown("### Character Selection") | |
character_options = [char['name'] for char in data_manager.get_characters()] | |
character_select = gr.CheckboxGroup(choices=character_options, label="Select Characters") | |
random_characters = gr.Checkbox(label="Select Random Characters") | |
num_characters = gr.Slider(minimum=0, maximum=10, step=1, value=1, label="Number of Characters (if random)") | |
# Number of people in the scene | |
num_people = gr.Slider(minimum=0, maximum=10, step=1, value=1, label="Number of People in the Scene") | |
generate_button = gr.Button("Generate Prompt") | |
prompt_output = gr.Textbox(label="Generated Prompt", lines=5) | |
def generate_prompt(*args): | |
# args correspond to inputs in the order they are defined | |
# Need to map args to variables | |
arg_idx = 0 | |
prompt_tags = [] | |
for category_name, var_name in categories: | |
tags_input = args[arg_idx] | |
tags_num = args[arg_idx + 1] | |
arg_idx += 2 | |
tags_list = [tag.strip() for tag in tags_input.split(',')] if tags_input else [] | |
if tags_list and tags_num > 0: | |
selected_tags = random.sample(tags_list, min(len(tags_list), int(tags_num))) | |
prompt_tags.extend(selected_tags) | |
# Handle Characters | |
selected_characters = args[arg_idx] | |
random_chars = args[arg_idx + 1] | |
num_random_chars = args[arg_idx + 2] | |
num_people_in_scene = args[arg_idx + 3] | |
arg_idx += 4 | |
characters = data_manager.get_characters() | |
selected_chars = [] | |
if random_chars: | |
num = min(len(characters), int(num_random_chars)) | |
selected_chars = random.sample(characters, num) | |
else: | |
selected_chars = [char for char in characters if char['name'] in selected_characters] | |
# Adjust the number of people in the scene | |
if num_people_in_scene > len(selected_chars): | |
# Add generic people or adjust as needed | |
pass # This part can be customized based on requirements | |
for idx, char in enumerate(selected_chars): | |
prompt_tags.append(f"{char['name']}") | |
prompt_tags.extend(char['traits']) | |
# Load persistent tags | |
persistent_tags = data_manager.get_persistent_tags() | |
prompt_tags.extend(persistent_tags) | |
# Assemble prompt with classic syntax (word:weight) | |
prompt = ', '.join(f"({tag}:1.0)" for tag in prompt_tags if tag) | |
# Add ending tags | |
ending_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed" | |
prompt = f"{prompt}, {ending_tags}" | |
return prompt | |
# Prepare the list of inputs for the generate_prompt function | |
inputs_list = [] | |
for var_name in inputs: | |
inputs_list.append(inputs[var_name]) | |
inputs_list.extend([character_select, random_characters, num_characters, num_people]) | |
generate_button.click(generate_prompt, inputs=inputs_list, outputs=prompt_output) | |
def settings_app(data_manager): | |
with gr.Tab("Settings"): | |
persistent_tags_input = gr.Textbox(label="Persistent Tags (comma separated)") | |
save_persistent_tags_button = gr.Button("Save Persistent Tags") | |
status_output = gr.Textbox(label="Status", interactive=False) | |
def save_persistent_tags(tags_string): | |
data_manager.set_persistent_tags(tags_string) | |
return "Persistent tags saved successfully." | |
save_persistent_tags_button.click(save_persistent_tags, inputs=persistent_tags_input, outputs=status_output) | |
def main(): | |
data_manager = DataManager() | |
with gr.Blocks() as demo: | |
prompt_generator_app(data_manager) | |
character_creation_app(data_manager) | |
settings_app(data_manager) | |
demo.launch() | |
if __name__ == "__main__": | |
main() |