File size: 5,209 Bytes
87d4598
4c9345a
de7d0eb
 
 
 
6d62697
684e4b9
 
6d62697
2a2fa4c
f92c4cd
 
 
74b518c
b6bfe0d
bdf7188
 
 
ffe6837
f92c4cd
1042f70
 
 
 
 
eb50505
 
 
1042f70
78839d9
 
 
 
 
bd17032
 
 
 
 
076b130
bd17032
 
 
0ce165c
9d51561
f92c4cd
 
74b518c
cea6ee2
 
fbb8f01
f92c4cd
 
 
 
 
 
2a2fa4c
f92c4cd
 
 
 
 
 
 
9d32382
f92c4cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a43544
4f843fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import sys, types, importlib.machinery, importlib

spec = importlib.machinery.ModuleSpec('flash_attn', loader=None)
mod = types.ModuleType('flash_attn')
mod.__spec__ = spec
sys.modules['flash_attn'] = mod

import huggingface_hub as _hf_hub
_hf_hub.cached_download = _hf_hub.hf_hub_download

import gradio as gr
import torch
import random
from PIL import Image
from transformers import AutoProcessor, AutoModelForSeq2SeqLM
from diffusers import DiffusionPipeline
try:
    from diffusers import FlowMatchEulerDiscreteScheduler
except ImportError:
    from diffusers import EulerDiscreteScheduler as FlowMatchEulerDiscreteScheduler

import transformers.utils.import_utils as _import_utils
from transformers.utils import is_flash_attn_2_available
_import_utils._is_package_available   = lambda pkg: False
_import_utils.is_flash_attn_2_available = lambda: False

hf_utils = importlib.import_module('transformers.utils')
hf_utils.is_flash_attn_2_available            = lambda *a, **k: False
hf_utils.is_flash_attn_greater_or_equal_2_10 = lambda *a, **k: False

mask_utils = importlib.import_module("transformers.modeling_attn_mask_utils")
for fn in ("_prepare_4d_attention_mask_for_sdpa", "_prepare_4d_causal_attention_mask_for_sdpa"):
    if not hasattr(mask_utils, fn):
        setattr(mask_utils, fn, lambda *a, **k: None)

cfg_mod = importlib.import_module("transformers.configuration_utils")
_PrC = cfg_mod.PretrainedConfig
_orig_getattr = _PrC.__getattribute__
def _getattr(self, name):
    if name == "_attn_implementation":
        return "sdpa"
    return _orig_getattr(self, name)
_PrC.__getattribute__ = _getattr

REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"

# Florence-2 ๋กœ๋“œ
device = "cuda" if torch.cuda.is_available() else "cpu"
florence_model = AutoModelForSeq2SeqLM.from_pretrained('microsoft/Florence-2-base', revision = REVISION, trust_remote_code=True, torch_dtype=torch.float16)
florence_model.to("cpu")
florence_model.eval()
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', revision = REVISION, trust_remote_code=True)

# Stable Diffusion TurboX ๋กœ๋“œ
model_repo = "tensorart/stable-diffusion-3.5-large-TurboX"
pipe = DiffusionPipeline.from_pretrained(
    model_repo,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo, subfolder="scheduler", shift=5)
pipe = pipe.to(device)

MAX_SEED = 2**31 - 1

def pseudo_translate_to_korean_style(en_prompt: str) -> str:
    # ๋ฒˆ์—ญ ์—†์ด ์Šคํƒ€์ผ ์ ์šฉ
    return f"Cartoon styled {en_prompt} handsome or pretty people"

def generate_prompt(image):
    """์ด๋ฏธ์ง€ โ†’ ์˜์–ด ์„ค๋ช… โ†’ ํ•œ๊ตญ์–ด ํ”„๋กฌํ”„ํŠธ ์Šคํƒ€์ผ๋กœ ๋ณ€ํ™˜"""
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)

    inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
    generated_ids = florence_model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=512,
        num_beams=3
    )
    generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = florence_processor.post_process_generation(
        generated_text,
        task="<MORE_DETAILED_CAPTION>",
        image_size=(image.width, image.height)
    )
    prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]

    # ๋ฒˆ์—ญ๊ธฐ ์—†์ด ์Šคํƒ€์ผ ์ ์šฉ
    cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
    return cartoon_prompt

def generate_image(prompt, seed=42, randomize_seed=False):
    """ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ โ†’ ์ด๋ฏธ์ง€ ์ƒ์„ฑ"""
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    image = pipe(
        prompt=prompt,
        negative_prompt="์™œ๊ณก๋œ ์†, ํ๋ฆผ, ์ด์ƒํ•œ ์–ผ๊ตด",
        guidance_scale=1.5,
        num_inference_steps=8,
        width=768,
        height=768,
        generator=generator
    ).images[0]
    return image, seed

# Gradio UI ๊ตฌ์„ฑ
with gr.Blocks() as demo:
    gr.Markdown("# ๐Ÿ–ผ ์ด๋ฏธ์ง€ โ†’ ์„ค๋ช… ์ƒ์„ฑ โ†’ ์นดํˆฐ ์ด๋ฏธ์ง€ ์ž๋™ ์ƒ์„ฑ๊ธฐ")
    
    gr.Markdown("**๐Ÿ“Œ ์‚ฌ์šฉ๋ฒ• ์•ˆ๋‚ด (ํ•œ๊ตญ์–ด)**\n"
                "- ์™ผ์ชฝ์— ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.\n"
                "- AI๊ฐ€ ์˜์–ด ์„ค๋ช…์„ ๋งŒ๋“ค๊ณ , ๋‚ด๋ถ€์—์„œ ํ•œ๊ตญ์–ด ์Šคํƒ€์ผ ํ”„๋กฌํ”„ํŠธ๋กœ ์žฌ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n"
                "- ์˜ค๋ฅธ์ชฝ์— ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.")

    with gr.Row():
        with gr.Column():
            input_img = gr.Image(label="๐ŸŽจ ์›๋ณธ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
            run_button = gr.Button("โœจ ์ƒ์„ฑ ์‹œ์ž‘")

        with gr.Column():
            prompt_out = gr.Textbox(label="๐Ÿ“ ์Šคํƒ€์ผ ์ ์šฉ๋œ ํ”„๋กฌํ”„ํŠธ", lines=3, show_copy_button=True)
            output_img = gr.Image(label="๐ŸŽ‰ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")

    def full_process(img):
        prompt = generate_prompt(img)
        image, seed = generate_image(prompt, randomize_seed=True)
        return prompt, image

    run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])

demo.launch()