File size: 16,081 Bytes
9a0bb92
 
 
 
 
 
 
 
 
 
 
 
 
fe73fc1
9a0bb92
 
 
 
 
fe73fc1
 
9a0bb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe73fc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a0bb92
78506fe
9a0bb92
95bdd33
78506fe
76e32d2
c4be8c2
26548b4
9a0bb92
78506fe
5edca44
c2401c0
26548b4
 
 
 
 
f77e7f8
26548b4
 
 
26de42f
26548b4
 
 
 
 
 
26de42f
26548b4
 
 
 
 
 
26de42f
26548b4
 
 
 
 
 
 
f77e7f8
26548b4
77987e2
 
26548b4
 
 
 
 
 
 
 
 
 
 
77987e2
 
26548b4
fe73fc1
26548b4
 
 
 
 
 
 
e2eb8eb
e33e382
 
 
 
 
 
 
 
 
 
 
e2eb8eb
9969ce7
e2eb8eb
 
e33e382
e2eb8eb
 
78506fe
e33e382
 
 
 
 
 
e2eb8eb
fe73fc1
45523be
e2eb8eb
 
78506fe
dbf8ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78506fe
 
 
 
26548b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
# 필요한 라이브러리 임포트
import gradio as gr
import random
import json
import os
import re
from datetime import datetime
from huggingface_hub import InferenceClient
import subprocess
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import random
import openai  # OpenAI API 라이브러리 추가

subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

huggingface_token = os.getenv("HUGGINGFACE_TOKEN")

# OpenAI API 클라이언트 설정
openai.api_key = os.getenv("OPENAI_API_KEY")

# Initialize Florence model
device = "cuda" if torch.cuda.is_available() else "cpu"
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)

# Florence caption function
def florence_caption(image):
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    
    inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
    generated_ids = florence_model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        early_stopping=False,
        do_sample=False,
        num_beams=3,
    )
    generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = florence_processor.post_process_generation(
        generated_text,
        task="<MORE_DETAILED_CAPTION>",
        image_size=(image.width, image.height)
    )
    return parsed_answer["<MORE_DETAILED_CAPTION>"]

# JSON 파일 로드 함수
def load_json_file(file_name):
    file_path = os.path.join("data", file_name)
    with open(file_path, "r") as file:
        return json.load(file)

ARTFORM = load_json_file("artform.json")
PHOTO_TYPE = load_json_file("photo_type.json")
BODY_TYPES = load_json_file("body_types.json")
DEFAULT_TAGS = load_json_file("default_tags.json")
ROLES = load_json_file("roles.json")
HAIRSTYLES = load_json_file("hairstyles.json")
ADDITIONAL_DETAILS = load_json_file("additional_details.json")
PHOTOGRAPHY_STYLES = load_json_file("photography_styles.json")
DEVICE = load_json_file("device.json")
PHOTOGRAPHER = load_json_file("photographer.json")
ARTIST = load_json_file("artist.json")
DIGITAL_ARTFORM = load_json_file("digital_artform.json")
PLACE = load_json_file("place.json")
LIGHTING = load_json_file("lighting.json")
CLOTHING = load_json_file("clothing.json")
COMPOSITION = load_json_file("composition.json")
POSE = load_json_file("pose.json")
BACKGROUND = load_json_file("background.json")

