Futuretop's picture
Update app.py
746a79b verified
raw
history blame
4.29 kB
import sys, types, importlib.machinery
spec = importlib.machinery.ModuleSpec('flash_attn', loader=None)
mod = types.ModuleType('flash_attn')
mod.__spec__ = spec
sys.modules['flash_attn'] = mod
import transformers.utils.import_utils as _import_utils
_import_utils._is_package_available = lambda pkg: False
_import_utils.is_flash_attn_2_available = lambda: False
import huggingface_hub as _hf_hub
_hf_hub.cached_download = _hf_hub.hf_hub_download
import gradio as gr
import torch
import random
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"
# Florence-2 ๋กœ๋“œ
device = "cuda" if torch.cuda.is_available() else "cpu"
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', revision = REVISION, trust_remote_code=True).to(device).eval()
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', revision = REVISION, trust_remote_code=True)
# Stable Diffusion TurboX ๋กœ๋“œ
model_repo = "tensorart/stable-diffusion-3.5-large-TurboX"
pipe = DiffusionPipeline.from_pretrained(
model_repo,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo, subfolder="scheduler", shift=5)
pipe = pipe.to(device)
MAX_SEED = 2**31 - 1
def pseudo_translate_to_korean_style(en_prompt: str) -> str:
# ๋ฒˆ์—ญ ์—†์ด ์Šคํƒ€์ผ ์ ์šฉ
return f"์ด ์žฅ๋ฉด์€ {en_prompt} ์žฅ๋ฉด์ž…๋‹ˆ๋‹ค. ๋ฐ๊ณ  ๊ท€์—ฌ์šด ์นดํˆฐ ์Šคํƒ€์ผ๋กœ ๊ทธ๋ ค์ฃผ์„ธ์š”. ๋””์ง€ํ„ธ ์ผ๋Ÿฌ์ŠคํŠธ ๋А๋‚Œ์œผ๋กœ ๋ฌ˜์‚ฌํ•ด ์ฃผ์„ธ์š”."
def generate_prompt(image):
"""์ด๋ฏธ์ง€ โ†’ ์˜์–ด ์„ค๋ช… โ†’ ํ•œ๊ตญ์–ด ํ”„๋กฌํ”„ํŠธ ์Šคํƒ€์ผ๋กœ ๋ณ€ํ™˜"""
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=512,
num_beams=3
)
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = florence_processor.post_process_generation(
generated_text,
task="<MORE_DETAILED_CAPTION>",
image_size=(image.width, image.height)
)
prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]
# ๋ฒˆ์—ญ๊ธฐ ์—†์ด ์Šคํƒ€์ผ ์ ์šฉ
cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
return cartoon_prompt
def generate_image(prompt, seed=42, randomize_seed=False):
"""ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ โ†’ ์ด๋ฏธ์ง€ ์ƒ์„ฑ"""
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt="์™œ๊ณก๋œ ์†, ํ๋ฆผ, ์ด์ƒํ•œ ์–ผ๊ตด",
guidance_scale=1.5,
num_inference_steps=8,
width=768,
height=768,
generator=generator
).images[0]
return image, seed
# Gradio UI ๊ตฌ์„ฑ
with gr.Blocks() as demo:
gr.Markdown("# ๐Ÿ–ผ ์ด๋ฏธ์ง€ โ†’ ์„ค๋ช… ์ƒ์„ฑ โ†’ ์นดํˆฐ ์ด๋ฏธ์ง€ ์ž๋™ ์ƒ์„ฑ๊ธฐ")
gr.Markdown("**๐Ÿ“Œ ์‚ฌ์šฉ๋ฒ• ์•ˆ๋‚ด (ํ•œ๊ตญ์–ด)**\n"
"- ์™ผ์ชฝ์— ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.\n"
"- AI๊ฐ€ ์˜์–ด ์„ค๋ช…์„ ๋งŒ๋“ค๊ณ , ๋‚ด๋ถ€์—์„œ ํ•œ๊ตญ์–ด ์Šคํƒ€์ผ ํ”„๋กฌํ”„ํŠธ๋กœ ์žฌ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n"
"- ์˜ค๋ฅธ์ชฝ์— ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.")
with gr.Row():
with gr.Column():
input_img = gr.Image(label="๐ŸŽจ ์›๋ณธ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
run_button = gr.Button("โœจ ์ƒ์„ฑ ์‹œ์ž‘")
with gr.Column():
prompt_out = gr.Textbox(label="๐Ÿ“ ์Šคํƒ€์ผ ์ ์šฉ๋œ ํ”„๋กฌํ”„ํŠธ", lines=3, show_copy_button=True)
output_img = gr.Image(label="๐ŸŽ‰ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")
def full_process(img):
prompt = generate_prompt(img)
image, seed = generate_image(prompt, randomize_seed=True)
return prompt, image
run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])
demo.launch()