ginipick's picture
Update app.py
cd1d8d4 verified
raw
history blame
9.61 kB
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image
import os
import spaces
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.unet_2d_condition import UNet2DConditionModel
from diffusers import AutoencoderKL, EulerDiscreteScheduler
from huggingface_hub import snapshot_download
device = "cuda"
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
ckpt_dir = f'{root_dir}/weights/Kolors'
snapshot_download(repo_id="Kwai-Kolors/Kolors", local_dir=ckpt_dir)
snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus", local_dir=f"{root_dir}/weights/Kolors-IP-Adapter-Plus")
# Load models
text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder',
ignore_mismatched_sizes=True
).to(dtype=torch.float16, device=device)
ip_img_size = 336
clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=clip_image_processor,
force_zeros_for_empty_prompt=False
).to(device)
if hasattr(pipe.unet, 'encoder_hid_proj'):
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
pipe.load_ip_adapter(f'{root_dir}/weights/Kolors-IP-Adapter-Plus', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# ----------------------------------------------
# infer ํ•จ์ˆ˜ (๊ธฐ์กด ๋กœ์ง ๊ทธ๋Œ€๋กœ ์œ ์ง€)
# ----------------------------------------------
@spaces.GPU(duration=80)
def infer(
user_prompt,
ip_adapter_image,
ip_adapter_scale=0.5,
negative_prompt="",
seed=100,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=5.0,
num_inference_steps=50,
progress=gr.Progress(track_tqdm=True)
):
# ์ˆจ๊ฒจ์ง„(๊ธฐ๋ณธ/ํ•„์ˆ˜) ํ”„๋กฌํ”„ํŠธ
hidden_prompt = (
"Ghibli Studio style, Charming hand-drawn anime-style illustration"
)
# ์‹ค์ œ๋กœ ํŒŒ์ดํ”„๋ผ์ธ์— ์ „๋‹ฌํ•  ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ
prompt = f"{hidden_prompt}, {user_prompt}"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
pipe.to("cuda")
image_encoder.to("cuda")
pipe.image_encoder = image_encoder
pipe.set_ip_adapter_scale([ip_adapter_scale])
image = pipe(
prompt=prompt,
ip_adapter_image=[ip_adapter_image],
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=generator,
).images[0]
return image, seed
examples = [
[
"background alps",
"gh0.webp",
0.5
],
[
"dancing",
"gh5.jpg",
0.5
],
[
"smile",
"gh2.jpg",
0.5
],
[
"3d style",
"gh3.webp",
0.6
],
[
"with Pikachu",
"gh4.jpg",
0.5
],
[
"Ghibli Studio style, Charming hand-drawn anime-style illustration",
"gh7.jpg",
0.5
],
[
"Ghibli Studio style, Charming hand-drawn anime-style illustration",
"gh1.jpg",
0.5
],
]
# --------------------------
# ๊ฐœ์„ ๋œ UI๋ฅผ ์œ„ํ•œ CSS
# --------------------------
css = """
body {
background: linear-gradient(135deg, #f5f7fa, #c3cfe2);
font-family: 'Helvetica Neue', Arial, sans-serif;
color: #333;
margin: 0;
padding: 0;
}
#col-container {
margin: 0 auto !important;
max-width: 720px;
background: rgba(255,255,255,0.85);
border-radius: 16px;
padding: 2rem;
box-shadow: 0 8px 24px rgba(0,0,0,0.1);
}
#header-title {
text-align: center;
font-size: 2rem;
font-weight: bold;
margin-bottom: 1rem;
}
#prompt-row {
display: flex;
gap: 0.5rem;
align-items: center;
margin-bottom: 1rem;
}
#prompt-text {
flex: 1;
}
#result img {
object-position: top;
border-radius: 8px;
}
#result .image-container {
height: 100%;
}
.gr-button {
background-color: #2E8BFB !important;
color: white !important;
border: none !important;
transition: background-color 0.2s ease;
}
.gr-button:hover {
background-color: #186EDB !important;
}
.gr-slider input[type=range] {
accent-color: #2E8BFB !important;
}
.gr-box {
background-color: #fafafa !important;
border: 1px solid #ddd !important;
border-radius: 8px !important;
padding: 1rem !important;
}
#advanced-settings {
margin-top: 1rem;
border-radius: 8px;
}
"""
with gr.Blocks(theme="apriel", css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("<div id='header-title'>Ghibli Meme Studio</div>")
gr.Markdown("<div id='header-title' style='font-size: 12px;'>Community: https://discord.gg/openfreeai</div>")
# ์ƒ๋‹จ: ํ”„๋กฌํ”„ํŠธ ์ž…๋ ฅ + ์‹คํ–‰ ๋ฒ„ํŠผ
with gr.Row(elem_id="prompt-row"):
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
elem_id="prompt-text",
)
run_button = gr.Button("Run", elem_id="run-button")
# ๊ฐ€์šด๋ฐ: ์ด๋ฏธ์ง€ ์ž…๋ ฅ๊ณผ ์Šฌ๋ผ์ด๋”, ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€
with gr.Row():
with gr.Column():
ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil")
ip_adapter_scale = gr.Slider(
label="Image influence scale",
info="Use 1 for creating variations",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.5,
)
result = gr.Image(label="Result", elem_id="result")
# ํ•˜๋‹จ: ๊ณ ๊ธ‰ ์„ค์ •(Accordion)
with gr.Accordion("Advanced Settings", open=False, elem_id="advanced-settings"):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=2,
placeholder=(
"Copy(worst quality, low quality:1.4), bad anatomy, bad hands, text, error, "
"missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, "
"normal quality, jpeg artifacts, signature, watermark, username, blurry, "
"artist name, (deformed iris, deformed pupils:1.2), (semi-realistic, cgi, "
"3d, render:1.1), amateur, (poorly drawn hands, poorly drawn face:1.2)"
),
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=50,
)
# ์˜ˆ์‹œ๋“ค
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt, ip_adapter_image, ip_adapter_scale],
outputs=[result, seed],
cache_examples="lazy"
)
# ๋ฒ„ํŠผ ํด๋ฆญ/ํ”„๋กฌํ”„ํŠธ ์—”ํ„ฐ ์‹œ ์‹คํ–‰
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
ip_adapter_image,
ip_adapter_scale,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps
],
outputs=[result, seed]
)
demo.queue().launch()