aliceblue11 commited on
Commit
9a0bb92
·
verified ·
1 Parent(s): 26548b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -8
app.py CHANGED
@@ -1,15 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def create_interface():
2
- prompt_generator = PromptGenerator()
3
  huggingface_node = HuggingFaceInferenceNode()
4
 
5
  with gr.Blocks(theme='Nymbo/Nymbo_Theme') as demo:
6
 
7
  gr.HTML("""<h1 align="center">FLUX 프롬프트 생성기</h1>
8
- <p><center>
9
- <a href="https://github.com/dagthomas/comfyui_dagthomas" target="_blank">[comfyui_dagthomas]</a>
10
- <a href="https://github.com/dagthomas" target="_blank">[dagthomas Github]</a>
11
- <p align="center">이미지 또는 간단한 텍스트에서 긴 프롬프트를 생성합니다. 짧은 프롬프트를 개선합니다.</p>
12
- </center></p>""")
13
 
14
  with gr.Row():
15
  with gr.Column(scale=2):
@@ -17,8 +193,6 @@ def create_interface():
17
  seed = gr.Number(label="시드", value=random.randint(0, 1000000))
18
  custom = gr.Textbox(label="사용자 정의 입력 프롬프트 (선택사항)")
19
  subject = gr.Textbox(label="주제 (선택사항)")
20
-
21
- # 글로벌 옵션 선택을 위한 라디오 버튼 추가
22
  global_option = gr.Radio(["비활성화", "랜덤"], label="모든 옵션 설정:", value="비활성화")
23
 
24
  with gr.Accordion("예술 형식 및 사진 유형", open=False):
 
