Futuretop's picture
Update app.py
c9314ca verified
raw
history blame
5.77 kB
import os
from huggingface_hub import snapshot_download
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"
LOCAL_FLORENCE = snapshot_download(
repo_id="microsoft/Florence-2-base",
revision=REVISION
)
LOCAL_TURBOX = snapshot_download(
repo_id="tensorart/stable-diffusion-3.5-large-TurboX"
)
LOCAL_FLORENCE_DIR = snapshot_download(
repo_id="microsoft/Florence-2-base",
revision=REVISION,
local_files_only=False
)
import sys, types, importlib.machinery, importlib
spec = importlib.machinery.ModuleSpec('flash_attn', loader=None)
mod = types.ModuleType('flash_attn')
mod.__spec__ = spec
sys.modules['flash_attn'] = mod
import huggingface_hub as _hf_hub
_hf_hub.cached_download = _hf_hub.hf_hub_download
import gradio as gr
import torch
import random
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import diffusers
from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline
from diffusers import EulerDiscreteScheduler
diffusers.FlowMatchEulerDiscreteScheduler = EulerDiscreteScheduler
import transformers.utils.import_utils as _import_utils
from transformers.utils import is_flash_attn_2_available
_import_utils._is_package_available = lambda pkg: False
_import_utils.is_flash_attn_2_available = lambda: False
hf_utils = importlib.import_module('transformers.utils')
hf_utils.is_flash_attn_2_available = lambda *a, **k: False
hf_utils.is_flash_attn_greater_or_equal_2_10 = lambda *a, **k: False
mask_utils = importlib.import_module("transformers.modeling_attn_mask_utils")
for fn in ("_prepare_4d_attention_mask_for_sdpa", "_prepare_4d_causal_attention_mask_for_sdpa"):
if not hasattr(mask_utils, fn):
setattr(mask_utils, fn, lambda *a, **k: None)
cfg_mod = importlib.import_module("transformers.configuration_utils")
_PrC = cfg_mod.PretrainedConfig
_orig_getattr = _PrC.__getattribute__
def _getattr(self, name):
if name == "_attn_implementation":
return "sdpa"
return _orig_getattr(self, name)
_PrC.__getattribute__ = _getattr
# Florence-2 λ‘œλ“œ
device = "cuda" if torch.cuda.is_available() else "cpu"
florence_model = AutoModelForCausalLM.from_pretrained(LOCAL_FLORENCE, trust_remote_code=True, torch_dtype=torch.float16)
florence_model.to("cpu")
florence_model.eval()
florence_processor = AutoProcessor.from_pretrained(LOCAL_FLORENCE, trust_remote_code=True)
# Stable Diffusion TurboX λ‘œλ“œ
diffusers.StableDiffusion3Pipeline = StableDiffusionPipeline
model_repo = "tensorart/stable-diffusion-3.5-large-TurboX"
pipe = diffusers.DiffusionPipeline.from_pretrained(
"tensorart/stable-diffusion-3.5-large-TurboX",
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo, subfolder="scheduler", local_files_only=True, trust_remote_code = True, shift=5)
pipe = pipe.to(device)
MAX_SEED = 2**31 - 1
def pseudo_translate_to_korean_style(en_prompt: str) -> str:
# λ²ˆμ—­ 없이 μŠ€νƒ€μΌ 적용
return f"Cartoon styled {en_prompt} handsome or pretty people"
def generate_prompt(image):
"""이미지 β†’ μ˜μ–΄ μ„€λͺ… β†’ ν•œκ΅­μ–΄ ν”„λ‘¬ν”„νŠΈ μŠ€νƒ€μΌλ‘œ λ³€ν™˜"""
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=512,
num_beams=3
)
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = florence_processor.post_process_generation(
generated_text,
task="<MORE_DETAILED_CAPTION>",
image_size=(image.width, image.height)
)
prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]
# λ²ˆμ—­κΈ° 없이 μŠ€νƒ€μΌ 적용
cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
return cartoon_prompt
def generate_image(prompt, seed=42, randomize_seed=False):
"""ν…μŠ€νŠΈ ν”„λ‘¬ν”„νŠΈ β†’ 이미지 생성"""
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt="μ™œκ³‘λœ 손, 흐림, μ΄μƒν•œ μ–Όκ΅΄",
guidance_scale=1.5,
num_inference_steps=8,
width=768,
height=768,
generator=generator
).images[0]
return image, seed
# Gradio UI ꡬ성
with gr.Blocks() as demo:
gr.Markdown("# πŸ–Ό 이미지 β†’ μ„€λͺ… 생성 β†’ 카툰 이미지 μžλ™ 생성기")
gr.Markdown("**πŸ“Œ μ‚¬μš©λ²• μ•ˆλ‚΄ (ν•œκ΅­μ–΄)**\n"
"- μ™Όμͺ½μ— 이미지λ₯Ό μ—…λ‘œλ“œν•˜μ„Έμš”.\n"
"- AIκ°€ μ˜μ–΄ μ„€λͺ…을 λ§Œλ“€κ³ , λ‚΄λΆ€μ—μ„œ ν•œκ΅­μ–΄ μŠ€νƒ€μΌ ν”„λ‘¬ν”„νŠΈλ‘œ μž¬κ΅¬μ„±ν•©λ‹ˆλ‹€.\n"
"- 였λ₯Έμͺ½μ— κ²°κ³Ό 이미지가 μƒμ„±λ©λ‹ˆλ‹€.")
with gr.Row():
with gr.Column():
input_img = gr.Image(label="🎨 원본 이미지 μ—…λ‘œλ“œ")
run_button = gr.Button("✨ 생성 μ‹œμž‘")
with gr.Column():
prompt_out = gr.Textbox(label="πŸ“ μŠ€νƒ€μΌ 적용된 ν”„λ‘¬ν”„νŠΈ", lines=3, show_copy_button=True)
output_img = gr.Image(label="πŸŽ‰ μƒμ„±λœ 이미지")
def full_process(img):
prompt = generate_prompt(img)
image, seed = generate_image(prompt, randomize_seed=True)
return prompt, image
run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])
demo.launch()