File size: 8,446 Bytes
8ca7af9
db9ae99
80c0c14
46e41c9
db9ae99
46e41c9
 
 
 
 
 
 
 
 
 
 
8ca7af9
46e41c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ca7af9
46e41c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ca7af9
46e41c9
 
 
 
 
 
 
 
 
8ca7af9
46e41c9
 
 
 
8ca7af9
46e41c9
 
8ca7af9
46e41c9
 
 
 
 
 
 
 
db9ae99
46e41c9
 
 
 
 
 
80c0c14
46e41c9
db9ae99
46e41c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eedb909
46e41c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ca7af9
46e41c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3f4eb1
46e41c9
 
 
8ca7af9
 
46e41c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import gradio as gr
import json
import os
import random

class DataManager:
    def __init__(self, characters_file='characters.json', persistent_tags_file='persistent_tags.json', images_folder='character_images'):
        self.characters_file = characters_file
        self.persistent_tags_file = persistent_tags_file
        self.images_folder = images_folder

        # Make sure image folder exists
        if not os.path.exists(images_folder):
            os.makedirs(images_folder)

        self.load_data()
    
    def load_data(self):
        # Load characters
        if os.path.exists(self.characters_file):
            with open(self.characters_file, 'r') as f:
                self.characters = json.load(f)
        else:
            self.characters = []
        
        # Load persistent tags
        if os.path.exists(self.persistent_tags_file):
            with open(self.persistent_tags_file, 'r') as f:
                self.persistent_tags = json.load(f)
        else:
            self.persistent_tags = []

    def save_characters(self):
        with open(self.characters_file, 'w') as f:
            json.dump(self.characters, f)

    def save_persistent_tags(self):
        with open(self.persistent_tags_file, 'w') as f:
            json.dump(self.persistent_tags, f)
    
    def add_character(self, character):
        # Save image to disk and store the filename
        image_data = character['image']  # This is a PIL Image
        image_filename = f"{character['name']}.png"
        image_path = os.path.join(self.images_folder, image_filename)
        image_data.save(image_path)
        character['image_path'] = image_path
        character.pop('image', None)

        # Assuming traits is a string, split into list if necessary
        if isinstance(character['traits'], str):
            character['traits'] = character['traits'].split(',')
            character['traits'] = [t.strip() for t in character['traits']]

        self.characters.append(character)
        self.save_characters()
    
    def get_characters(self):
        # Load character images
        for char in self.characters:
            image_path = char.get('image_path')
            if image_path and os.path.exists(image_path):
                char['image'] = image_path
            else:
                char['image'] = None
        return self.characters
    
    def set_persistent_tags(self, tags_string):
        if isinstance(tags_string, str):
            self.persistent_tags = [t.strip() for t in tags_string.split(',') if t.strip()]
            self.save_persistent_tags()
    
    def get_persistent_tags(self):
        return self.persistent_tags

def character_creation_app(data_manager):
    with gr.Tab("Character Creation"):
        with gr.Row():
            name_input = gr.Textbox(label="Character Name")
            traits_input = gr.Textbox(label="Traits/Appearance Tags (comma separated)")
        image_input = gr.Image(label="Upload Character Image", type="pil")
        save_button = gr.Button("Save Character")
        output = gr.Textbox(label="Status", interactive=False)

    def save_character(name, traits, image):
        if not name or not traits or image is None:
            return "Please enter all fields."
        character = {'name': name, 'traits': traits, 'image': image}
        data_manager.add_character(character)
        return f"Character '{name}' saved successfully."

    save_button.click(save_character, inputs=[name_input, traits_input, image_input], outputs=output)