1
+ # 필요한 라이브러리 임포트
2
+ import gradio as gr
3
+ import random
4
+ import json
5
+ import os
6
+ import re
7
+ from datetime import datetime
8
+ from huggingface_hub import InferenceClient
9
+ import subprocess
10
+ import torch
11
+ from PIL import Image
12
+ from transformers import AutoProcessor, AutoModelForCausalLM
13
+ import random
14
+
15
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
16
+
17
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
18
+
19
+
20
+ # Initialize Florence model
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
23
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
24
+
25
+ # Florence caption function
26
+ def florence_caption(image):
27
+ if not isinstance(image, Image.Image):
28
+ image = Image.fromarray(image)
29
+
30
+ inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
31
+ generated_ids = florence_model.generate(
32
+ input_ids=inputs["input_ids"],
33
+ pixel_values=inputs["pixel_values"],
34
+ max_new_tokens=1024,
35
+ early_stopping=False,
36
+ do_sample=False,
37
+ num_beams=3,
38
+ )
39
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
40
+ parsed_answer = florence_processor.post_process_generation(
41
+ generated_text,
42
+ task="<MORE_DETAILED_CAPTION>",
43
+ image_size=(image.width, image.height)
44
+ )
45
+ return parsed_answer["<MORE_DETAILED_CAPTION>"]
46
+
47
+ # JSON 파일 로드 함수
48
+ def load_json_file(file_name):
49
+ file_path = os.path.join("data", file_name)
50
+ with open(file_path, "r") as file:
51
+ return json.load(file)
52
+
53
+ ARTFORM = load_json_file("artform.json")
54
+ PHOTO_TYPE = load_json_file("photo_type.json")
55
+ BODY_TYPES = load_json_file("body_types.json")
56
+ DEFAULT_TAGS = load_json_file("default_tags.json")
57
+ ROLES = load_json_file("roles.json")
58
+ HAIRSTYLES = load_json_file("hairstyles.json")
59
+ ADDITIONAL_DETAILS = load_json_file("additional_details.json")
60
+ PHOTOGRAPHY_STYLES = load_json_file("photography_styles.json")
61
+ DEVICE = load_json_file("device.json")
62
+ PHOTOGRAPHER = load_json_file("photographer.json")
63
+ ARTIST = load_json_file("artist.json")
64
+ DIGITAL_ARTFORM = load_json_file("digital_artform.json")
65
+ PLACE = load_json_file("place.json")
66
+ LIGHTING = load_json_file("lighting.json")
67
+ CLOTHING = load_json_file("clothing.json")
68
+ COMPOSITION = load_json_file("composition.json")
69
+ POSE = load_json_file("pose.json")
70
+ BACKGROUND = load_json_file("background.json")
71
+
72
+ # PromptGenerator 클래스 정의
73
+ class PromptGenerator:
74
+ def __init__(self, seed=None):
75
+ self.rng = random.Random(seed)
76
+
77
+ def split_and_choose(self, input_str):
78
+ choices = [choice.strip() for choice in input_str.split(",")]
79
+ return self.rng.choices(choices, k=1)[0]
80
+
81
+ def get_choice(self, input_str, default_choices):
82
+ if input_str.lower() == "disabled":
83
+ return ""
84
+ elif "," in input_str:
85
+ return self.split_and_choose(input_str)
86
+ elif input_str.lower() == "random":
87
+ return self.rng.choices(default_choices, k=1)[0]
88
+ else:
89
+ return input_str
90
+
91
+ def clean_consecutive_commas(self, input_string):
92
+ cleaned_string = re.sub(r',\s*,', ',', input_string)
93
+ return cleaned_string
94
+
95
+ def process_string(self, replaced, seed):
96
+ replaced = re.sub(r'\s*,\s*', ',', replaced)
97
+ replaced = re.sub(r',+', ',', replaced)
98
+ original = replaced
99
+
100
+ first_break_clipl_index = replaced.find("BREAK_CLIPL")
101
+ second_break_clipl_index = replaced.find("BREAK_CLIPL", first_break_clipl_index + len("BREAK_CLIPL"))
102
+
103
+ if first_break_clipl_index != -1 and second_break_clipl_index != -1:
104
+ clip_content_l = replaced[first_break_clipl_index + len("BREAK_CLIPL"):second_break_clipl_index]
105
+ replaced = replaced[:first_break_clipl_index].strip(", ") + replaced[second_break_clipl_index + len("BREAK_CLIPL"):].strip(", ")
106
+ clip_l = clip_content_l
107
+ else:
108
+ clip_l = ""
109
+
110
+ first_break_clipg_index = replaced.find("BREAK_CLIPG")
111
+ second_break_clipg_index = replaced.find("BREAK_CLIPG", first_break_clipg_index + len("BREAK_CLIPG"))
112
+
113
+ if first_break_clipg_index != -1 and second_break_clipg_index != -1:
114
+ clip_content_g = replaced[first_break_clipg_index + len("BREAK_CLIPG"):second_break_clipg_index]
115
+ replaced = replaced[:first_break_clipg_index].strip(", ") + replaced[second_break_clipg_index + len("BREAK_CLIPG"):].strip(", ")
116
+ clip_g = clip_content_g
117
+ else:
118
+ clip_g = ""
119
+
120
+ t5xxl = replaced
121
+
122
+ original = original.replace("BREAK_CLIPL", "").replace("BREAK_CLIPG", "")
123
+ original = re.sub(r'\s*,\s*', ',', original)
124
+ original = re.sub(r',+', ',', original)
125
+ clip_l = re.sub(r'\s*,\s*', ',', clip_l)
126
+ clip_l = re.sub(r',+', ',', clip_l)
127
+ clip_g = re.sub(r'\s*,\s*', ',', clip_g)
128
+ clip_g = re.sub(r',+', ',', clip_g)
129
+ if clip_l.startswith(","):
130
+ clip_l = clip_l[1:]
131
+ if clip_g.startswith(","):
132
+ clip_g = clip_g[1:]
133
+ if original.startswith(","):
134
+ original = original[1:]
135
+ if t5xxl.startswith(","):
136
+ t5xxl = t5xxl[1:]
137
+
138
+ return original, seed, t5xxl, clip_l, clip_g
139
+
140
+ def generate_prompt(self, seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
141
+ additional_details, photography_styles, device, photographer, artist, digital_artform,
142
+ place, lighting, clothing, composition, pose, background, input_image):
143
+ # 생략된 기능들...
144
+ pass
145
+
146
+ def add_caption_to_prompt(self, prompt, caption):
147
+ if caption:
148
+ return f"{prompt}, {caption}"
149
+ return prompt
150
+
151
+ # HuggingFace 모델을 사용한 텍스트 생성 클래스 정의
152
+ class HuggingFaceInferenceNode:
153
+ def __init__(self):
154
+ self.clients = {
155
+ "Mixtral": InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO"),
156
+ "Mistral": InferenceClient("mistralai/Mistral-7B-Instruct-v0.3"),
157
+ "Llama 3": InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct"),
158
+ "Mistral-Nemo": InferenceClient("mistralai/Mistral-Nemo-Instruct-2407")
159
+ }
160
+ self.prompts_dir = "./prompts"
161
+ os.makedirs(self.prompts_dir, exist_ok=True)
162
+
163
+ def save_prompt(self, prompt):
164
+ filename_text = "hf_" + prompt.split(',')[0].strip()
165
+ filename_text = re.sub(r'[^\w\-_\. ]', '_', filename_text)
166
+ filename_text = filename_text[:30]
167
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
168
+ base_filename = f"{filename_text}_{timestamp}.txt"
169
+ filename = os.path.join(self.prompts_dir, base_filename)
170
+
171
+ with open(filename, "w") as file:
172
+ file.write(prompt)
173
+
174
+ print(f"Prompt saved to {filename}")
175
+
176
+ def generate(self, model, input_text, happy_talk, compress, compression_level, poster, custom_base_prompt=""):
177
+ # 생략된 기능들...
178
+ pass
179
+
180
+ # Gradio 인터페이스 생성 함수
181
  def create_interface():
182
+ prompt_generator = PromptGenerator() # PromptGenerator 클래스가 정의되었으므로 사용 가능
183
  huggingface_node = HuggingFaceInferenceNode()
184
 
185
  with gr.Blocks(theme='Nymbo/Nymbo_Theme') as demo:
186
 
187
  gr.HTML("""<h1 align="center">FLUX 프롬프트 생성기</h1>
188
+ <p><center>이미지 또는 간단한 텍스트에서 긴 프롬프트를 생성합니다. 짧은 프롬프트를 개선합니다.</center></p>""")
 
 
 
 
189
 
190
  with gr.Row():
191
  with gr.Column(scale=2):
 
193
  seed = gr.Number(label="시드", value=random.randint(0, 1000000))
194
  custom = gr.Textbox(label="사용자 정의 입력 프롬프트 (선택사항)")
195
  subject = gr.Textbox(label="주제 (선택사항)")
 
 
196
  global_option = gr.Radio(["비활성화", "랜덤"], label="모든 옵션 설정:", value="비활성화")
197
 
198
  with gr.Accordion("예술 형식 및 사진 유형", open=False):