File size: 9,329 Bytes
8ca7af9
 
db9ae99
80c0c14
db9ae99
 
 
80c0c14
 
db9ae99
 
 
80c0c14
db9ae99
 
 
 
 
 
 
 
 
 
80c0c14
db9ae99
 
 
80c0c14
 
db9ae99
 
 
8ca7af9
80c0c14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ca7af9
 
 
 
 
 
 
 
 
 
 
 
80c0c14
 
 
 
 
8ca7af9
 
 
 
80c0c14
8ca7af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db9ae99
 
 
 
 
 
80c0c14
 
 
 
 
 
 
 
 
8ca7af9
db9ae99
 
8ca7af9
80c0c14
8ca7af9
80c0c14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8f0231
 
 
8ca7af9
80c0c14
 
 
 
 
 
 
 
 
e8f0231
80c0c14
db9ae99
 
80c0c14
 
 
 
 
 
 
 
 
 
 
 
 
 
db9ae99
8ca7af9
db9ae99
80c0c14
 
8ca7af9
 
80c0c14
 
e8f0231
 
80c0c14
 
 
 
 
 
 
 
 
 
8ca7af9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import gradio as gr
import random
import json
import os
from huggingface_hub import HfApi, hf_hub_download, upload_file

REPO_ID = "throaway2854/promptGen"  # Replace with your actual Space name
DATA_FILE = "saved_data.json"
IMAGES_DIR = "character_images"

def load_data():
    try:
        file_path = hf_hub_download(repo_id=REPO_ID, filename=DATA_FILE)
        with open(file_path, 'r') as f:
            return json.load(f)
    except:
        return {
            "scene_tags": [], "position_tags": [], "outfit_tags": [],
            "camera_tags": [], "concept_tags": [], "lora_tags": [],
            "characters": {}
        }

def save_data(data):
    with open(DATA_FILE, 'w') as f:
        json.dump(data, f)
    api = HfApi()
    api.upload_file(
        path_or_fileobj=DATA_FILE,
        path_in_repo=DATA_FILE,
        repo_id=REPO_ID,
        repo_type="space"
    )

def save_character_image(name, image):
    if not os.path.exists(IMAGES_DIR):
        os.makedirs(IMAGES_DIR)
    image_path = os.path.join(IMAGES_DIR, f"{name}.png")
    image.save(image_path)
    api = HfApi()
    api.upload_file(
        path_or_fileobj=image_path,
        path_in_repo=f"{IMAGES_DIR}/{name}.png",
        repo_id=REPO_ID,
        repo_type="space"
    )
    return image_path

def generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_counts, data):
    all_tags = {
        "scene": scene_tags.split(','),
        "position": position_tags.split(','),
        "outfit": outfit_tags.split(','),
        "camera": camera_tags.split(','),
        "concept": concept_tags.split(','),
    }
    
    for category in all_tags:
        all_tags[category] = [tag.strip() for tag in all_tags[category] if tag.strip()]
    
    character_prompts = []
    for char_name in selected_characters:
        if char_name in data["characters"]:
            char_traits = data["characters"][char_name]["traits"]
            char_prompt = f"{char_name}, " + ", ".join(random.sample(char_traits, min(tag_counts["character"], len(char_traits))))
            character_prompts.append(char_prompt)
    
    selected_tags = []
    for category, tags in all_tags.items():
        if tags:
            selected_tags.extend(f"{tag}:{random.uniform(0.8, 1.2):.2f}" for tag in random.sample(tags, min(tag_counts[category], len(tags))))
    
    if num_people.strip():
        selected_tags.append(f"{num_people} people:1.1")
    
    prompt_parts = character_prompts + selected_tags
    random.shuffle(prompt_parts)
    main_prompt = ", ".join(prompt_parts)
    
    lora_list = [lora.strip() for lora in lora_tags.split(',') if lora.strip()]
    lora_prompt = " ".join(f"<lora:{lora}:1>" for lora in lora_list)
    
    fixed_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed"
    
    full_prompt = f"{main_prompt}, {fixed_tags} {lora_prompt}".strip()
    return full_prompt

def update_data(data, key, value):
    if isinstance(data[key], list):
        data[key] = list(set(data[key] + [v.strip() for v in value.split(',') if v.strip()]))
    save_data(data)
    return data