# PromptGenerator 클래스 정의
class PromptGenerator:
    def __init__(self, seed=None):
        self.rng = random.Random(seed)

    def split_and_choose(self, input_str):
        choices = [choice.strip() for choice in input_str.split(",")]
        return self.rng.choices(choices, k=1)[0]

    def get_choice(self, input_str, default_choices):
        if input_str.lower() == "disabled":
            return ""
        elif "," in input_str:
            return self.split_and_choose(input_str)
        elif input_str.lower() == "random":
            return self.rng.choices(default_choices, k=1)[0]
        else:
            return input_str

    def clean_consecutive_commas(self, input_string):
        cleaned_string = re.sub(r',\s*,', ',', input_string)
        return cleaned_string

    def process_string(self, replaced, seed):
        replaced = re.sub(r'\s*,\s*', ',', replaced)
        replaced = re.sub(r',+', ',', replaced)
        original = replaced
        
        first_break_clipl_index = replaced.find("BREAK_CLIPL")
        second_break_clipl_index = replaced.find("BREAK_CLIPL", first_break_clipl_index + len("BREAK_CLIPL"))
        
        if first_break_clipl_index != -1 and second_break_clipl_index != -1:
            clip_content_l = replaced[first_break_clipl_index + len("BREAK_CLIPL"):second_break_clipl_index]
            replaced = replaced[:first_break_clipl_index].strip(", ") + replaced[second_break_clipl_index + len("BREAK_CLIPL"):].strip(", ")
            clip_l = clip_content_l
        else:
            clip_l = ""
        
        first_break_clipg_index = replaced.find("BREAK_CLIPG")
        second_break_clipg_index = replaced.find("BREAK_CLIPG", first_break_clipg_index + len("BREAK_CLIPG"))
        
        if first_break_clipg_index != -1 and second_break_clipg_index != -1:
            clip_content_g = replaced[first_break_clipg_index + len("BREAK_CLIPG"):second_break_clipg_index]
            replaced = replaced[:first_break_clipg_index].strip(", ") + replaced[second_break_clipg_index + len("BREAK_CLIPG"):].strip(", ")
            clip_g = clip_content_g
        else:
            clip_g = ""
        
        t5xxl = replaced
        
        original = original.replace("BREAK_CLIPL", "").replace("BREAK_CLIPG", "")
        original = re.sub(r'\s*,\s*', ',', original)
        original = re.sub(r',+', ',', original)
        clip_l = re.sub(r'\s*,\s*', ',', clip_l)
        clip_l = re.sub(r',+', ',', clip_l)
        clip_g = re.sub(r'\s*,\s*', ',', clip_g)
        clip_g = re.sub(r',+', ',', clip_g)
        if clip_l.startswith(","):
            clip_l = clip_l[1:]
        if clip_g.startswith(","):
            clip_g = clip_g[1:]
        if original.startswith(","):
            original = original[1:]
        if t5xxl.startswith(","):
            t5xxl = t5xxl[1:]

        return original, seed, t5xxl, clip_l, clip_g

    def generate_prompt(self, seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
                        additional_details, photography_styles, device, photographer, artist, digital_artform,
                        place, lighting, clothing, composition, pose, background, input_image):
        # 생략된 기능들...
        pass
    
    def add_caption_to_prompt(self, prompt, caption):
        if caption:
            return f"{prompt}, {caption}"
        return prompt

# HuggingFace 모델을 사용한 텍스트 생성 클래스 정의
class HuggingFaceInferenceNode:
    def __init__(self):
        self.clients = {
            "Mixtral": InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO"),
            "Mistral": InferenceClient("mistralai/Mistral-7B-Instruct-v0.3"),
            "Llama 3": InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct"),
            "Mistral-Nemo": InferenceClient("mistralai/Mistral-Nemo-Instruct-2407")
        }
        self.prompts_dir = "./prompts"
        os.makedirs(self.prompts_dir, exist_ok=True)

    def save_prompt(self, prompt):
        filename_text = "hf_" + prompt.split(',')[0].strip()
        filename_text = re.sub(r'[^\w\-_\. ]', '_', filename_text)
        filename_text = filename_text[:30]  
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        base_filename = f"{filename_text}_{timestamp}.txt"
        filename = os.path.join(self.prompts_dir, base_filename)
        
        with open(filename, "w") as file:
            file.write(prompt)
        
        print(f"Prompt saved to {filename}")

    def generate(self, model, input_text, happy_talk, compress, compression_level, poster, custom_base_prompt=""):
        # 생략된 기능들...
        pass

