Futuretop's picture
Update app.py
f8215e3 verified
import os
from huggingface_hub import snapshot_download
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"
LOCAL_FLORENCE = snapshot_download(
repo_id="microsoft/Florence-2-base",
revision=REVISION
)
LOCAL_TURBOX = snapshot_download(
repo_id="tensorart/stable-diffusion-3.5-large-TurboX"
)
LOCAL_FLORENCE_DIR = snapshot_download(
repo_id="microsoft/Florence-2-base",
revision=REVISION,
local_files_only=False
)
import sys, types, importlib.machinery, importlib
spec = importlib.machinery.ModuleSpec('flash_attn', loader=None)
mod = types.ModuleType('flash_attn')
mod.__spec__ = spec
sys.modules['flash_attn'] = mod
import huggingface_hub as _hf_hub
_hf_hub.cached_download = _hf_hub.hf_hub_download
import gradio as gr
import torch
import random
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers import (
CLIPTextModel,
CLIPTokenizer,
CLIPFeatureExtractor,
)
import diffusers
from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline
from diffusers import EulerDiscreteScheduler as FlowMatchEulerDiscreteScheduler
from diffusers import UNet2DConditionModel
# from diffusers import FlowMatchEulerDiscreteScheduler
# diffusers.FlowMatchEulerDiscreteScheduler = EulerDiscreteScheduler
import transformers.utils.import_utils as _import_utils
from transformers.utils import is_flash_attn_2_available
_import_utils._is_package_available = lambda pkg: False
_import_utils.is_flash_attn_2_available = lambda: False
hf_utils = importlib.import_module('transformers.utils')
hf_utils.is_flash_attn_2_available = lambda *a, **k: False
hf_utils.is_flash_attn_greater_or_equal_2_10 = lambda *a, **k: False
mask_utils = importlib.import_module("transformers.modeling_attn_mask_utils")
for fn in ("_prepare_4d_attention_mask_for_sdpa", "_prepare_4d_causal_attention_mask_for_sdpa"):
if not hasattr(mask_utils, fn):
setattr(mask_utils, fn, lambda *a, **k: None)
cfg_mod = importlib.import_module("transformers.configuration_utils")
_PrC = cfg_mod.PretrainedConfig
_orig_getattr = _PrC.__getattribute__
def _getattr(self, name):
if name == "_attn_implementation":
return "sdpa"
return _orig_getattr(self, name)
_PrC.__getattribute__ = _getattr
model_repo = "tensorart/stable-diffusion-3.5-large-TurboX"
# Florence-2
device = "cuda" if torch.cuda.is_available() else "cpu"
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
model_repo,
subfolder="scheduler",
torch_dtype=torch.float16,
)
text_encoder = CLIPTextModel.from_pretrained(
model_repo, subfolder="text_encoder", torch_dtype=torch.float16
)
tokenizer = CLIPTokenizer.from_pretrained(
model_repo, subfolder="tokenizer"
)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="feature_extractor"
)
unet = UNet2DConditionModel.from_pretrained(
model_repo, subfolder="unet", torch_dtype=torch.float16
)
florence_model = AutoModelForCausalLM.from_pretrained(LOCAL_FLORENCE, trust_remote_code=True, torch_dtype=torch.float16)
florence_model.to("cpu")
florence_model.eval()
florence_processor = AutoProcessor.from_pretrained(LOCAL_FLORENCE, trust_remote_code=True)
# Stable Diffusion TurboX
diffusers.StableDiffusion3Pipeline = StableDiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"tensorart/stable-diffusion-3.5-large-TurboX",
torch_dtype=torch.float16,
trust_remote_code=True,
safety_checker=None,
feature_extractor=None
)
pipe = pipe.to("cuda")
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo, subfolder="scheduler", local_files_only=True, trust_remote_code = True, shift=5)
MAX_SEED = 2**31 - 1
def pseudo_translate_to_korean_style(en_prompt: str) -> str:
return f"Cartoon styled {en_prompt} handsome or pretty people"
def generate_prompt(image):
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=512,
num_beams=3
)
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = florence_processor.post_process_generation(
generated_text,
task="<MORE_DETAILED_CAPTION>",
image_size=(image.width, image.height)
)
prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]
# ๋ฒˆ์—ญ๊ธฐ ์—†์ด ์Šคํƒ€์ผ ์ ์šฉ
cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
return cartoon_prompt
def generate_image(prompt, seed=42, randomize_seed=False):
"""ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ โ†’ ์ด๋ฏธ์ง€ ์ƒ์„ฑ"""
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
guidance_scale=1.5,
num_inference_steps=8,
width=768,
height=768,
generator=generator
).images[0]
return image, seed
# Gradio UI ๊ตฌ์„ฑ
with gr.Blocks() as demo:
gr.Markdown("# ๐Ÿ–ผ ์ด๋ฏธ์ง€ โ†’ ์„ค๋ช… ์ƒ์„ฑ โ†’ ์นดํˆฐ ์ด๋ฏธ์ง€ ์ž๋™ ์ƒ์„ฑ๊ธฐ")
gr.Markdown("**๐Ÿ“Œ ์‚ฌ์šฉ๋ฒ• ์•ˆ๋‚ด (ํ•œ๊ตญ์–ด)**\n"
"- ์™ผ์ชฝ์— ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.\n"
"- AI๊ฐ€ ์˜์–ด ์„ค๋ช…์„ ๋งŒ๋“ค๊ณ , ๋‚ด๋ถ€์—์„œ ํ•œ๊ตญ์–ด ์Šคํƒ€์ผ ํ”„๋กฌํ”„ํŠธ๋กœ ์žฌ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n"
"- ์˜ค๋ฅธ์ชฝ์— ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.")
with gr.Row():
with gr.Column():
input_img = gr.Image(label="๐ŸŽจ ์›๋ณธ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
run_button = gr.Button("โœจ ์ƒ์„ฑ ์‹œ์ž‘")
with gr.Column():
prompt_out = gr.Textbox(label="๐Ÿ“ ์Šคํƒ€์ผ ์ ์šฉ๋œ ํ”„๋กฌํ”„ํŠธ", lines=3, show_copy_button=True)
output_img = gr.Image(label="๐ŸŽ‰ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")
def full_process(img):
prompt = generate_prompt(img)
image, seed = generate_image(prompt, randomize_seed=True)
return prompt, image
run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])
demo.launch()