File size: 5,568 Bytes
8dd0bda
 
 
 
 
 
 
 
 
 
87d4598
4c9345a
de7d0eb
 
 
 
6d62697
684e4b9
 
6d62697
2a2fa4c
f92c4cd
 
 
19a2293
615e1c0
2dc80b5
0beba8c
bdf7188
 
 
ffe6837
f92c4cd
1042f70
 
 
 
 
eb50505
 
 
1042f70
78839d9
 
 
 
 
bd17032
 
 
 
 
076b130
bd17032
 
 
f92c4cd
 
19a2293
cea6ee2
 
fbb8f01
f92c4cd
 
5dd3dde
 
 
f92c4cd
 
 
c0f4255
f92c4cd
2a2fa4c
69afbf3
f92c4cd
 
 
 
 
 
9d32382
f92c4cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a43544
4f843fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from huggingface_hub import snapshot_download

REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"

LOCAL_FLORENCE_DIR = snapshot_download(
    repo_id="microsoft/Florence-2-base",
    revision=REVISION,
    local_files_only=False
)

import sys, types, importlib.machinery, importlib

spec = importlib.machinery.ModuleSpec('flash_attn', loader=None)
mod = types.ModuleType('flash_attn')
mod.__spec__ = spec
sys.modules['flash_attn'] = mod

import huggingface_hub as _hf_hub
_hf_hub.cached_download = _hf_hub.hf_hub_download

import gradio as gr
import torch
import random
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import diffusers
from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline
try:
    from diffusers import FlowMatchEulerDiscreteScheduler
except ImportError:
    from diffusers import EulerDiscreteScheduler as FlowMatchEulerDiscreteScheduler

import transformers.utils.import_utils as _import_utils
from transformers.utils import is_flash_attn_2_available
_import_utils._is_package_available   = lambda pkg: False
_import_utils.is_flash_attn_2_available = lambda: False

hf_utils = importlib.import_module('transformers.utils')
hf_utils.is_flash_attn_2_available            = lambda *a, **k: False
hf_utils.is_flash_attn_greater_or_equal_2_10 = lambda *a, **k: False

mask_utils = importlib.import_module("transformers.modeling_attn_mask_utils")
for fn in ("_prepare_4d_attention_mask_for_sdpa", "_prepare_4d_causal_attention_mask_for_sdpa"):
    if not hasattr(mask_utils, fn):
        setattr(mask_utils, fn, lambda *a, **k: None)

cfg_mod = importlib.import_module("transformers.configuration_utils")
_PrC = cfg_mod.PretrainedConfig
_orig_getattr = _PrC.__getattribute__
def _getattr(self, name):
    if name == "_attn_implementation":
        return "sdpa"
    return _orig_getattr(self, name)
_PrC.__getattribute__ = _getattr

# Florence-2 ๋กœ๋“œ
device = "cuda" if torch.cuda.is_available() else "cpu"
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', revision = REVISION, trust_remote_code=True, torch_dtype=torch.float16)
florence_model.to("cpu")
florence_model.eval()
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', revision = REVISION, trust_remote_code=True)

# Stable Diffusion TurboX ๋กœ๋“œ

diffusers.StableDiffusion3Pipeline = StableDiffusionPipeline

model_repo = "tensorart/stable-diffusion-3.5-large-TurboX"
pipe = DiffusionPipeline.from_pretrained(
    model_repo,
    trust_remote_code=True,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo, subfolder="scheduler", trust_remote_code = True, shift=5)
pipe = pipe.to(device)

MAX_SEED = 2**31 - 1

def pseudo_translate_to_korean_style(en_prompt: str) -> str:
    # ๋ฒˆ์—ญ ์—†์ด ์Šคํƒ€์ผ ์ ์šฉ
    return f"Cartoon styled {en_prompt} handsome or pretty people"

def generate_prompt(image):
    """์ด๋ฏธ์ง€ โ†’ ์˜์–ด ์„ค๋ช… โ†’ ํ•œ๊ตญ์–ด ํ”„๋กฌํ”„ํŠธ ์Šคํƒ€์ผ๋กœ ๋ณ€ํ™˜"""
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)

    inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
    generated_ids = florence_model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=512,
        num_beams=3
    )
    generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = florence_processor.post_process_generation(
        generated_text,
        task="<MORE_DETAILED_CAPTION>",
        image_size=(image.width, image.height)
    )
    prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]

    # ๋ฒˆ์—ญ๊ธฐ ์—†์ด ์Šคํƒ€์ผ ์ ์šฉ
    cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
    return cartoon_prompt

def generate_image(prompt, seed=42, randomize_seed=False):
    """ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ โ†’ ์ด๋ฏธ์ง€ ์ƒ์„ฑ"""
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    image = pipe(
        prompt=prompt,
        negative_prompt="์™œ๊ณก๋œ ์†, ํ๋ฆผ, ์ด์ƒํ•œ ์–ผ๊ตด",
        guidance_scale=1.5,
        num_inference_steps=8,
        width=768,
        height=768,
        generator=generator
    ).images[0]
    return image, seed

# Gradio UI ๊ตฌ์„ฑ
with gr.Blocks() as demo:
    gr.Markdown("# ๐Ÿ–ผ ์ด๋ฏธ์ง€ โ†’ ์„ค๋ช… ์ƒ์„ฑ โ†’ ์นดํˆฐ ์ด๋ฏธ์ง€ ์ž๋™ ์ƒ์„ฑ๊ธฐ")
    
    gr.Markdown("**๐Ÿ“Œ ์‚ฌ์šฉ๋ฒ• ์•ˆ๋‚ด (ํ•œ๊ตญ์–ด)**\n"
                "- ์™ผ์ชฝ์— ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.\n"
                "- AI๊ฐ€ ์˜์–ด ์„ค๋ช…์„ ๋งŒ๋“ค๊ณ , ๋‚ด๋ถ€์—์„œ ํ•œ๊ตญ์–ด ์Šคํƒ€์ผ ํ”„๋กฌํ”„ํŠธ๋กœ ์žฌ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n"
                "- ์˜ค๋ฅธ์ชฝ์— ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.")

    with gr.Row():
        with gr.Column():
            input_img = gr.Image(label="๐ŸŽจ ์›๋ณธ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
            run_button = gr.Button("โœจ ์ƒ์„ฑ ์‹œ์ž‘")

        with gr.Column():
            prompt_out = gr.Textbox(label="๐Ÿ“ ์Šคํƒ€์ผ ์ ์šฉ๋œ ํ”„๋กฌํ”„ํŠธ", lines=3, show_copy_button=True)
            output_img = gr.Image(label="๐ŸŽ‰ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")

    def full_process(img):
        prompt = generate_prompt(img)
        image, seed = generate_image(prompt, randomize_seed=True)
        return prompt, image

    run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])

demo.launch()