Futuretop's picture
Update app.py
2dc80b5 verified
raw
history blame
5.33 kB
import sys, types, importlib.machinery, importlib
spec = importlib.machinery.ModuleSpec('flash_attn', loader=None)
mod = types.ModuleType('flash_attn')
mod.__spec__ = spec
sys.modules['flash_attn'] = mod
import huggingface_hub as _hf_hub
_hf_hub.cached_download = _hf_hub.hf_hub_download
import gradio as gr
import torch
import random
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from diffusers import StableDiffusionPipeline
try:
from diffusers import FlowMatchEulerDiscreteScheduler
except ImportError:
from diffusers import EulerDiscreteScheduler as FlowMatchEulerDiscreteScheduler
import transformers.utils.import_utils as _import_utils
from transformers.utils import is_flash_attn_2_available
_import_utils._is_package_available = lambda pkg: False
_import_utils.is_flash_attn_2_available = lambda: False
hf_utils = importlib.import_module('transformers.utils')
hf_utils.is_flash_attn_2_available = lambda *a, **k: False
hf_utils.is_flash_attn_greater_or_equal_2_10 = lambda *a, **k: False
mask_utils = importlib.import_module("transformers.modeling_attn_mask_utils")
for fn in ("_prepare_4d_attention_mask_for_sdpa", "_prepare_4d_causal_attention_mask_for_sdpa"):
if not hasattr(mask_utils, fn):
setattr(mask_utils, fn, lambda *a, **k: None)
cfg_mod = importlib.import_module("transformers.configuration_utils")
_PrC = cfg_mod.PretrainedConfig
_orig_getattr = _PrC.__getattribute__
def _getattr(self, name):
if name == "_attn_implementation":
return "sdpa"
return _orig_getattr(self, name)
_PrC.__getattribute__ = _getattr
REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"
# Florence-2 ๋กœ๋“œ
device = "cuda" if torch.cuda.is_available() else "cpu"
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', revision = REVISION, trust_remote_code=True, torch_dtype=torch.float16)
florence_model.to("cpu")
florence_model.eval()
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', revision = REVISION, trust_remote_code=True)
# Stable Diffusion TurboX ๋กœ๋“œ
diffusers.StableDiffusion3Pipeline = StableDiffusionPipeline
model_repo = "tensorart/stable-diffusion-3.5-large-TurboX"
pipe = DiffusionPipeline.from_pretrained(
model_repo,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo, subfolder="scheduler", trust_remote_code = True, shift=5)
pipe = pipe.to(device)
MAX_SEED = 2**31 - 1
def pseudo_translate_to_korean_style(en_prompt: str) -> str:
# ๋ฒˆ์—ญ ์—†์ด ์Šคํƒ€์ผ ์ ์šฉ
return f"Cartoon styled {en_prompt} handsome or pretty people"
def generate_prompt(image):
"""์ด๋ฏธ์ง€ โ†’ ์˜์–ด ์„ค๋ช… โ†’ ํ•œ๊ตญ์–ด ํ”„๋กฌํ”„ํŠธ ์Šคํƒ€์ผ๋กœ ๋ณ€ํ™˜"""
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=512,
num_beams=3
)
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = florence_processor.post_process_generation(
generated_text,
task="<MORE_DETAILED_CAPTION>",
image_size=(image.width, image.height)
)
prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]
# ๋ฒˆ์—ญ๊ธฐ ์—†์ด ์Šคํƒ€์ผ ์ ์šฉ
cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
return cartoon_prompt
def generate_image(prompt, seed=42, randomize_seed=False):
"""ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ โ†’ ์ด๋ฏธ์ง€ ์ƒ์„ฑ"""
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt="์™œ๊ณก๋œ ์†, ํ๋ฆผ, ์ด์ƒํ•œ ์–ผ๊ตด",
guidance_scale=1.5,
num_inference_steps=8,
width=768,
height=768,
generator=generator
).images[0]
return image, seed
# Gradio UI ๊ตฌ์„ฑ
with gr.Blocks() as demo:
gr.Markdown("# ๐Ÿ–ผ ์ด๋ฏธ์ง€ โ†’ ์„ค๋ช… ์ƒ์„ฑ โ†’ ์นดํˆฐ ์ด๋ฏธ์ง€ ์ž๋™ ์ƒ์„ฑ๊ธฐ")
gr.Markdown("**๐Ÿ“Œ ์‚ฌ์šฉ๋ฒ• ์•ˆ๋‚ด (ํ•œ๊ตญ์–ด)**\n"
"- ์™ผ์ชฝ์— ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.\n"
"- AI๊ฐ€ ์˜์–ด ์„ค๋ช…์„ ๋งŒ๋“ค๊ณ , ๋‚ด๋ถ€์—์„œ ํ•œ๊ตญ์–ด ์Šคํƒ€์ผ ํ”„๋กฌํ”„ํŠธ๋กœ ์žฌ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n"
"- ์˜ค๋ฅธ์ชฝ์— ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.")
with gr.Row():
with gr.Column():
input_img = gr.Image(label="๐ŸŽจ ์›๋ณธ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
run_button = gr.Button("โœจ ์ƒ์„ฑ ์‹œ์ž‘")
with gr.Column():
prompt_out = gr.Textbox(label="๐Ÿ“ ์Šคํƒ€์ผ ์ ์šฉ๋œ ํ”„๋กฌํ”„ํŠธ", lines=3, show_copy_button=True)
output_img = gr.Image(label="๐ŸŽ‰ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")
def full_process(img):
prompt = generate_prompt(img)
image, seed = generate_image(prompt, randomize_seed=True)
return prompt, image
run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])
demo.launch()