File size: 4,669 Bytes
87d4598
4c9345a
de7d0eb
 
 
 
6d62697
684e4b9
 
6d62697
2a2fa4c
f92c4cd
 
 
 
b6bfe0d
bdf7188
 
 
ffe6837
f92c4cd
1042f70
 
 
 
 
 
 
 
0ce165c
9d51561
f92c4cd
 
cea6ee2
 
 
fbb8f01
f92c4cd
 
 
 
 
 
2a2fa4c
f92c4cd
 
 
 
 
 
 
9d32382
f92c4cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a43544
4f843fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import sys, types, importlib.machinery, importlib

spec = importlib.machinery.ModuleSpec('flash_attn', loader=None)
mod = types.ModuleType('flash_attn')
mod.__spec__ = spec
sys.modules['flash_attn'] = mod

import huggingface_hub as _hf_hub
_hf_hub.cached_download = _hf_hub.hf_hub_download

import gradio as gr
import torch
import random
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from diffusers import DiffusionPipeline
try:
    from diffusers import FlowMatchEulerDiscreteScheduler
except ImportError:
    from diffusers import EulerDiscreteScheduler as FlowMatchEulerDiscreteScheduler

import transformers.utils.import_utils as _import_utils
from transformers.utils import is_flash_attn_2_available
_import_utils._is_package_available   = lambda pkg: False
_import_utils.is_flash_attn_2_available = lambda: False

transformers.utils.is_flash_attn_2_available              = getattr(transformers.utils, "is_flash_attn_2_available", lambda: False)
transformers.utils.is_flash_attn_greater_or_equal_2_10    = lambda *args, **kwargs: False

REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"

# Florence-2 ๋กœ๋“œ
device = "cuda" if torch.cuda.is_available() else "cpu"
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', revision = REVISION, trust_remote_code=True, torch_dtype=torch.float16)
florence_model.to("cpu")
florence_model.eval()
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', revision = REVISION, trust_remote_code=True)

# Stable Diffusion TurboX ๋กœ๋“œ
model_repo = "tensorart/stable-diffusion-3.5-large-TurboX"
pipe = DiffusionPipeline.from_pretrained(
    model_repo,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo, subfolder="scheduler", shift=5)
pipe = pipe.to(device)

MAX_SEED = 2**31 - 1

def pseudo_translate_to_korean_style(en_prompt: str) -> str:
    # ๋ฒˆ์—ญ ์—†์ด ์Šคํƒ€์ผ ์ ์šฉ
    return f"Cartoon styled {en_prompt} handsome or pretty people"

def generate_prompt(image):
    """์ด๋ฏธ์ง€ โ†’ ์˜์–ด ์„ค๋ช… โ†’ ํ•œ๊ตญ์–ด ํ”„๋กฌํ”„ํŠธ ์Šคํƒ€์ผ๋กœ ๋ณ€ํ™˜"""
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)

    inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
    generated_ids = florence_model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=512,
        num_beams=3
    )
    generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = florence_processor.post_process_generation(
        generated_text,
        task="<MORE_DETAILED_CAPTION>",
        image_size=(image.width, image.height)
    )
    prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]

    # ๋ฒˆ์—ญ๊ธฐ ์—†์ด ์Šคํƒ€์ผ ์ ์šฉ
    cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
    return cartoon_prompt

def generate_image(prompt, seed=42, randomize_seed=False):
    """ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ โ†’ ์ด๋ฏธ์ง€ ์ƒ์„ฑ"""
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    image = pipe(
        prompt=prompt,
        negative_prompt="์™œ๊ณก๋œ ์†, ํ๋ฆผ, ์ด์ƒํ•œ ์–ผ๊ตด",
        guidance_scale=1.5,
        num_inference_steps=8,
        width=768,
        height=768,
        generator=generator
    ).images[0]
    return image, seed

# Gradio UI ๊ตฌ์„ฑ
with gr.Blocks() as demo:
    gr.Markdown("# ๐Ÿ–ผ ์ด๋ฏธ์ง€ โ†’ ์„ค๋ช… ์ƒ์„ฑ โ†’ ์นดํˆฐ ์ด๋ฏธ์ง€ ์ž๋™ ์ƒ์„ฑ๊ธฐ")
    
    gr.Markdown("**๐Ÿ“Œ ์‚ฌ์šฉ๋ฒ• ์•ˆ๋‚ด (ํ•œ๊ตญ์–ด)**\n"
                "- ์™ผ์ชฝ์— ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.\n"
                "- AI๊ฐ€ ์˜์–ด ์„ค๋ช…์„ ๋งŒ๋“ค๊ณ , ๋‚ด๋ถ€์—์„œ ํ•œ๊ตญ์–ด ์Šคํƒ€์ผ ํ”„๋กฌํ”„ํŠธ๋กœ ์žฌ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n"
                "- ์˜ค๋ฅธ์ชฝ์— ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.")

    with gr.Row():
        with gr.Column():
            input_img = gr.Image(label="๐ŸŽจ ์›๋ณธ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
            run_button = gr.Button("โœจ ์ƒ์„ฑ ์‹œ์ž‘")

        with gr.Column():
            prompt_out = gr.Textbox(label="๐Ÿ“ ์Šคํƒ€์ผ ์ ์šฉ๋œ ํ”„๋กฌํ”„ํŠธ", lines=3, show_copy_button=True)
            output_img = gr.Image(label="๐ŸŽ‰ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")

    def full_process(img):
        prompt = generate_prompt(img)
        image, seed = generate_image(prompt, randomize_seed=True)
        return prompt, image

    run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])

demo.launch()