Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,65 @@ import time
|
|
5 |
import logging
|
6 |
import google.generativeai as genai
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
# λ‘κΉ
μ€μ
|
9 |
logging.basicConfig(
|
10 |
level=logging.INFO,
|
@@ -189,12 +248,8 @@ physical_transformation_categories = {
|
|
189 |
##############################################################################
|
190 |
def query_gemini_api(prompt):
|
191 |
try:
|
192 |
-
# μμ: κΈ°μ‘΄ gemini-2.0... λμ , λ€λ₯Έ λͺ¨λΈμ΄ νμνλ€λ©΄ κ΅μ²΄νμΈμ.
|
193 |
model = genai.GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
|
194 |
-
|
195 |
response = model.generate_content(prompt)
|
196 |
-
|
197 |
-
# μλ΅ κ΅¬μ‘° λ°©μ΄μ μΌλ‘ μ²λ¦¬
|
198 |
try:
|
199 |
if hasattr(response, 'text'):
|
200 |
return response.text
|
@@ -210,13 +265,10 @@ def query_gemini_api(prompt):
|
|
210 |
if hasattr(response, 'parts') and response.parts:
|
211 |
if len(response.parts) > 0:
|
212 |
return response.parts[0].text
|
213 |
-
|
214 |
return "Unable to generate a response. API response structure is different than expected."
|
215 |
-
|
216 |
except Exception as inner_e:
|
217 |
logger.error(f"Error processing response: {inner_e}")
|
218 |
return f"An error occurred while processing the response: {str(inner_e)}"
|
219 |
-
|
220 |
except Exception as e:
|
221 |
logger.error(f"Error calling Gemini API: {e}")
|
222 |
if "API key not valid" in str(e):
|
@@ -224,7 +276,7 @@ def query_gemini_api(prompt):
|
|
224 |
return f"An error occurred while calling the API: {str(e)}"
|
225 |
|
226 |
##############################################################################
|
227 |
-
# μ€λͺ
νμ₯
|
228 |
##############################################################################
|
229 |
def enhance_with_llm(base_description, obj_name, category):
|
230 |
prompt = f"""
|
@@ -238,7 +290,7 @@ def enhance_with_llm(base_description, obj_name, category):
|
|
238 |
return query_gemini_api(prompt)
|
239 |
|
240 |
##############################################################################
|
241 |
-
# λ¨μΌ
|
242 |
##############################################################################
|
243 |
def generate_single_object_transformations(obj):
|
244 |
results = {}
|
@@ -283,10 +335,8 @@ def generate_three_objects_interaction(obj1, obj2, obj3):
|
|
283 |
##############################################################################
|
284 |
def enhance_descriptions(results, objects):
|
285 |
obj_name = " λ° ".join([obj for obj in objects if obj])
|
286 |
-
|
287 |
for category, result in results.items():
|
288 |
result["enhanced"] = enhance_with_llm(result["base"], obj_name, category)
|
289 |
-
|
290 |
return results
|
291 |
|
292 |
##############################################################################
|
@@ -302,7 +352,6 @@ def generate_transformations(text1, text2=None, text3=None):
|
|
302 |
else:
|
303 |
results = generate_single_object_transformations(text1)
|
304 |
objects = [text1]
|
305 |
-
|
306 |
return enhance_descriptions(results, objects)
|
307 |
|
308 |
##############################################################################
|
@@ -315,7 +364,7 @@ def format_results(results):
|
|
315 |
return formatted
|
316 |
|
317 |
##############################################################################
|
318 |
-
# Gradio UIμμ νΈμΆν ν¨μ
|
319 |
##############################################################################
|
320 |
def process_inputs(text1, text2, text3, selected_category, progress=gr.Progress()):
|
321 |
text1 = text1.strip() if text1 else None
|
@@ -325,20 +374,13 @@ def process_inputs(text1, text2, text3, selected_category, progress=gr.Progress(
|
|
325 |
if not text1:
|
326 |
return "μ€λ₯: μ΅μ νλμ ν€μλλ₯Ό μ
λ ₯ν΄μ£ΌμΈμ."
|
327 |
|
328 |
-
keyword_info = f"ν€μλ: {text1}"
|
329 |
-
if text2:
|
330 |
-
keyword_info += f", {text2}"
|
331 |
-
if text3:
|
332 |
-
keyword_info += f", {text3}"
|
333 |
-
|
334 |
progress(0.05, desc="μμ΄λμ΄ μμ± μ€λΉ μ€...")
|
335 |
-
time.sleep(0.3)
|
336 |
-
|
337 |
-
progress(0.1, desc="μ°½μμ μΈ λͺ¨λΈ/컨μ
/νμ λ³ν μμ΄λμ΄ μμ± μμ...")
|
338 |
|
339 |
results = generate_transformations(text1, text2, text3)
|
340 |
|
341 |
-
# μ νν
|
342 |
if selected_category in results:
|
343 |
results = {selected_category: results[selected_category]}
|
344 |
else:
|
@@ -346,10 +388,19 @@ def process_inputs(text1, text2, text3, selected_category, progress=gr.Progress(
|
|
346 |
|
347 |
progress(0.8, desc="κ²°κ³Ό ν¬λ§·ν
μ€...")
|
348 |
formatted = format_results(results)
|
349 |
-
|
350 |
progress(1.0, desc="μλ£!")
|
351 |
return formatted
|
352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
##############################################################################
|
354 |
# API ν€ κ²½κ³ λ©μμ§
|
355 |
##############################################################################
|
@@ -361,7 +412,7 @@ def get_warning_message():
|
|
361 |
##############################################################################
|
362 |
# Gradio UI
|
363 |
##############################################################################
|
364 |
-
with gr.Blocks(title="ν€μλ κΈ°λ° μ°½μμ λ³ν μμ΄λμ΄ μμ±κΈ°",
|
365 |
theme=gr.themes.Soft(primary_hue="teal", secondary_hue="slate", neutral_hue="neutral")) as demo:
|
366 |
|
367 |
gr.HTML("""
|
@@ -377,27 +428,23 @@ with gr.Blocks(title="ν€μλ κΈ°λ° μ°½μμ λ³ν μμ΄λμ΄ μμ±κΈ°",
|
|
377 |
</style>
|
378 |
""")
|
379 |
|
380 |
-
gr.Markdown("# π ν€μλ κΈ°λ° μ°½μμ λ³ν μμ΄λμ΄ μμ±κΈ°")
|
381 |
-
gr.Markdown("μ
λ ₯ν **ν€μλ**(μ΅λ 3κ°)μ **μΉ΄ν
κ³ λ¦¬**λ₯Ό λ°νμΌλ‘,
|
382 |
|
383 |
warning = gr.Markdown(get_warning_message())
|
384 |
|
385 |
-
# μ’μΈ‘ μ
λ ₯ μμ
|
386 |
with gr.Row():
|
387 |
with gr.Column(scale=1):
|
388 |
text_input1 = gr.Textbox(label="ν€μλ 1 (νμ)", placeholder="μ: μ€λ§νΈν°")
|
389 |
text_input2 = gr.Textbox(label="ν€μλ 2 (μ ν)", placeholder="μ: μΈκ³΅μ§λ₯")
|
390 |
text_input3 = gr.Textbox(label="ν€μλ 3 (μ ν)", placeholder="μ: ν¬μ€μΌμ΄")
|
391 |
-
# μΉ΄ν
κ³ λ¦¬ μ ν λλ‘λ€μ΄ μΆκ°
|
392 |
category_dropdown = gr.Dropdown(
|
393 |
label="μΉ΄ν
κ³ λ¦¬ μ ν",
|
394 |
choices=list(physical_transformation_categories.keys()),
|
395 |
value=list(physical_transformation_categories.keys())[0],
|
396 |
info="μΆλ ₯ν μΉ΄ν
κ³ λ¦¬λ₯Ό μ ννμΈμ."
|
397 |
)
|
398 |
-
|
399 |
-
status_msg = gr.Markdown("π‘ 'μμ΄λμ΄ μμ±νκΈ°' λ²νΌμ ν΄λ¦νλ©΄ μμ΄λμ΄ μμ±μ΄ μμλ©λλ€.")
|
400 |
-
|
401 |
processing_indicator = gr.HTML("""
|
402 |
<div style="display: flex; justify-content: center; align-items: center; margin: 10px 0;">
|
403 |
<div style="border: 5px solid #f3f3f3; border-top: 5px solid #3498db; border-radius: 50%; width: 30px; height: 30px; animation: spin 2s linear infinite;"></div>
|
@@ -410,13 +457,12 @@ with gr.Blocks(title="ν€μλ κΈ°λ° μ°½μμ λ³ν μμ΄λμ΄ μμ±κΈ°",
|
|
410 |
}
|
411 |
</style>
|
412 |
""", visible=False)
|
413 |
-
|
414 |
submit_button = gr.Button("μμ΄λμ΄ μμ±νκΈ°", variant="primary")
|
415 |
|
416 |
-
# μ°μΈ‘ μΆλ ₯ μμ
|
417 |
with gr.Column(scale=2):
|
418 |
idea_output = gr.Markdown(label="μμ΄λμ΄ κ²°κ³Ό")
|
419 |
-
|
|
|
420 |
gr.Examples(
|
421 |
examples=[
|
422 |
["μ€λ§νΈν°", "", "", list(physical_transformation_categories.keys())[0]],
|
@@ -439,9 +485,9 @@ with gr.Blocks(title="ν€μλ κΈ°λ° μ°½μμ λ³ν μμ΄λμ΄ μμ±κΈ°",
|
|
439 |
inputs=None,
|
440 |
outputs=processing_indicator
|
441 |
).then(
|
442 |
-
fn=
|
443 |
inputs=[text_input1, text_input2, text_input3, category_dropdown],
|
444 |
-
outputs=idea_output
|
445 |
).then(
|
446 |
fn=hide_processing_indicator,
|
447 |
inputs=None,
|
|
|
5 |
import logging
|
6 |
import google.generativeai as genai
|
7 |
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from diffusers import DiffusionPipeline
|
11 |
+
from transformers import pipeline as hf_pipeline
|
12 |
+
|
13 |
+
# ---------------------- μ΄λ―Έμ§ μμ± κ΄λ ¨ μ€μ ----------------------
|
14 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
dtype = torch.bfloat16 if device=="cuda" else torch.float32
|
16 |
+
|
17 |
+
# νκ΅μ΄-μμ΄ λ²μ λͺ¨λΈ λ‘λ (μ₯μΉμ λ°λΌ CPU λλ GPU μ¬μ©)
|
18 |
+
translator = hf_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device=0 if device=="cuda" else -1)
|
19 |
+
|
20 |
+
# Diffusion Pipeline λ‘λ (μ: FLUX.1-schnell λͺ¨λΈ)
|
21 |
+
pipe = DiffusionPipeline.from_pretrained(
|
22 |
+
"black-forest-labs/FLUX.1-schnell",
|
23 |
+
torch_dtype=dtype
|
24 |
+
).to(device)
|
25 |
+
|
26 |
+
MAX_SEED = np.iinfo(np.int32).max
|
27 |
+
MAX_IMAGE_SIZE = 2048
|
28 |
+
|
29 |
+
def contains_korean(text):
|
30 |
+
for char in text:
|
31 |
+
if ord('κ°') <= ord(char) <= ord('ν£'):
|
32 |
+
return True
|
33 |
+
return False
|
34 |
+
|
35 |
+
def generate_design_image(prompt, seed=42, randomize_seed=True, width=1024, height=1024, num_inference_steps=4):
|
36 |
+
"""
|
37 |
+
μμ±λ νμ₯ μμ΄λμ΄ ν
μ€νΈ(prompt)λ₯Ό μ
λ ₯λ°μ,
|
38 |
+
νμμ νκ΅μ΄λ₯Ό μμ΄λ‘ λ²μν ν DiffusionPipelineμΌλ‘ μ΄λ―Έμ§λ₯Ό μμ±ν©λλ€.
|
39 |
+
"""
|
40 |
+
original_prompt = prompt
|
41 |
+
translated = False
|
42 |
+
|
43 |
+
# νκ΅μ΄κ° ν¬ν¨λμ΄ μμΌλ©΄ μμ΄λ‘ λ²μ
|
44 |
+
if contains_korean(prompt):
|
45 |
+
translation = translator(prompt)
|
46 |
+
prompt = translation[0]['translation_text']
|
47 |
+
translated = True
|
48 |
+
|
49 |
+
# λλ€ μλ μ€μ
|
50 |
+
if randomize_seed:
|
51 |
+
seed = random.randint(0, MAX_SEED)
|
52 |
+
|
53 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
54 |
+
|
55 |
+
image = pipe(
|
56 |
+
prompt=prompt,
|
57 |
+
width=width,
|
58 |
+
height=height,
|
59 |
+
num_inference_steps=num_inference_steps,
|
60 |
+
generator=generator,
|
61 |
+
guidance_scale=0.0
|
62 |
+
).images[0]
|
63 |
+
|
64 |
+
return image
|
65 |
+
|
66 |
+
# ---------------------- Gemini API λ° μμ΄λμ΄ μμ± κ΄λ ¨ κΈ°μ‘΄ μ½λ ----------------------
|
67 |
# λ‘κΉ
μ€μ
|
68 |
logging.basicConfig(
|
69 |
level=logging.INFO,
|
|
|
248 |
##############################################################################
|
249 |
def query_gemini_api(prompt):
|
250 |
try:
|
|
|
251 |
model = genai.GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
|
|
|
252 |
response = model.generate_content(prompt)
|
|
|
|
|
253 |
try:
|
254 |
if hasattr(response, 'text'):
|
255 |
return response.text
|
|
|
265 |
if hasattr(response, 'parts') and response.parts:
|
266 |
if len(response.parts) > 0:
|
267 |
return response.parts[0].text
|
|
|
268 |
return "Unable to generate a response. API response structure is different than expected."
|
|
|
269 |
except Exception as inner_e:
|
270 |
logger.error(f"Error processing response: {inner_e}")
|
271 |
return f"An error occurred while processing the response: {str(inner_e)}"
|
|
|
272 |
except Exception as e:
|
273 |
logger.error(f"Error calling Gemini API: {e}")
|
274 |
if "API key not valid" in str(e):
|
|
|
276 |
return f"An error occurred while calling the API: {str(e)}"
|
277 |
|
278 |
##############################################################################
|
279 |
+
# μ€λͺ
νμ₯ ν¨μ
|
280 |
##############################################################################
|
281 |
def enhance_with_llm(base_description, obj_name, category):
|
282 |
prompt = f"""
|
|
|
290 |
return query_gemini_api(prompt)
|
291 |
|
292 |
##############################################################################
|
293 |
+
# λ¨μΌ ν€μλμ λν "μ°½μμ λ³ν μμ΄λμ΄" μμ±
|
294 |
##############################################################################
|
295 |
def generate_single_object_transformations(obj):
|
296 |
results = {}
|
|
|
335 |
##############################################################################
|
336 |
def enhance_descriptions(results, objects):
|
337 |
obj_name = " λ° ".join([obj for obj in objects if obj])
|
|
|
338 |
for category, result in results.items():
|
339 |
result["enhanced"] = enhance_with_llm(result["base"], obj_name, category)
|
|
|
340 |
return results
|
341 |
|
342 |
##############################################################################
|
|
|
352 |
else:
|
353 |
results = generate_single_object_transformations(text1)
|
354 |
objects = [text1]
|
|
|
355 |
return enhance_descriptions(results, objects)
|
356 |
|
357 |
##############################################################################
|
|
|
364 |
return formatted
|
365 |
|
366 |
##############################################################################
|
367 |
+
# Gradio UIμμ νΈμΆν ν¨μ (μμ΄λμ΄ ν
μ€νΈ μμ±)
|
368 |
##############################################################################
|
369 |
def process_inputs(text1, text2, text3, selected_category, progress=gr.Progress()):
|
370 |
text1 = text1.strip() if text1 else None
|
|
|
374 |
if not text1:
|
375 |
return "μ€λ₯: μ΅μ νλμ ν€μλλ₯Ό μ
λ ₯ν΄μ£ΌμΈμ."
|
376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
progress(0.05, desc="μμ΄λμ΄ μμ± μ€λΉ μ€...")
|
378 |
+
time.sleep(0.3)
|
379 |
+
progress(0.1, desc="μ°½μμ μΈ μμ΄λμ΄ μμ± μμ...")
|
|
|
380 |
|
381 |
results = generate_transformations(text1, text2, text3)
|
382 |
|
383 |
+
# μ νν μΉ΄ν
κ³ λ¦¬ κ²°κ³Όλ§ νν°λ§
|
384 |
if selected_category in results:
|
385 |
results = {selected_category: results[selected_category]}
|
386 |
else:
|
|
|
388 |
|
389 |
progress(0.8, desc="κ²°κ³Ό ν¬λ§·ν
μ€...")
|
390 |
formatted = format_results(results)
|
|
|
391 |
progress(1.0, desc="μλ£!")
|
392 |
return formatted
|
393 |
|
394 |
+
##############################################################################
|
395 |
+
# μλ‘μ΄ ν΅ν© ν¨μ: μμ΄λμ΄ ν
μ€νΈ μμ± λ° μ΄λ―Έμ§ μμ±
|
396 |
+
##############################################################################
|
397 |
+
def process_all(text1, text2, text3, selected_category, progress=gr.Progress()):
|
398 |
+
# νμ₯ μμ΄λμ΄ ν
μ€νΈ μμ±
|
399 |
+
idea_result = process_inputs(text1, text2, text3, selected_category, progress)
|
400 |
+
# μμ±λ μμ΄λμ΄λ₯Ό κ·Έλλ‘ μ΄λ―Έμ§ μμ± ν둬ννΈλ‘ μ¬μ©
|
401 |
+
image_result = generate_design_image(idea_result, seed=42, randomize_seed=True, width=1024, height=1024, num_inference_steps=4)
|
402 |
+
return idea_result, image_result
|
403 |
+
|
404 |
##############################################################################
|
405 |
# API ν€ κ²½κ³ λ©μμ§
|
406 |
##############################################################################
|
|
|
412 |
##############################################################################
|
413 |
# Gradio UI
|
414 |
##############################################################################
|
415 |
+
with gr.Blocks(title="ν€μλ κΈ°λ° μ°½μμ λ³ν μμ΄λμ΄ λ° λμμΈ μμ±κΈ°",
|
416 |
theme=gr.themes.Soft(primary_hue="teal", secondary_hue="slate", neutral_hue="neutral")) as demo:
|
417 |
|
418 |
gr.HTML("""
|
|
|
428 |
</style>
|
429 |
""")
|
430 |
|
431 |
+
gr.Markdown("# π ν€μλ κΈ°λ° μ°½μμ λ³ν μμ΄λμ΄ λ° λμμΈ μμ±κΈ°")
|
432 |
+
gr.Markdown("μ
λ ₯ν **ν€μλ**(μ΅λ 3κ°)μ **μΉ΄ν
κ³ λ¦¬**λ₯Ό λ°νμΌλ‘, μ°½μμ μΈ λͺ¨λΈ/컨μ
/νμ λ³ν μμ΄λμ΄λ₯Ό μμ±νκ³ , ν΄λΉ νμ₯ μμ΄λμ΄λ₯Ό ν둬ννΈλ‘ νμ¬ λμμΈ μ΄λ―Έμ§λ₯Ό μμ±ν©λλ€.")
|
433 |
|
434 |
warning = gr.Markdown(get_warning_message())
|
435 |
|
|
|
436 |
with gr.Row():
|
437 |
with gr.Column(scale=1):
|
438 |
text_input1 = gr.Textbox(label="ν€μλ 1 (νμ)", placeholder="μ: μ€λ§νΈν°")
|
439 |
text_input2 = gr.Textbox(label="ν€μλ 2 (μ ν)", placeholder="μ: μΈκ³΅μ§λ₯")
|
440 |
text_input3 = gr.Textbox(label="ν€μλ 3 (μ ν)", placeholder="μ: ν¬μ€μΌμ΄")
|
|
|
441 |
category_dropdown = gr.Dropdown(
|
442 |
label="μΉ΄ν
κ³ λ¦¬ μ ν",
|
443 |
choices=list(physical_transformation_categories.keys()),
|
444 |
value=list(physical_transformation_categories.keys())[0],
|
445 |
info="μΆλ ₯ν μΉ΄ν
κ³ λ¦¬λ₯Ό μ ννμΈμ."
|
446 |
)
|
447 |
+
status_msg = gr.Markdown("οΏ½οΏ½οΏ½οΏ½ 'μμ΄λμ΄ μμ±νκΈ°' λ²νΌμ ν΄λ¦νλ©΄ μμ΄λμ΄ μμ±κ³Ό ν¨κ» λμμΈ μ΄λ―Έμ§κ° μμ±λ©λλ€.")
|
|
|
|
|
448 |
processing_indicator = gr.HTML("""
|
449 |
<div style="display: flex; justify-content: center; align-items: center; margin: 10px 0;">
|
450 |
<div style="border: 5px solid #f3f3f3; border-top: 5px solid #3498db; border-radius: 50%; width: 30px; height: 30px; animation: spin 2s linear infinite;"></div>
|
|
|
457 |
}
|
458 |
</style>
|
459 |
""", visible=False)
|
|
|
460 |
submit_button = gr.Button("μμ΄λμ΄ μμ±νκΈ°", variant="primary")
|
461 |
|
|
|
462 |
with gr.Column(scale=2):
|
463 |
idea_output = gr.Markdown(label="μμ΄λμ΄ κ²°κ³Ό")
|
464 |
+
generated_image = gr.Image(label="μμ±λ λμμΈ μ΄λ―Έμ§", type="pil")
|
465 |
+
|
466 |
gr.Examples(
|
467 |
examples=[
|
468 |
["μ€λ§νΈν°", "", "", list(physical_transformation_categories.keys())[0]],
|
|
|
485 |
inputs=None,
|
486 |
outputs=processing_indicator
|
487 |
).then(
|
488 |
+
fn=process_all,
|
489 |
inputs=[text_input1, text_input2, text_input3, category_dropdown],
|
490 |
+
outputs=[idea_output, generated_image]
|
491 |
).then(
|
492 |
fn=hide_processing_indicator,
|
493 |
inputs=None,
|