def create_character(name, traits, image, data):
    if name:
        data["characters"][name] = {
            "traits": [trait.strip() for trait in traits.split(',') if trait.strip()],
            "image": save_character_image(name, image) if image else None
        }
        save_data(data)
    return data, gr.update(choices=list(data["characters"].keys()))

def create_ui():
    data = load_data()

    with gr.Blocks() as demo:
        gr.Markdown("# Advanced Pony SDXL Prompt Generator with Character Creation")
        
        with gr.Tabs():
            with gr.TabItem("Prompt Generator"):
                with gr.Row():
                    with gr.Column():
                        scene_input = gr.Textbox(label="Scene Tags (comma-separated)", value=", ".join(data["scene_tags"]))
                        scene_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of scene tags")
                        
                        num_people_input = gr.Textbox(label="Number of People")
                        
                        position_input = gr.Textbox(label="Position Tags (comma-separated)", value=", ".join(data["position_tags"]))
                        position_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of position tags")
                        
                        character_select = gr.CheckboxGroup(label="Select Characters", choices=list(data["characters"].keys()))
                        character_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of character traits")
                        
                        outfit_input = gr.Textbox(label="Outfit Tags (comma-separated)", value=", ".join(data["outfit_tags"]))
                        outfit_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of outfit tags")
                        
                        camera_input = gr.Textbox(label="Camera View/Angle Tags (comma-separated)", value=", ".join(data["camera_tags"]))
                        camera_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of camera tags")
                        
                        concept_input = gr.Textbox(label="Concept Tags (comma-separated)", value=", ".join(data["concept_tags"]))
                        concept_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of concept tags")
                        
                        lora_input = gr.Textbox(label="LORA Tags (comma-separated)", value=", ".join(data["lora_tags"]))
                        lora_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of LORA tags")
                        
                        generate_button = gr.Button("Generate Prompt")
                    
                    with gr.Column():
                        output = gr.Textbox(label="Generated Prompt", lines=5)
                        
                        with gr.Row():
                            char_images = [char_data["image"] for char_data in data["characters"].values() if char_data["image"]]
                            char_names = [char_name for char_name, char_data in data["characters"].items() if char_data["image"]]
                            gr.Gallery(value=char_images, label="Character Images", show_label=True, elem_id="char_gallery", columns=2, rows=2, height="auto")
            
            with gr.TabItem("Character Creation"):
                with gr.Row():
                    with gr.Column():
                        char_name_input = gr.Textbox(label="Character Name")
                        char_traits_input = gr.Textbox(label="Character Traits (comma-separated)")
                        char_image_input = gr.Image(label="Character Image", type="pil")
                        create_char_button = gr.Button("Create/Update Character")
                    
                    with gr.Column():
                        char_gallery = gr.Gallery(label="Existing Characters", show_label=True, elem_id="char_gallery", columns=2, rows=2, height="auto")

        def update_and_generate(*args):
            nonlocal data
            scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, *tag_counts = args
            data = update_data(data, "scene_tags", scene_tags)
            data = update_data(data, "position_tags", position_tags)
            data = update_data(data, "outfit_tags", outfit_tags)
            data = update_data(data, "camera_tags", camera_tags)
            data = update_data(data, "concept_tags", concept_tags)
            data = update_data(data, "lora_tags", lora_tags)
            
            tag_count_dict = {
                "scene": tag_counts[0], "position": tag_counts[1], "character": tag_counts[2],
                "outfit": tag_counts[3], "camera": tag_counts[4], "concept": tag_counts[5], "lora": tag_counts[6]
            }
            
            return generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_count_dict, data)

        generate_button.click(
            update_and_generate,
            inputs=[scene_input, num_people_input, position_input, character_select, outfit_input, camera_input, concept_input, lora_input,
                    scene_count, position_count, character_count, outfit_count, camera_count, concept_count, lora_count],
            outputs=[output]
        )

        def update_char_gallery():
            char_images = [char_data["image"] for char_data in data["characters"].values() if char_data["image"]]
            return gr.Gallery(value=char_images)

        create_char_button.click(
            create_character,
            inputs=[char_name_input, char_traits_input, char_image_input],
            outputs=[gr.State(data), character_select]
        ).then(
            update_char_gallery,
            outputs=[char_gallery]
        )

    return demo

if __name__ == "__main__":
    demo = create_ui()
    demo.launch()