ginipick commited on
Commit
80e38a2
ยท
verified ยท
1 Parent(s): 7b9b23e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -72
app.py CHANGED
@@ -1,22 +1,40 @@
1
- import spaces
2
- import argparse
3
  import os
4
  import time
5
  from os import path
 
 
 
 
 
 
 
6
  from safetensors.torch import load_file
7
  from huggingface_hub import hf_hub_download
8
 
9
- cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
10
- os.environ["TRANSFORMERS_CACHE"] = cache_path
11
- os.environ["HF_HUB_CACHE"] = cache_path
12
- os.environ["HF_HOME"] = cache_path
13
-
14
  import gradio as gr
15
- import torch
16
  from diffusers import FluxPipeline
17
 
18
- torch.backends.cuda.matmul.allow_tf32 = True
 
 
 
 
 
 
 
 
 
 
19
 
 
 
 
 
 
 
 
 
20
  class timer:
21
  def __init__(self, method_name="timed process"):
22
  self.method = method_name
@@ -27,83 +45,234 @@ class timer:
27
  end = time.time()
28
  print(f"{self.method} took {str(round(end - self.start, 2))}s")
29
 
30
- if not path.exists(cache_path):
31
- os.makedirs(cache_path, exist_ok=True)
 
 
 
 
32
 
33
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
34
- pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
 
 
 
 
 
 
 
35
  pipe.fuse_lora(lora_scale=0.125)
 
 
36
  pipe.to(device="cuda", dtype=torch.bfloat16)
37
 
