Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
def create_interface():
|
2 |
-
prompt_generator = PromptGenerator()
|
3 |
huggingface_node = HuggingFaceInferenceNode()
|
4 |
|
5 |
with gr.Blocks(theme='Nymbo/Nymbo_Theme') as demo:
|
6 |
|
7 |
gr.HTML("""<h1 align="center">FLUX 프롬프트 생성기</h1>
|
8 |
-
<p><center>
|
9 |
-
<a href="https://github.com/dagthomas/comfyui_dagthomas" target="_blank">[comfyui_dagthomas]</a>
|
10 |
-
<a href="https://github.com/dagthomas" target="_blank">[dagthomas Github]</a>
|
11 |
-
<p align="center">이미지 또는 간단한 텍스트에서 긴 프롬프트를 생성합니다. 짧은 프롬프트를 개선합니다.</p>
|
12 |
-
</center></p>""")
|
13 |
|
14 |
with gr.Row():
|
15 |
with gr.Column(scale=2):
|
@@ -17,8 +193,6 @@ def create_interface():
|
|
17 |
seed = gr.Number(label="시드", value=random.randint(0, 1000000))
|
18 |
custom = gr.Textbox(label="사용자 정의 입력 프롬프트 (선택사항)")
|
19 |
subject = gr.Textbox(label="주제 (선택사항)")
|
20 |
-
|
21 |
-
# 글로벌 옵션 선택을 위한 라디오 버튼 추가
|
22 |
global_option = gr.Radio(["비활성화", "랜덤"], label="모든 옵션 설정:", value="비활성화")
|
23 |
|
24 |
with gr.Accordion("예술 형식 및 사진 유형", open=False):
|
|
|
1 |
+
# 필요한 라이브러리 임포트
|
2 |
+
import gradio as gr
|
3 |
+
import random
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
from datetime import datetime
|
8 |
+
from huggingface_hub import InferenceClient
|
9 |
+
import subprocess
|
10 |
+
import torch
|
11 |
+
from PIL import Image
|
12 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
13 |
+
import random
|
14 |
+
|
15 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
16 |
+
|
17 |
+
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
|
18 |
+
|
19 |
+
|
20 |
+
# Initialize Florence model
|
21 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
|
23 |
+
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
|
24 |
+
|
25 |
+
# Florence caption function
|
26 |
+
def florence_caption(image):
|
27 |
+
if not isinstance(image, Image.Image):
|
28 |
+
image = Image.fromarray(image)
|
29 |
+
|
30 |
+
inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
|
31 |
+
generated_ids = florence_model.generate(
|
32 |
+
input_ids=inputs["input_ids"],
|
33 |
+
pixel_values=inputs["pixel_values"],
|
34 |
+
max_new_tokens=1024,
|
35 |
+
early_stopping=False,
|
36 |
+
do_sample=False,
|
37 |
+
num_beams=3,
|
38 |
+
)
|
39 |
+
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
40 |
+
parsed_answer = florence_processor.post_process_generation(
|
41 |
+
generated_text,
|
42 |
+
task="<MORE_DETAILED_CAPTION>",
|
43 |
+
image_size=(image.width, image.height)
|
44 |
+
)
|
45 |
+
return parsed_answer["<MORE_DETAILED_CAPTION>"]
|
46 |
+
|
47 |
+
# JSON 파일 로드 함수
|
48 |
+
def load_json_file(file_name):
|
49 |
+
file_path = os.path.join("data", file_name)
|
50 |
+
with open(file_path, "r") as file:
|
51 |
+
return json.load(file)
|
52 |
+
|
53 |
+
ARTFORM = load_json_file("artform.json")
|
54 |
+
PHOTO_TYPE = load_json_file("photo_type.json")
|
55 |
+
BODY_TYPES = load_json_file("body_types.json")
|
56 |
+
DEFAULT_TAGS = load_json_file("default_tags.json")
|
57 |
+
ROLES = load_json_file("roles.json")
|
58 |
+
HAIRSTYLES = load_json_file("hairstyles.json")
|
59 |
+
ADDITIONAL_DETAILS = load_json_file("additional_details.json")
|
60 |
+
PHOTOGRAPHY_STYLES = load_json_file("photography_styles.json")
|
61 |
+
DEVICE = load_json_file("device.json")
|
62 |
+
PHOTOGRAPHER = load_json_file("photographer.json")
|
63 |
+
ARTIST = load_json_file("artist.json")
|
64 |
+
DIGITAL_ARTFORM = load_json_file("digital_artform.json")
|
65 |
+
PLACE = load_json_file("place.json")
|
66 |
+
LIGHTING = load_json_file("lighting.json")
|
67 |
+
CLOTHING = load_json_file("clothing.json")
|
68 |
+
COMPOSITION = load_json_file("composition.json")
|
69 |
+
POSE = load_json_file("pose.json")
|
70 |
+
BACKGROUND = load_json_file("background.json")
|
71 |
+
|
72 |
+
# PromptGenerator 클래스 정의
|
73 |
+
class PromptGenerator:
|
74 |
+
def __init__(self, seed=None):
|
75 |
+
self.rng = random.Random(seed)
|
76 |
+
|
77 |
+
def split_and_choose(self, input_str):
|
78 |
+
choices = [choice.strip() for choice in input_str.split(",")]
|
79 |
+
return self.rng.choices(choices, k=1)[0]
|
80 |
+
|
81 |
+
def get_choice(self, input_str, default_choices):
|
82 |
+
if input_str.lower() == "disabled":
|
83 |
+
return ""
|
84 |
+
elif "," in input_str:
|
85 |
+
return self.split_and_choose(input_str)
|
86 |
+
elif input_str.lower() == "random":
|
87 |
+
return self.rng.choices(default_choices, k=1)[0]
|
88 |
+
else:
|
89 |
+
return input_str
|
90 |
+
|
91 |
+
def clean_consecutive_commas(self, input_string):
|
92 |
+
cleaned_string = re.sub(r',\s*,', ',', input_string)
|
93 |
+
return cleaned_string
|
94 |
+
|
95 |
+
def process_string(self, replaced, seed):
|
96 |
+
replaced = re.sub(r'\s*,\s*', ',', replaced)
|
97 |
+
replaced = re.sub(r',+', ',', replaced)
|
98 |
+
original = replaced
|
99 |
+
|
100 |
+
first_break_clipl_index = replaced.find("BREAK_CLIPL")
|
101 |
+
second_break_clipl_index = replaced.find("BREAK_CLIPL", first_break_clipl_index + len("BREAK_CLIPL"))
|
102 |
+
|
103 |
+
if first_break_clipl_index != -1 and second_break_clipl_index != -1:
|
104 |
+
clip_content_l = replaced[first_break_clipl_index + len("BREAK_CLIPL"):second_break_clipl_index]
|
105 |
+
replaced = replaced[:first_break_clipl_index].strip(", ") + replaced[second_break_clipl_index + len("BREAK_CLIPL"):].strip(", ")
|
106 |
+
clip_l = clip_content_l
|
107 |
+
else:
|
108 |
+
clip_l = ""
|
109 |
+
|
110 |
+
first_break_clipg_index = replaced.find("BREAK_CLIPG")
|
111 |
+
second_break_clipg_index = replaced.find("BREAK_CLIPG", first_break_clipg_index + len("BREAK_CLIPG"))
|
112 |
+
|
113 |
+
if first_break_clipg_index != -1 and second_break_clipg_index != -1:
|
114 |
+
clip_content_g = replaced[first_break_clipg_index + len("BREAK_CLIPG"):second_break_clipg_index]
|
115 |
+
replaced = replaced[:first_break_clipg_index].strip(", ") + replaced[second_break_clipg_index + len("BREAK_CLIPG"):].strip(", ")
|
116 |
+
clip_g = clip_content_g
|
117 |
+
else:
|
118 |
+
clip_g = ""
|
119 |
+
|
120 |
+
t5xxl = replaced
|
121 |
+
|
122 |
+
original = original.replace("BREAK_CLIPL", "").replace("BREAK_CLIPG", "")
|
123 |
+
original = re.sub(r'\s*,\s*', ',', original)
|
124 |
+
original = re.sub(r',+', ',', original)
|
125 |
+
clip_l = re.sub(r'\s*,\s*', ',', clip_l)
|
126 |
+
clip_l = re.sub(r',+', ',', clip_l)
|
127 |
+
clip_g = re.sub(r'\s*,\s*', ',', clip_g)
|
128 |
+
clip_g = re.sub(r',+', ',', clip_g)
|
129 |
+
if clip_l.startswith(","):
|
130 |
+
clip_l = clip_l[1:]
|
131 |
+
if clip_g.startswith(","):
|
132 |
+
clip_g = clip_g[1:]
|
133 |
+
if original.startswith(","):
|
134 |
+
original = original[1:]
|
135 |
+
if t5xxl.startswith(","):
|
136 |
+
t5xxl = t5xxl[1:]
|
137 |
+
|
138 |
+
return original, seed, t5xxl, clip_l, clip_g
|
139 |
+
|
140 |
+
def generate_prompt(self, seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
|
141 |
+
additional_details, photography_styles, device, photographer, artist, digital_artform,
|
142 |
+
place, lighting, clothing, composition, pose, background, input_image):
|
143 |
+
# 생략된 기능들...
|
144 |
+
pass
|
145 |
+
|
146 |
+
def add_caption_to_prompt(self, prompt, caption):
|
147 |
+
if caption:
|
148 |
+
return f"{prompt}, {caption}"
|
149 |
+
return prompt
|
150 |
+
|
151 |
+
# HuggingFace 모델을 사용한 텍스트 생성 클래스 정의
|
152 |
+
class HuggingFaceInferenceNode:
|
153 |
+
def __init__(self):
|
154 |
+
self.clients = {
|
155 |
+
"Mixtral": InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO"),
|
156 |
+
"Mistral": InferenceClient("mistralai/Mistral-7B-Instruct-v0.3"),
|
157 |
+
"Llama 3": InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct"),
|
158 |
+
"Mistral-Nemo": InferenceClient("mistralai/Mistral-Nemo-Instruct-2407")
|
159 |
+
}
|
160 |
+
self.prompts_dir = "./prompts"
|
161 |
+
os.makedirs(self.prompts_dir, exist_ok=True)
|
162 |
+
|
163 |
+
def save_prompt(self, prompt):
|
164 |
+
filename_text = "hf_" + prompt.split(',')[0].strip()
|
165 |
+
filename_text = re.sub(r'[^\w\-_\. ]', '_', filename_text)
|
166 |
+
filename_text = filename_text[:30]
|
167 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
168 |
+
base_filename = f"{filename_text}_{timestamp}.txt"
|
169 |
+
filename = os.path.join(self.prompts_dir, base_filename)
|
170 |
+
|
171 |
+
with open(filename, "w") as file:
|
172 |
+
file.write(prompt)
|
173 |
+
|
174 |
+
print(f"Prompt saved to {filename}")
|
175 |
+
|
176 |
+
def generate(self, model, input_text, happy_talk, compress, compression_level, poster, custom_base_prompt=""):
|
177 |
+
# 생략된 기능들...
|
178 |
+
pass
|
179 |
+
|
180 |
+
# Gradio 인터페이스 생성 함수
|
181 |
def create_interface():
|
182 |
+
prompt_generator = PromptGenerator() # PromptGenerator 클래스가 정의되었으므로 사용 가능
|
183 |
huggingface_node = HuggingFaceInferenceNode()
|
184 |
|
185 |
with gr.Blocks(theme='Nymbo/Nymbo_Theme') as demo:
|
186 |
|
187 |
gr.HTML("""<h1 align="center">FLUX 프롬프트 생성기</h1>
|
188 |
+
<p><center>이미지 또는 간단한 텍스트에서 긴 프롬프트를 생성합니다. 짧은 프롬프트를 개선합니다.</center></p>""")
|
|
|
|
|
|
|
|
|
189 |
|
190 |
with gr.Row():
|
191 |
with gr.Column(scale=2):
|
|
|
193 |
seed = gr.Number(label="시드", value=random.randint(0, 1000000))
|
194 |
custom = gr.Textbox(label="사용자 정의 입력 프롬프트 (선택사항)")
|
195 |
subject = gr.Textbox(label="주제 (선택사항)")
|
|
|
|
|
196 |
global_option = gr.Radio(["비활성화", "랜덤"], label="모든 옵션 설정:", value="비활성화")
|
197 |
|
198 |
with gr.Accordion("예술 형식 및 사진 유형", open=False):
|