# gpt-4o-mini와 Cohere Command R+를 사용한 프롬프트 생성 함수
def call_gpt4o_mini(content, system_message, max_tokens=1000, temperature=0.7, top_p=1):
    response = openai.ChatCompletion.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": system_message},
            {"role": "user", "content": content},
        ],
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
    )
    return response.choices[0].message['content']

def call_cohere(content, temperature=0.7, max_tokens=1000):
    response = openai.ChatCompletion.create(
        model="Cohere-Command-R+",
        messages=[
            {"role": "user", "content": content},
        ],
        max_tokens=max_tokens,
        temperature=temperature,
    )
    return response.choices[0].message['content']


# Gradio 인터페이스 생성 함수
def create_interface():
    prompt_generator = PromptGenerator()  # PromptGenerator 클래스가 정의되었으므로 사용 가능
    huggingface_node = HuggingFaceInferenceNode()

    with gr.Blocks(theme='Nymbo/Nymbo_Theme') as demo:
        
        gr.HTML("""<h1 align="center">FLUX 프롬프트 생성기</h1>
                   <p><center>이미지 또는 간단한 텍스트에서 긴 프롬프트를 생성합니다. 짧은 프롬프트를 개선합니다.</center></p>""")

        with gr.Row():
            with gr.Column(scale=2):
                with gr.Accordion("기본 설정"):
                    seed = gr.Number(label="시드", value=random.randint(0, 1000000))
                    custom = gr.Textbox(label="사용자 정의 입력 프롬프트 (선택사항)")
                    subject = gr.Textbox(label="주제 (선택사항)")
                    global_option = gr.Radio(["비활성화", "랜덤"], label="모든 옵션 설정:", value="비활성화")
                
                with gr.Accordion("예술 형식 및 사진 유형", open=False):
                    artform = gr.Dropdown(["비활성화", "랜덤"] + ARTFORM, label="예술 형식", value="비활성화")
                    photo_type = gr.Dropdown(["비활성화", "랜덤"] + PHOTO_TYPE, label="사진 유형", value="비활성화")
            
                with gr.Accordion("캐릭터 세부사항", open=False):
                    body_types = gr.Dropdown(["비활성화", "랜덤"] + BODY_TYPES, label="체형", value="비활성화")
                    default_tags = gr.Dropdown(["비활성화", "랜덤"] + DEFAULT_TAGS, label="기본 태그", value="비활성화")
                    roles = gr.Dropdown(["비활성화", "랜덤"] + ROLES, label="역할", value="비활성화")
                    hairstyles = gr.Dropdown(["비활성화", "랜덤"] + HAIRSTYLES, label="헤어스타일", value="비활성화")
                    clothing = gr.Dropdown(["비활성화", "랜덤"] + CLOTHING, label="의상", value="비활성화")
            
                with gr.Accordion("장면 세부사항", open=False):
                    place = gr.Dropdown(["비활성화", "랜덤"] + PLACE, label="장소", value="비활성화")
                    lighting = gr.Dropdown(["비활성화", "랜덤"] + LIGHTING, label="조명", value="비활성화")
                    composition = gr.Dropdown(["비활성화", "랜덤"] + COMPOSITION, label="구성", value="비활성화")
                    pose = gr.Dropdown(["비활성화", "랜덤"] + POSE, label="포즈", value="비활성화")
                    background = gr.Dropdown(["비활성화", "랜덤"] + BACKGROUND, label="배경", value="비활성화")
            
                with gr.Accordion("스타일 및 아티스트", open=False):
                    additional_details = gr.Dropdown(["비활성화", "랜덤"] + ADDITIONAL_DETAILS, label="추가 세부 사항", value="비활성화")
                    photography_styles = gr.Dropdown(["비활성화", "랜덤"] + PHOTOGRAPHY_STYLES, label="사진 스타일", value="비활성화")
                    device = gr.Dropdown(["비활성화", "랜덤"] + DEVICE, label="장비", value="비활성화")
                    photographer = gr.Dropdown(["비활성화", "랜덤"] + PHOTOGRAPHER, label="사진작가", value="비활성화")
                    artist = gr.Dropdown(["비활성화", "랜덤"] + ARTIST, label="아티스트", value="비활성화")
                    digital_artform = gr.Dropdown(["비활성화", "랜덤"] + DIGITAL_ARTFORM, label="디지털 예술 형식", value="비활성화")
                
                generate_button = gr.Button("프롬프트 생성")

            with gr.Column(scale=2):
                with gr.Accordion("이미지 및 설명", open=False):
                    input_image = gr.Image(label="입력 이미지 (선택사항)")
                    caption_output = gr.Textbox(label="생성된 설명", lines=3)
                    create_caption_button = gr.Button("설명 생성")
                    add_caption_button = gr.Button("프롬프트에 설명 추가")

                with gr.Accordion("프롬프트 생성", open=True):
                    output = gr.Textbox(label="생성된 프롬프트 / 입력 텍스트", lines=4)
                    t5xxl_output = gr.Textbox(label="T5XXL 출력", visible=True)
                    clip_l_output = gr.Textbox(label="CLIP L 출력", visible=True)
                    clip_g_output = gr.Textbox(label="CLIP G 출력", visible=True)
            
            with gr.Column(scale=2):
                with gr.Accordion("LLM을 사용한 프롬프트 생성", open=False):
                    model = gr.Dropdown(["Mixtral", "Mistral", "Llama 3", "Mistral-Nemo", "gpt-4o-mini", "Cohere-Command-R+"], label="모델", value="Llama 3")
                    happy_talk = gr.Checkbox(label="행복한 대화", value=True)
                    compress = gr.Checkbox(label="압축", value=True)
                    compression_level = gr.Radio(["부드럽게", "중간", "강하게"], label="압축 레벨", value="강하게")
                    poster = gr.Checkbox(label="포스터 형식", value=False)
                    custom_base_prompt = gr.Textbox(label="사용자 정의 기본 프롬프트", lines=5)
                generate_text_button = gr.Button("LLM으로 프롬프트 생성")
                text_output = gr.Textbox(label="생성된 텍스트", lines=10)

        def create_caption(image):
            if image is not None:
                return florence_caption(image)
            return ""

        create_caption_button.click(
            create_caption,
            inputs=[input_image],
            outputs=[caption_output]
        )

        generate_button.click(
            prompt_generator.generate_prompt,
            inputs=[seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
                    additional_details, photography_styles, device, photographer, artist, digital_artform,
                    place, lighting, clothing, composition, pose, background],
            outputs=[output, gr.Number(visible=False), t5xxl_output, clip_l_output, clip_g_output]
        )

        add_caption_button.click(
            prompt_generator.add_caption_to_prompt,
            inputs=[output, caption_output],
            outputs=[output]
        )

        generate_text_button.click(
            lambda model, input_text, happy_talk, compress, compression_level, poster, custom_base_prompt: call_gpt4o_mini(input_text, custom_base_prompt) if model == "gpt-4o-mini" else call_cohere(input_text),
            inputs=[model, output, happy_talk, compress, compression_level, poster, custom_base_prompt],
            outputs=text_output
        )

        def update_all_options(choice):
            return {dropdown: gr.update(value=choice) for dropdown in [
                artform, photo_type, body_types, default_tags, roles, hairstyles, clothing,
                place, lighting, composition, pose, background, additional_details,
                photography_styles, device, photographer, artist, digital_artform
            ]}

        global_option.change(
            update_all_options,
            inputs=[global_option],
            outputs=[
                artform, photo_type, body_types, default_tags, roles, hairstyles, clothing,
                place, lighting, composition, pose, background, additional_details,
                photography_styles, device, photographer, artist, digital_artform
            ]
        )

    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.launch()