38
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
39
- gr.Markdown(
40
- """
41
- <div style="text-align: center; max-width: 650px; margin: 0 auto;">
42
- <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">Hyper-FLUX-8steps-LoRA</h1>
43
- <p style="font-size: 1rem; margin-bottom: 1.5rem;">AutoML team from ByteDance</p>
44
- </div>
45
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  )
47
 
48
- with gr.Row():
49
- with gr.Column(scale=3):
50
- with gr.Group():
51
- prompt = gr.Textbox(
52
- label="Your Image Description",
53
- placeholder="E.g., A serene landscape with mountains and a lake at sunset",
54
- lines=3
55
- )
56
-
57
- with gr.Accordion("Advanced Settings", open=False):
58
- with gr.Group():
59
- with gr.Row():
60
- height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
61
- width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
62
-
63
- with gr.Row():
64
- steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8)
65
- scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
66
-
67
- seed = gr.Number(label="Seed (for reproducibility)", value=3413, precision=0)
68
-
69
- generate_btn = gr.Button("Generate Image", variant="primary", scale=1)
70
-
71
- with gr.Column(scale=4):
72
- output = gr.Image(label="Your Generated Image")
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  gr.Markdown(
75
  """
76
- <div style="max-width: 650px; margin: 2rem auto; padding: 1rem; border-radius: 10px; background-color: #f0f0f0;">
77
- <h2 style="font-size: 1.5rem; margin-bottom: 1rem;">How to Use</h2>
78
- <ol style="padding-left: 1.5rem;">
79
- <li>Enter a detailed description of the image you want to create.</li>
80
- <li>Adjust advanced settings if desired (tap to expand).</li>
81
- <li>Tap "Generate Image" and wait for your creation!</li>
82
- </ol>
83
- <p style="margin-top: 1rem; font-style: italic;">Tip: Be specific in your description for best results!</p>
84
- </div>
85
  """
86
  )
87
 
88
- @spaces.GPU
89
- def process_image(height, width, steps, scales, prompt, seed):
90
- global pipe
91
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
92
- return pipe(
93
- prompt=[prompt],
94
- generator=torch.Generator().manual_seed(int(seed)),
95
- num_inference_steps=int(steps),
96
- guidance_scale=float(scales),
97
- height=int(height),
98
- width=int(width),
99
- max_sequence_length=256
100
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
102
  generate_btn.click(
103
- process_image,
104
- inputs=[height, width, steps, scales, prompt, seed],
105
- outputs=output
 
 
 
 
 
 
106
  )
107
 
 
108
  if __name__ == "__main__":
109
- demo.launch()
 
 
 
1
  import os
2
  import time
3
  from os import path
4
+ import tempfile
5
+ import uuid
6
+ import base64
7
+ import mimetypes
8
+ import json
9
+
10
+ import torch
11
  from safetensors.torch import load_file
12
  from huggingface_hub import hf_hub_download
13
 
14
+ # Diffusers ๊ด€๋ จ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
 
 
 
 
15
  import gradio as gr
 
16
  from diffusers import FluxPipeline
17
 
18
+ # Google GenAI ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
19
+ from google import genai
20
+ from google.genai import types
21
+
22
+ #######################################
23
+ # 0. ํ™˜๊ฒฝ์„ค์ •
24
+ #######################################
25
+
26
+ # ๋ชจ๋ธ ์บ์‹œ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
27
+ BASE_DIR = path.dirname(path.abspath(__file__)) if "__file__" in globals() else os.getcwd()
28
+ CACHE_PATH = path.join(BASE_DIR, "models")
29
 
30
+ os.environ["TRANSFORMERS_CACHE"] = CACHE_PATH
31
+ os.environ["HF_HUB_CACHE"] = CACHE_PATH
32
+ os.environ["HF_HOME"] = CACHE_PATH
33
+
34
+ # Google GenAI ์‚ฌ์šฉ์„ ์œ„ํ•ด์„œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
35
+ # os.environ["GAPI_TOKEN"] = "<YOUR_GOOGLE_GENAI_API_KEY>"
36
+
37
+ # ์ž‘์—… ์‹œ๊ฐ„ ์ธก์ •์„ ์œ„ํ•œ ๊ฐ„๋‹จํ•œ ํƒ€์ด๋จธ ํด๋ž˜์Šค
38
  class timer:
39
  def __init__(self, method_name="timed process"):
40
  self.method = method_name
 
45
  end = time.time()
46
  print(f"{self.method} took {str(round(end - self.start, 2))}s")
47
 
48
+ #######################################
49
+ # 1. FLUX ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ
50
+ #######################################
51
+
52
+ if not path.exists(CACHE_PATH):
53
+ os.makedirs(CACHE_PATH, exist_ok=True)
54
 
55
+ # FLUX ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ
56
+ pipe = FluxPipeline.from_pretrained(
57
+ "black-forest-labs/FLUX.1-dev",
58
+ torch_dtype=torch.bfloat16
59
+ )
60
+
61
+ # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
62
+ lora_path = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
63
+ pipe.load_lora_weights(lora_path)
64
  pipe.fuse_lora(lora_scale=0.125)
65
+
66
+ # GPU๋กœ ์˜ฎ๊ธฐ๊ธฐ
67
  pipe.to(device="cuda", dtype=torch.bfloat16)
68
 
69
+ #######################################
70
+ # 2. Google GenAI๋ฅผ ํ†ตํ•œ ์ด๋ฏธ์ง€ ๋‚ด ํ…์ŠคํŠธ ๋ณ€ํ™˜ ํ•จ์ˆ˜
71
+ #######################################
72
+
73
+ def save_binary_file(file_name, data):
74
+ """Google GenAI์—์„œ ์‘๋‹ต๋ฐ›์€ ์ด์ง„ ๋ฐ์ดํ„ฐ๋ฅผ ์ด๋ฏธ์ง€ ํŒŒ์ผ๋กœ ์ €์žฅ"""
75
+ with open(file_name, "wb") as f:
76
+ f.write(data)
77
+
78
+ def generate_by_google_genai(text, file_name, model="gemini-2.0-flash-exp"):
79
+ """
80
+ Google GenAI(gemini) ๋ชจ๋ธ์„ ํ†ตํ•ด ์ด๋ฏธ์ง€/ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•˜๊ฑฐ๋‚˜ ๋ณ€ํ™˜.
81
+ - text: ๋ณ€๊ฒฝํ•  ํ…์ŠคํŠธ๋‚˜ ๋ช…๋ น์–ด ๋“ฑ ํ”„๋กฌํ”„ํŠธ
82
+ - file_name: ์›๋ณธ ์ด๋ฏธ์ง€(์˜ˆ: .png) ๊ฒฝ๋กœ
83
+ - model: ์‚ฌ์šฉํ•  gemini ๋ชจ๋ธ ์ด๋ฆ„
84
+ """
85
+ # 1) Google Client ์ดˆ๊ธฐํ™”
86
+ client = genai.Client(api_key=os.getenv("GAPI_TOKEN"))
87
+
88
+ # 2) ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ
89
+ files = [client.files.upload(file=file_name)]
90
+
91
+ # 3) gemini์— ์ „๋‹ฌํ•  Content ์ค€๋น„ (์ด๋ฏธ์ง€ + ํ”„๋กฌํ”„ํŠธ)
92
+ contents = [
93
+ types.Content(
94
+ role="user",
95
+ parts=[
96
+ types.Part.from_uri(
97
+ file_uri=files[0].uri,
98
+ mime_type=files[0].mime_type,
99
+ ),
100
+ types.Part.from_text(text=text),
101
+ ],
102
+ ),
103
+ ]
104
+
105
+ # 4) ์ƒ์„ฑ/๋ณ€ํ™˜ ์„ค์ •
106
+ generate_content_config = types.GenerateContentConfig(
107
+ temperature=1,
108
+ top_p=0.95,
109
+ top_k=40,
110
+ max_output_tokens=8192,
111
+ response_modalities=["image", "text"],
112
+ response_mime_type="text/plain",
113
  )
114
 
115
+ text_response = ""
116
+ image_path = None
117
+
118
+ # ์ž„์‹œ ํŒŒ์ผ๋กœ ์ด๋ฏธ์ง€ ๋ฐ›์„ ์ค€๋น„
119
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
120
+ temp_path = tmp.name
121
+ # 5) ์ŠคํŠธ๋ฆผ ํ˜•ํƒœ๋กœ ์‘๋‹ต ๋ฐ›์•„์„œ ์ด๋ฏธ์ง€/ํ…์ŠคํŠธ ๊ตฌ๋ถ„ ์ฒ˜๋ฆฌ
122
+ for chunk in client.models.generate_content_stream(
123
+ model=model,
124
+ contents=contents,
125
+ config=generate_content_config,
126
+ ):
127
+ if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
128
+ continue
129
+ candidate = chunk.candidates[0].content.parts[0]
130
+
131
+ # inline_data๊ฐ€ ์žˆ์œผ๋ฉด ์ด๋ฏธ์ง€ ์‘๋‹ต
132
+ if candidate.inline_data:
133
+ save_binary_file(temp_path, candidate.inline_data.data)
134
+ print(f"File of mime type {candidate.inline_data.mime_type} saved to: {temp_path}")
135
+ image_path = temp_path
136
+ break # ์ด๋ฏธ์ง€๊ฐ€ ์˜ค๋ฉด ์šฐ์„  ๋ฉˆ์ถค
137
+ else:
138
+ # ์—†์œผ๋ฉด ํ…์ŠคํŠธ๋ฅผ ๋ˆ„์ 
139
+ text_response += chunk.text + "\n"
140
 
141
+ # ์—…๋กœ๋“œ ํŒŒ์ผ(google.genai.files.File) ๊ฐ์ฒด ์ œ๊ฑฐ
142
+ del files
143
+
144
+ return image_path, text_response
145
+
146
+ #######################################
147
+ # 3. Gradio ํ•จ์ˆ˜: (1) FLUX๋กœ ์ด๋ฏธ์ง€ ์ƒ์„ฑ -> (2) Google GenAI๋กœ ํ…์ŠคํŠธ ๊ต์ฒด
148
+ #######################################
149
+
150
+ def generate_initial_image(prompt, text, height, width, steps, scale, seed):
151
+ """
152
+ FLUX ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ฌ์šฉํ•ด 'ํ…์ŠคํŠธ๊ฐ€ ํฌํ•จ๋œ ์ด๋ฏธ์ง€๋ฅผ' ๋จผ์ € ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜.
153
+ prompt: ์ด๋ฏธ์ง€ ๋ฐฐ๊ฒฝ/์žฅ๋ฉด/์Šคํƒ€์ผ ๋ฌ˜์‚ฌ๋ฅผ ์œ„ํ•œ ํ”„๋กฌํ”„ํŠธ
154
+ text: ์‹ค์ œ๋กœ ์ด๋ฏธ์ง€์— ๋“ค์–ด๊ฐ€์•ผ ํ•  ๋ฌธ๊ตฌ(์˜ˆ: "์•ˆ๋…•ํ•˜์„ธ์š”", "Hello world" ๋“ฑ)
155
+ """
156
+ # ์ด๋ฏธ์ง€์— ํ…์ŠคํŠธ๋ฅผ ํฌํ•จ์‹œํ‚ค๋ ค๋ฉด ํ”„๋กฌํ”„ํŠธ์— ์ง์ ‘ ๋ฌธ๊ตฌ ์š”์ฒญ์„ ๋„ฃ๋Š” ๊ฒƒ์ด ์ค‘์š”.
157
+ # Diffusion ๋ชจ๋ธ์— ๋”ฐ๋ผ ์ž˜ ๋ฐ˜์˜๋˜์ง€ ์•Š์„ ์ˆ˜๋„ ์žˆ์œผ๋‹ˆ, ๊ตฌ์ฒด์ ์œผ๋กœ ๊ธฐ์žฌํ• ์ˆ˜๋ก ์œ ๋ฆฌ.
158
+ # ์˜ˆ: "A poster with large bold Korean text that says '์•ˆ๋…•ํ•˜์„ธ์š”' in red color ..."
159
+ # ์—ฌ๊ธฐ์„œ๋Š” ๊ฐ„๋‹จํžˆ prompt ๋’ค์— ํ…์ŠคํŠธ ์‚ฝ์ž… ์˜ˆ์‹œ๋ฅผ ๋ณด์—ฌ์คŒ
160
+ combined_prompt = f"{prompt} with clear readable text that says '{text}'"
161
+
162
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
163
+ result = pipe(
164
+ prompt=[combined_prompt],
165
+ generator=torch.Generator().manual_seed(int(seed)),
166
+ num_inference_steps=int(steps),
167
+ guidance_scale=float(scale),
168
+ height=int(height),
169
+ width=int(width),
170
+ max_sequence_length=256
171
+ ).images[0]
172
+
173
+ return result
174
+
175
+ def change_text_in_image(original_image, new_text):
176
+ """
177
+ Google GenAI์˜ gemini ๋ชจ๋ธ์„ ํ†ตํ•ด,
178
+ ์—…๋กœ๋“œ๋œ ์ด๋ฏธ์ง€ ๋‚ด๋ถ€์˜ ๋ฌธ๊ตฌ๋ฅผ `new_text`๋กœ ๋ณ€๊ฒฝํ•ด์ฃผ๋Š” ํ•จ์ˆ˜.
179
+ """
180
+ try:
181
+ # ์ž„์‹œ ํŒŒ์ผ์— ๋จผ์ € ์ €์žฅ
182
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
183
+ original_path = tmp.name
184
+ original_image.save(original_path)
185
+
186
+ # Gemini ๋ชจ๋ธ ํ˜ธ์ถœ
187
+ image_path, text_response = generate_by_google_genai(
188
+ text=f"Change the text in this image to: '{new_text}'",
189
+ file_name=original_path
190
+ )
191
+
192
+ # ๊ฒฐ๊ณผ๊ฐ€ ์ด๋ฏธ์ง€๋กœ ์™”๋‹ค๋ฉด
193
+ if image_path:
194
+ modified_img = gr.processing_utils.decode_base64_to_image(
195
+ base64.b64encode(open(image_path, "rb").read())
196
+ )
197
+ return modified_img, "" # (๊ฒฐ๊ณผ ์ด๋ฏธ์ง€, ๋นˆ ํ…์ŠคํŠธ)
198
+ else:
199
+ # ์ด๋ฏธ์ง€๊ฐ€ ์—†์ด ํ…์ŠคํŠธ๋งŒ ์‘๋‹ต์œผ๋กœ ์˜จ ๊ฒฝ์šฐ
200
+ return None, text_response
201
+
202
+ except Exception as e:
203
+ raise gr.Error(f"Error: {e}")
204
+
205
+
206
+ #######################################
207
+ # 4. Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
208
+ #######################################
209
+
210
+ with gr.Blocks(title="Flux + Google GenAI Text Replacement") as demo:
211
  gr.Markdown(
212
  """
213
+ # Flux ๊ธฐ๋ฐ˜ ์ด๋ฏธ์ง€ ์ƒ์„ฑ + Google GenAI๋ฅผ ํ†ตํ•œ ํ…์ŠคํŠธ ๋ณ€ํ™˜
214
+ **์ด ๋ฐ๋ชจ๋Š” ์•„๋ž˜ ๋‘ ๋‹จ๊ณ„๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.**
215
+
216
+ 1) **Diffusion ๋ชจ๋ธ(FluxPipeline)์„ ์ด์šฉํ•ด** ์ด๋ฏธ์ง€ ์ƒ์„ฑ.
217
+ - ์ด๋•Œ, ์‚ฌ์šฉ์ž๊ฐ€ ์ง€์ •ํ•œ ํ…์ŠคํŠธ๋ฅผ ์ด๋ฏธ์ง€ ์•ˆ์— ํ‘œ์‹œํ•˜๋„๋ก ์‹œ๋„ํ•ฉ๋‹ˆ๋‹ค.
218
+ 2) **์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๋ฅผ Google GenAI(gemini) ๋ชจ๋ธ์— ์ „๋‹ฌ**ํ•˜์—ฌ,
219
+ - ์ด๋ฏธ์ง€ ๋‚ด ํ…์ŠคํŠธ ๋ถ€๋ถ„๋งŒ ๋‹ค๋ฅธ ๋ฌธ์ž์—ด๋กœ ๋ณ€๊ฒฝ.
220
+
221
+ ---
222
  """
223
  )
224
 
225
+ with gr.Row():
226
+ with gr.Column():
227
+ gr.Markdown("## 1) Step 1: FLUX๋กœ ํ…์ŠคํŠธ ํฌํ•จ ์ด๋ฏธ์ง€ ์ƒ์„ฑ")
228
+ prompt_input = gr.Textbox(
229
+ lines=3,
230
+ label="์ด๋ฏธ์ง€ ์žฅ๋ฉด/๋ฐฐ๊ฒฝ Prompt",
231
+ placeholder="์˜ˆ) A poster with futuristic neon style..."
232
+ )
233
+ text_input = gr.Textbox(
234
+ lines=1,
235
+ label="์ด๋ฏธ์ง€ ์•ˆ์— ๋“ค์–ด๊ฐˆ ํ…์ŠคํŠธ",
236
+ placeholder="์˜ˆ) ์•ˆ๋…•ํ•˜์„ธ์š”"
237
+ )
238
+ with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ • (ํ™•์žฅ)", open=False):
239
+ height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=512)
240
+ width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=512)
241
+ steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8)
242
+ scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
243
+ seed = gr.Number(label="Seed (reproducibility)", value=1234, precision=0)
244
+
245
+ generate_btn = gr.Button("Generate Base Image", variant="primary")
246
+
247
+ # ์ƒ์„ฑ ๊ฒฐ๊ณผ ํ‘œ์‹œ
248
+ generated_image = gr.Image(
249
+ label="Generated Image (with text)",
250
+ type="pil"
251
+ )
252
+
253
+ with gr.Column():
254
+ gr.Markdown("## 2) Step 2: ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ๋‚ด ํ…์ŠคํŠธ ์ˆ˜์ •")
255
+ new_text_input = gr.Textbox(
256
+ label="์ƒˆ๋กœ ๋ฐ”๊ฟ€ ํ…์ŠคํŠธ",
257
+ placeholder="์˜ˆ) Hello world"
258
+ )
259
+ modify_btn = gr.Button("Change Text in Image via Gemini", variant="secondary")
260
+ output_img = gr.Image(label="Modified Image", type="pil")
261
+ output_txt = gr.Textbox(label="(If only text returned)")
262
 
263
+ # ๋ฒ„ํŠผ ์•ก์…˜ ์—ฐ๊ฒฐ
264
  generate_btn.click(
265
+ fn=generate_initial_image,
266
+ inputs=[prompt_input, text_input, height, width, steps, scale, seed],
267
+ outputs=[generated_image]
268
+ )
269
+
270
+ modify_btn.click(
271
+ fn=change_text_in_image,
272
+ inputs=[generated_image, new_text_input],
273
+ outputs=[output_img, output_txt]
274
  )
275
 
276
+ # ์‹ค์ œ ์‹คํ–‰ ์‹œ์—๋Š” ์•„๋ž˜์™€ ๊ฐ™์ด demo.launch()๋ฅผ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค.
277
  if __name__ == "__main__":
278
+ demo.queue(concurrency_count=1, max_size=20).launch()