def prompt_generator_app(data_manager):
    with gr.Tab("Prompt Generator"):
        # Input fields for each category
        categories = [
            ('Scene', 'scene_tags'),
            ('Position', 'position_tags'),
            ('Outfit', 'outfit_tags'),
            ('Camera View/Angle', 'camera_tags'),
            ('Concept', 'concept_tags'),
            ('Additional', 'additional_tags'),
            ('LORA', 'lora_tags')
        ]

        inputs = {}
        for category_name, var_name in categories:
            with gr.Row():
                tag_input = gr.Textbox(label=f"{category_name} Tags (comma separated)")
                tag_num = gr.Slider(minimum=0, maximum=10, step=1, value=1, label=f"Number of {category_name} Tags to Select")
            inputs[f"{var_name}_input"] = tag_input
            inputs[f"{var_name}_num"] = tag_num

        # For Character Selection
        with gr.Group():
            gr.Markdown("### Character Selection")
            character_options = [char['name'] for char in data_manager.get_characters()]
            character_select = gr.CheckboxGroup(choices=character_options, label="Select Characters")
            random_characters = gr.Checkbox(label="Select Random Characters")
            num_characters = gr.Slider(minimum=0, maximum=10, step=1, value=1, label="Number of Characters (if random)")
        
        # Number of people in the scene
        num_people = gr.Slider(minimum=0, maximum=10, step=1, value=1, label="Number of People in the Scene")

        generate_button = gr.Button("Generate Prompt")
        prompt_output = gr.Textbox(label="Generated Prompt", lines=5)

    def generate_prompt(*args):
        # args correspond to inputs in the order they are defined
        # Need to map args to variables
        arg_idx = 0

        prompt_tags = []
        for category_name, var_name in categories:
            tags_input = args[arg_idx]
            tags_num = args[arg_idx + 1]
            arg_idx += 2

            tags_list = [tag.strip() for tag in tags_input.split(',')] if tags_input else []
            if tags_list and tags_num > 0:
                selected_tags = random.sample(tags_list, min(len(tags_list), int(tags_num)))
                prompt_tags.extend(selected_tags)
        
        # Handle Characters
        selected_characters = args[arg_idx]
        random_chars = args[arg_idx + 1]
        num_random_chars = args[arg_idx + 2]
        num_people_in_scene = args[arg_idx + 3]

        arg_idx += 4

        characters = data_manager.get_characters()
        selected_chars = []
        if random_chars:
            num = min(len(characters), int(num_random_chars))
            selected_chars = random.sample(characters, num)
        else:
            selected_chars = [char for char in characters if char['name'] in selected_characters]

        # Adjust the number of people in the scene
        if num_people_in_scene > len(selected_chars):
            # Add generic people or adjust as needed
            pass  # This part can be customized based on requirements

        for idx, char in enumerate(selected_chars):
            prompt_tags.append(f"{char['name']}")
            prompt_tags.extend(char['traits'])

        # Load persistent tags
        persistent_tags = data_manager.get_persistent_tags()
        prompt_tags.extend(persistent_tags)

        # Assemble prompt with classic syntax (word:weight)
        prompt = ', '.join(f"({tag}:1.0)" for tag in prompt_tags if tag)
        # Add ending tags
        ending_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed"
        prompt = f"{prompt}, {ending_tags}"
        return prompt

    # Prepare the list of inputs for the generate_prompt function
    inputs_list = []
    for var_name in inputs:
        inputs_list.append(inputs[var_name])
    inputs_list.extend([character_select, random_characters, num_characters, num_people])

    generate_button.click(generate_prompt, inputs=inputs_list, outputs=prompt_output)

def settings_app(data_manager):
    with gr.Tab("Settings"):
        persistent_tags_input = gr.Textbox(label="Persistent Tags (comma separated)")
        save_persistent_tags_button = gr.Button("Save Persistent Tags")
        status_output = gr.Textbox(label="Status", interactive=False)

    def save_persistent_tags(tags_string):
        data_manager.set_persistent_tags(tags_string)
        return "Persistent tags saved successfully."

    save_persistent_tags_button.click(save_persistent_tags, inputs=persistent_tags_input, outputs=status_output)

def main():
    data_manager = DataManager()
    with gr.Blocks() as demo:
        prompt_generator_app(data_manager)
        character_creation_app(data_manager)
        settings_app(data_manager)

    demo.launch()

if __name__ == "__main__":
    main()