seawolf2357's picture
Update app.py
a01c472 verified
raw
history blame
26.6 kB
import spaces
import os
import re
import json
import time
import torch
import tempfile
import io
import random
import string
import logging
from typing import Tuple, Optional, List, Dict, Any, Union
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import gradio as gr
from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
# Google Gemini API μΆ”κ°€
from google import genai
from google.genai import types
# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"
lora_base_path = "./models"
# System prompt that will be hidden from users but automatically added to their input
SYSTEM_PROMPT = "Ghibli Studio style, Charming hand-drawn anime-style illustration"
# λ‘œκΉ… μ„€μ •
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# Load the model
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
pipe.to("cuda")
def clear_cache(transformer):
for name, attn_processor in transformer.attn_processors.items():
attn_processor.bank_kv.clear()
#######################################
# Utility Functions
#######################################
# Simple Timer Class
class timer:
def __init__(self, method_name="timed process"):
self.method = method_name
def __enter__(self):
self.start = time.time()
print(f"[TIMER] {self.method} starts")
def __exit__(self, exc_type, exc_val, exc_tb):
end = time.time()
print(f"[TIMER] {self.method} took {round(end - self.start, 2)}s")
# κ°„λ‹¨ν•œ λ²ˆμ—­ κΈ°λŠ₯ (ν•œκΈ€ -> μ˜μ–΄)
def maybe_translate_to_english(text: str) -> str:
"""
ν…μŠ€νŠΈμ— ν•œκΈ€μ΄ ν¬ν•¨λ˜μ–΄ 있으면 μ˜μ–΄λ‘œ λ²ˆμ—­, μ•„λ‹ˆλ©΄ κ·ΈλŒ€λ‘œ λ°˜ν™˜
"""
if not text or not re.search("[κ°€-힣]", text):
return text
try:
# κ°„λ‹¨ν•œ λ²ˆμ—­ κ·œμΉ™ (μ‹€μ œ ν”„λ‘œλ•μ…˜μ—μ„œλŠ” API μ‚¬μš© ꢌμž₯)
translations = {
"μ•ˆλ…•ν•˜μ„Έμš”": "Hello",
"ν™˜μ˜ν•©λ‹ˆλ‹€": "Welcome",
"μ•„λ¦„λ‹€μš΄ λ‹Ήμ‹ ": "Beautiful You",
"μ•ˆλ…•": "Hello",
"고양이": "Cat",
"λ°°λ„ˆ": "Banner",
"μ¬κΈ€λΌμŠ€": "Sunglasses",
"μ°©μš©ν•œ": "wearing",
"흰색": "white"
}
# 전체 λ¬Έμž₯에 λŒ€ν•œ λŒ€λž΅μ μΈ λ²ˆμ—­
for kr, en in translations.items():
if kr in text:
text = text.replace(kr, en)
print(f"[TRANSLATE] Translated Korean text: '{text}'")
return text
except Exception as e:
print(f"[WARNING] Translation failed: {e}")
return text
def save_binary_file(file_name, data):
with open(file_name, "wb") as f:
f.write(data)
#######################################
# Gemini API Functions
#######################################
def generate_by_google_genai(text, file_name, model="gemini-2.0-flash-exp"):
"""
- μΆ”κ°€ μ§€μ‹œμ‚¬ν•­(AIP)을 전달해 이미지 기반 νŽΈμ§‘μ„ μˆ˜ν–‰.
- 응닡이 '이미지'λ©΄ μ €μž₯, 'ν…μŠ€νŠΈ'λ©΄ λˆ„μ ν•˜μ—¬ λ°˜ν™˜.
"""
# API ν‚€ κ°€μ Έμ˜€κΈ° (ν™˜κ²½ λ³€μˆ˜ GAPI_TOKEN μ‚¬μš©)
api_key = os.getenv("GAPI_TOKEN", None)
if not api_key:
raise ValueError("GAPI_TOKEN is missing. Please set an API key.")
client = genai.Client(api_key=api_key)
files = [client.files.upload(file=file_name)]
contents = [
types.Content(
role="user",
parts=[
types.Part.from_uri(
file_uri=files[0].uri,
mime_type=files[0].mime_type,
),
types.Part.from_text(text=text),
],
),
]
generate_content_config = types.GenerateContentConfig(
temperature=1,
top_p=0.95,
top_k=40,
max_output_tokens=8192,
response_modalities=["image", "text"],
response_mime_type="text/plain",
)
text_response = ""
image_path = None
# μž„μ‹œ νŒŒμΌμ— 이미지 μ €μž₯ κ°€λŠ₯ν•˜λ„λ‘ μ€€λΉ„
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
temp_path = tmp.name
for chunk in client.models.generate_content_stream(
model=model,
contents=contents,
config=generate_content_config,
):
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
continue
candidate = chunk.candidates[0].content.parts[0]
# λ§Œμ•½ inline_data(이미지 데이터)κ°€ μžˆλ‹€λ©΄ -> μ‹€μ œ 이미지 νŽΈμ§‘ κ²°κ³Ό
if candidate.inline_data:
save_binary_file(temp_path, candidate.inline_data.data)
print(f"File of mime type {candidate.inline_data.mime_type} saved to: {temp_path}")
image_path = temp_path
# 이미지 ν•œ μž₯만 ν™•λ³΄ν•˜λ©΄ 쀑단
break
else:
# inline_dataκ°€ μ—†μœΌλ©΄ ν…μŠ€νŠΈ λ°μ΄ν„°μ΄λ―€λ‘œ λˆ„μ 
text_response += chunk.text + "\n"
del files
return image_path, text_response
def change_text_in_image_two_times(original_image, instruction):
"""
Call the text-modification API twice (Google Gemini), returning 2 final variations.
"""
if original_image is None:
raise gr.Error("μ²˜λ¦¬ν•  이미지가 μ—†μŠ΅λ‹ˆλ‹€. λ¨Όμ € 이미지λ₯Ό μƒμ„±ν•΄μ£Όμ„Έμš”.")
results = []
for version_tag in ["(A)", "(B)"]:
mod_instruction = f"{instruction} {version_tag}"
try:
# 이미지 μ €μž₯용 μž„μ‹œ 파일 생성
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
original_path = tmp.name
# PIL 이미지 객체인 경우 μ €μž₯
print(f"[DEBUG] Saving image of type {type(original_image)} to temporary file")
if isinstance(original_image, Image.Image):
original_image.save(original_path, format="PNG")
print(f"[DEBUG] Saved image to temporary file: {original_path}")
else:
raise gr.Error(f"μ˜ˆμƒλœ PIL Imageκ°€ μ•„λ‹Œ {type(original_image)} νƒ€μž…μ΄ μ œκ³΅λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
print(f"[DEBUG] Google Gemini API에 λ³΄λ‚΄λŠ” μ§€μ‹œμ‚¬ν•­: {mod_instruction}")
image_path, text_response = generate_by_google_genai(
text=mod_instruction,
file_name=original_path
)
if image_path:
print(f"[DEBUG] Received image from Gemini API: {image_path}")
try:
with open(image_path, "rb") as f:
image_data = f.read()
new_img = Image.open(io.BytesIO(image_data))
results.append(new_img)
except Exception as img_err:
print(f"[ERROR] Failed to process Gemini image: {img_err}")
results.append(original_image)
else:
# λ§Œμ•½ 이미지 응닡이 μ—†κ³ , ν…μŠ€νŠΈλ§Œ 온 경우
print(f"[WARNING] 이미지가 λ°˜ν™˜λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€. ν…μŠ€νŠΈ 응닡: {text_response}")
results.append(original_image)
except Exception as e:
logging.exception(f"Text modification error: {e}")
# 였λ₯˜κ°€ λ‚˜λ„ 원본 이미지라도 λ°˜ν™˜
print(f"[ERROR] ν…μŠ€νŠΈ μˆ˜μ • 쀑 였λ₯˜ λ°œμƒ: {e}")
results.append(original_image)
return results
#######################################
# Image Generation Functions
#######################################
@spaces.GPU()
def single_condition_generate_image(user_prompt, spatial_img, height, width, seed):
# Combine the system prompt with user prompt
full_prompt = f"{SYSTEM_PROMPT}, {user_prompt}" if user_prompt else SYSTEM_PROMPT
# Set the Ghibli LoRA
lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
# Process the image
spatial_imgs = [spatial_img] if spatial_img else []
image = pipe(
full_prompt,
height=int(height),
width=int(width),
guidance_scale=3.5,
num_inference_steps=25,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(seed),
subject_images=[],
spatial_images=spatial_imgs,
cond_size=512,
).images[0]
clear_cache(pipe.transformer)
return image
@spaces.GPU()
def text_rendering_generate_image(user_prompt, input_text, text_color, text_size, text_position, spatial_img, height, width, seed):
"""
Generate image with Ghibli style and then send to Gemini API for multilingual text rendering
"""
try:
# Step 1: Generate the base image using FLUX
print(f"[DEBUG] Generating base image with FLUX")
full_prompt = f"{SYSTEM_PROMPT}, {user_prompt}" if user_prompt else SYSTEM_PROMPT
# Set the Ghibli LoRA
lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
# Process the image
spatial_imgs = [spatial_img] if spatial_img else []
base_image = pipe(
full_prompt,
height=int(height),
width=int(width),
guidance_scale=3.5,
num_inference_steps=25,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(seed),
subject_images=[],
spatial_images=spatial_imgs,
cond_size=512,
).images[0]
clear_cache(pipe.transformer)
# If no text is provided, return the base image
if not input_text or not input_text.strip():
return [base_image, base_image]
# Step 2: Build the instruction for Gemini API
instruction = f"Add the text '{input_text}' to this image in {text_color} color"
# Add position information
if text_position == "top":
instruction += " at the top of the image"
elif text_position == "bottom":
instruction += " at the bottom of the image"
else: # center
instruction += " at the center of the image"
# Add size information
if text_size <= 40:
instruction += " in small size"
elif text_size <= 120:
instruction += " in medium size"
else:
instruction += " in large size"
instruction += ". Make sure the text is clearly visible and readable."
# Step 3: Call Gemini API to generate two variations
print(f"[DEBUG] Sending to Gemini API with instruction: {instruction}")
return change_text_in_image_two_times(base_image, instruction)
except Exception as e:
logging.exception(f"Text rendering error: {e}")
# Create a dummy image in case of error
dummy_img = Image.new('RGB', (width, height), color=(255, 200, 200))
draw = ImageDraw.Draw(dummy_img)
draw.text((width//2, height//2), f"Error: {str(e)}", fill="black", anchor="mm")
return [dummy_img, dummy_img]
# Load example images
def load_examples():
examples = []
test_img_dir = "./test_imgs"
example_prompts = [
" ",
"saying 'HELLO' in 'speech bubble'",
"background 'alps'"
]
for i, filename in enumerate(["00.jpg", "02.jpg", "03.jpg"]):
img_path = os.path.join(test_img_dir, filename)
if os.path.exists(img_path):
# Use dimensions from original code for each specific example
if filename == "00.jpg":
height, width = 680, 1024
elif filename == "02.jpg":
height, width = 560, 1024
elif filename == "03.jpg":
height, width = 1024, 768
else:
height, width = 768, 768
examples.append([
example_prompts[i % len(example_prompts)], # User prompt (without system prompt)
Image.open(img_path), # Reference image
height, # Height
width, # Width
i + 1 # Seed
])
return examples
# Load examples for text rendering tab
def load_text_examples():
examples = []
test_img_dir = "./test_imgs"
example_data = [
{
"prompt": "cute character with speech bubble",
"text": "Hello World!",
"color": "#ffffff",
"size": 72,
"position": "center",
"filename": "00.jpg",
"height": 680,
"width": 1024,
"seed": 123
},
{
"prompt": "landscape with message",
"text": "μ•ˆλ…•ν•˜μ„Έμš”!",
"color": "#ffff00",
"size": 100,
"position": "top",
"filename": "03.jpg",
"height": 1024,
"width": 768,
"seed": 456
},
{
"prompt": "character with subtitles",
"text": "γ“γ‚“γ«γ‘γ―δΈ–η•Œ!",
"color": "#00ffff",
"size": 90,
"position": "bottom",
"filename": "02.jpg",
"height": 560,
"width": 1024,
"seed": 789
}
]
for example in example_data:
img_path = os.path.join(test_img_dir, example["filename"])
if os.path.exists(img_path):
examples.append([
example["prompt"],
example["text"],
example["color"],
example["size"],
example["position"],
Image.open(img_path),
example["height"],
example["width"],
example["seed"]
])
return examples
# Function to check API availability - modified to work directly
def check_api_status():
# Check Gemini API availability
api_key = os.getenv("GAPI_TOKEN")
gemini_available = api_key is not None
if gemini_available:
return """<div class="api-status api-connected">βœ“ Connected to FLUX.1 and Gemini API</div>"""
else:
return """<div class="api-status api-disconnected">βœ— Gemini API connection issue. Please check GAPI_TOKEN environment variable.</div>"""
# CSS for improved UI
css = """
:root {
--primary-color: #4a6670;
--accent-color: #ff8a65;
--background-color: #f5f5f5;
--card-background: #ffffff;
--text-color: #333333;
--border-radius: 10px;
--shadow: 0 4px 6px rgba(0,0,0,0.1);
}
body {
background-color: var(--background-color);
color: var(--text-color);
font-family: 'Helvetica Neue', Arial, sans-serif;
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
.gr-header {
background: linear-gradient(135deg, #668796 0%, #4a6670 100%);
padding: 24px;
border-radius: var(--border-radius);
margin-bottom: 24px;
box-shadow: var(--shadow);
text-align: center;
}
.gr-header h1 {
color: white;
font-size: 2.5rem;
margin: 0;
font-weight: 700;
}
.gr-header p {
color: rgba(255, 255, 255, 0.9);
font-size: 1.1rem;
margin-top: 8px;
}
.gr-panel {
background-color: var(--card-background);
border-radius: var(--border-radius);
padding: 16px;
box-shadow: var(--shadow);
}
.gr-button {
background-color: var(--accent-color);
border: none;
color: white;
padding: 10px 20px;
border-radius: 5px;
font-size: 16px;
font-weight: bold;
cursor: pointer;
transition: transform 0.1s, background-color 0.3s;
}
.gr-button:hover {
background-color: #ff7043;
transform: translateY(-2px);
}
.gr-input, .gr-select {
border-radius: 5px;
border: 1px solid #ddd;
padding: 10px;
width: 100%;
}
.gr-form {
display: grid;
gap: 16px;
}
.gr-box {
background-color: var(--card-background);
border-radius: var(--border-radius);
padding: 20px;
box-shadow: var(--shadow);
margin-bottom: 20px;
}
.gr-gallery {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
gap: 16px;
}
.gr-gallery-item {
overflow: hidden;
border-radius: var(--border-radius);
box-shadow: var(--shadow);
transition: transform 0.3s;
}
.gr-gallery-item:hover {
transform: scale(1.02);
}
.gr-image {
width: 100%;
height: auto;
object-fit: cover;
}
.gr-footer {
text-align: center;
margin-top: 40px;
padding: 20px;
color: #666;
font-size: 14px;
}
.gr-examples-gallery {
margin-top: 20px;
}
/* Responsive adjustments */
@media (max-width: 768px) {
.gr-header h1 {
font-size: 1.8rem;
}
.gr-panel {
padding: 12px;
}
}
/* Ghibli-inspired accent colors */
.gr-accent-1 {
background-color: #95ccd9;
}
.gr-accent-2 {
background-color: #74ad8c;
}
.gr-accent-3 {
background-color: #f9c06b;
}
.text-rendering-options {
background-color: #f0f8ff;
padding: 16px;
border-radius: var(--border-radius);
margin-top: 16px;
}
.api-status {
font-size: 14px;
color: #666;
text-align: center;
margin-bottom: 10px;
}
.api-connected {
color: green;
}
.api-disconnected {
color: red;
}
"""
# Create the Gradio Blocks interface
with gr.Blocks(css=css) as demo:
gr.HTML("""
<div class="gr-header">
<h1>✨ Ghibli Multilingual Text-Rendering ✨</h1>
<p>Transform your ideas into magical Ghibli-inspired artwork with multilingual text</p>
</div>
""")
# API Status - 직접 ν˜ΈμΆœν•΄μ„œ 초기 μƒνƒœ μ„€μ •
api_status = gr.Markdown(check_api_status(), visible=True)
with gr.Tabs():
with gr.Tab("Create Ghibli Art"):
with gr.Row():
with gr.Column(scale=1):
gr.HTML("""
<div class="gr-box">
<h3>🎨 Your Creative Input</h3>
<p>Describe what you want to see in your Ghibli-inspired image</p>
</div>
""")
user_prompt = gr.Textbox(
label="Your description",
placeholder="Describe what you want to see (e.g., a cat sitting by the window)",
lines=2
)
spatial_img = gr.Image(
label="Reference Image (Optional)",
type="pil",
elem_classes="gr-image-upload"
)
with gr.Group():
with gr.Row():
height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768)
width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768)
seed = gr.Slider(minimum=1, maximum=9999, step=1, label="Seed", value=42,
info="Change for different variations")
generate_btn = gr.Button("✨ Generate Ghibli Art", variant="primary", elem_classes=["generate-btn"])
with gr.Column(scale=1):
gr.HTML("""
<div class="gr-box">
<h3>✨ Your Magical Creation</h3>
<p>Your Ghibli-inspired artwork will appear here</p>
</div>
""")
output_image = gr.Image(label="Generated Image", elem_classes="gr-output-image")
gr.HTML("""
<div class="gr-box gr-examples-gallery">
<h3>✨ Inspiration Gallery</h3>
<p>Click on any example to try it out</p>
</div>
""")
# Add examples
examples = load_examples()
gr.Examples(
examples=examples,
inputs=[user_prompt, spatial_img, height, width, seed],
outputs=output_image,
fn=single_condition_generate_image,
cache_examples=False,
examples_per_page=4
)
# Link the button to the function
generate_btn.click(
single_condition_generate_image,
inputs=[user_prompt, spatial_img, height, width, seed],
outputs=output_image
)
# Second tab for Image & Multilingual Text Rendering with Gemini API
with gr.Tab("Image & Multilingual Text Rendering"):
with gr.Row():
with gr.Column(scale=1):
gr.HTML("""
<div class="gr-box">
<h3>🌈 Art with Multilingual Text</h3>
<p>Create Ghibli-style images with beautiful text in any language using Gemini AI</p>
</div>
""")
text_user_prompt = gr.Textbox(
label="Image Description",
placeholder="Describe what you want to see (e.g., a character with speech bubble)",
lines=2
)
with gr.Group(elem_classes="text-rendering-options"):
input_text = gr.Textbox(
label="Multilingual Text to Add",
placeholder="Enter text in any language (Korean, Japanese, English, etc.)",
lines=1
)
with gr.Row():
text_color = gr.ColorPicker(
label="Text Color",
value="#FFFFFF"
)
text_size = gr.Slider(
minimum=24,
maximum=200,
step=4,
label="Text Size",
value=72
)
text_position = gr.Radio(
["top", "center", "bottom"],
label="Text Position",
value="center"
)
text_spatial_img = gr.Image(
label="Reference Image (Optional)",
type="pil",
elem_classes="gr-image-upload"
)
with gr.Group():
with gr.Row():
text_height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768)
text_width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768)
text_seed = gr.Slider(minimum=1, maximum=9999, step=1, label="Seed", value=42,
info="Change for different variations")
text_generate_btn = gr.Button("✨ Generate Art with Multilingual Text", variant="primary", elem_classes=["generate-btn"])
with gr.Column(scale=1):
gr.HTML("""
<div class="gr-box">
<h3>✨ Your Text Creations (Two Variations)</h3>
<p>Two versions of your Ghibli-inspired artwork with text will appear here</p>
</div>
""")
with gr.Row():
text_output_image1 = gr.Image(
label="Variation A",
type="pil",
elem_classes="gr-output-image"
)
text_output_image2 = gr.Image(
label="Variation B",
type="pil",
elem_classes="gr-output-image"
)
gr.HTML("""
<div class="gr-box gr-examples-gallery">
<h3>✨ Multilingual Text Examples</h3>
<p>Click on any example to try it out</p>
</div>
""")
# Add text rendering examples
text_examples = load_text_examples()
gr.Examples(
examples=text_examples,
inputs=[text_user_prompt, input_text, text_color, text_size, text_position,
text_spatial_img, text_height, text_width, text_seed],
outputs=[text_output_image1, text_output_image2],
fn=text_rendering_generate_image,
cache_examples=False,
examples_per_page=3
)
# Link the text render button to the function
text_generate_btn.click(
text_rendering_generate_image,
inputs=[text_user_prompt, input_text, text_color, text_size, text_position,
text_spatial_img, text_height, text_width, text_seed],
outputs=[text_output_image1, text_output_image2]
)
gr.HTML("""
<div class="gr-footer">
<p>Powered by FLUX.1, Ghibli LoRA, and Google Gemini API β€’ Created with ❀️</p>
</div>
""")
# Launch the Gradio app
demo.queue().launch()