promptGen / app.py
throaway2854's picture
Update app.py
f3f4eb1 verified
raw
history blame
8.45 kB
import gradio as gr
import json
import os
import random
class DataManager:
def __init__(self, characters_file='characters.json', persistent_tags_file='persistent_tags.json', images_folder='character_images'):
self.characters_file = characters_file
self.persistent_tags_file = persistent_tags_file
self.images_folder = images_folder
# Make sure image folder exists
if not os.path.exists(images_folder):
os.makedirs(images_folder)
self.load_data()
def load_data(self):
# Load characters
if os.path.exists(self.characters_file):
with open(self.characters_file, 'r') as f:
self.characters = json.load(f)
else:
self.characters = []
# Load persistent tags
if os.path.exists(self.persistent_tags_file):
with open(self.persistent_tags_file, 'r') as f:
self.persistent_tags = json.load(f)
else:
self.persistent_tags = []
def save_characters(self):
with open(self.characters_file, 'w') as f:
json.dump(self.characters, f)
def save_persistent_tags(self):
with open(self.persistent_tags_file, 'w') as f:
json.dump(self.persistent_tags, f)
def add_character(self, character):
# Save image to disk and store the filename
image_data = character['image'] # This is a PIL Image
image_filename = f"{character['name']}.png"
image_path = os.path.join(self.images_folder, image_filename)
image_data.save(image_path)
character['image_path'] = image_path
character.pop('image', None)
# Assuming traits is a string, split into list if necessary
if isinstance(character['traits'], str):
character['traits'] = character['traits'].split(',')
character['traits'] = [t.strip() for t in character['traits']]
self.characters.append(character)
self.save_characters()
def get_characters(self):
# Load character images
for char in self.characters:
image_path = char.get('image_path')
if image_path and os.path.exists(image_path):
char['image'] = image_path
else:
char['image'] = None
return self.characters
def set_persistent_tags(self, tags_string):
if isinstance(tags_string, str):
self.persistent_tags = [t.strip() for t in tags_string.split(',') if t.strip()]
self.save_persistent_tags()
def get_persistent_tags(self):
return self.persistent_tags
def character_creation_app(data_manager):
with gr.Tab("Character Creation"):
with gr.Row():
name_input = gr.Textbox(label="Character Name")
traits_input = gr.Textbox(label="Traits/Appearance Tags (comma separated)")
image_input = gr.Image(label="Upload Character Image", type="pil")
save_button = gr.Button("Save Character")
output = gr.Textbox(label="Status", interactive=False)
def save_character(name, traits, image):
if not name or not traits or image is None:
return "Please enter all fields."
character = {'name': name, 'traits': traits, 'image': image}
data_manager.add_character(character)
return f"Character '{name}' saved successfully."
save_button.click(save_character, inputs=[name_input, traits_input, image_input], outputs=output)
def prompt_generator_app(data_manager):
with gr.Tab("Prompt Generator"):
# Input fields for each category
categories = [
('Scene', 'scene_tags'),
('Position', 'position_tags'),
('Outfit', 'outfit_tags'),
('Camera View/Angle', 'camera_tags'),
('Concept', 'concept_tags'),
('Additional', 'additional_tags'),
('LORA', 'lora_tags')
]
inputs = {}
for category_name, var_name in categories:
with gr.Row():
tag_input = gr.Textbox(label=f"{category_name} Tags (comma separated)")
tag_num = gr.Slider(minimum=0, maximum=10, step=1, value=1, label=f"Number of {category_name} Tags to Select")
inputs[f"{var_name}_input"] = tag_input
inputs[f"{var_name}_num"] = tag_num
# For Character Selection
with gr.Group():
gr.Markdown("### Character Selection")
character_options = [char['name'] for char in data_manager.get_characters()]
character_select = gr.CheckboxGroup(choices=character_options, label="Select Characters")
random_characters = gr.Checkbox(label="Select Random Characters")
num_characters = gr.Slider(minimum=0, maximum=10, step=1, value=1, label="Number of Characters (if random)")
# Number of people in the scene
num_people = gr.Slider(minimum=0, maximum=10, step=1, value=1, label="Number of People in the Scene")
generate_button = gr.Button("Generate Prompt")
prompt_output = gr.Textbox(label="Generated Prompt", lines=5)
def generate_prompt(*args):
# args correspond to inputs in the order they are defined
# Need to map args to variables
arg_idx = 0
prompt_tags = []
for category_name, var_name in categories:
tags_input = args[arg_idx]
tags_num = args[arg_idx + 1]
arg_idx += 2
tags_list = [tag.strip() for tag in tags_input.split(',')] if tags_input else []
if tags_list and tags_num > 0:
selected_tags = random.sample(tags_list, min(len(tags_list), int(tags_num)))
prompt_tags.extend(selected_tags)
# Handle Characters
selected_characters = args[arg_idx]
random_chars = args[arg_idx + 1]
num_random_chars = args[arg_idx + 2]
num_people_in_scene = args[arg_idx + 3]
arg_idx += 4
characters = data_manager.get_characters()
selected_chars = []
if random_chars:
num = min(len(characters), int(num_random_chars))
selected_chars = random.sample(characters, num)
else:
selected_chars = [char for char in characters if char['name'] in selected_characters]
# Adjust the number of people in the scene
if num_people_in_scene > len(selected_chars):
# Add generic people or adjust as needed
pass # This part can be customized based on requirements
for idx, char in enumerate(selected_chars):
prompt_tags.append(f"{char['name']}")
prompt_tags.extend(char['traits'])
# Load persistent tags
persistent_tags = data_manager.get_persistent_tags()
prompt_tags.extend(persistent_tags)
# Assemble prompt with classic syntax (word:weight)
prompt = ', '.join(f"({tag}:1.0)" for tag in prompt_tags if tag)
# Add ending tags
ending_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed"
prompt = f"{prompt}, {ending_tags}"
return prompt
# Prepare the list of inputs for the generate_prompt function
inputs_list = []
for var_name in inputs:
inputs_list.append(inputs[var_name])
inputs_list.extend([character_select, random_characters, num_characters, num_people])
generate_button.click(generate_prompt, inputs=inputs_list, outputs=prompt_output)
def settings_app(data_manager):
with gr.Tab("Settings"):
persistent_tags_input = gr.Textbox(label="Persistent Tags (comma separated)")
save_persistent_tags_button = gr.Button("Save Persistent Tags")
status_output = gr.Textbox(label="Status", interactive=False)
def save_persistent_tags(tags_string):
data_manager.set_persistent_tags(tags_string)
return "Persistent tags saved successfully."
save_persistent_tags_button.click(save_persistent_tags, inputs=persistent_tags_input, outputs=status_output)
def main():
data_manager = DataManager()
with gr.Blocks() as demo:
prompt_generator_app(data_manager)
character_creation_app(data_manager)
settings_app(data_manager)
demo.launch()
if __name__ == "__main__":
main()