Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import os | |
import spaces | |
import uuid | |
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler | |
from diffusers.utils import export_to_video | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from PIL import Image | |
# 自定义CSS样式 | |
custom_css = """ | |
.container { | |
max-width: 1000px; | |
margin: auto; | |
padding: 20px; | |
} | |
.title { | |
background: linear-gradient(90deg, #00ff87 0%, #60efff 100%); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
font-size: 2.5em; | |
text-align: center; | |
margin-bottom: 1em; | |
font-weight: bold; | |
text-shadow: 2px 2px 4px rgba(0,0,0,0.1); | |
} | |
.subtitle { | |
color: #666; | |
text-align: center; | |
margin-bottom: 2em; | |
font-size: 1.2em; | |
} | |
.warning { | |
color: #ff4b4b; | |
font-weight: bold; | |
text-align: center; | |
padding: 10px; | |
margin: 10px 0; | |
border-radius: 5px; | |
background: rgba(255,75,75,0.1); | |
} | |
.info { | |
color: #4b8bff; | |
text-align: center; | |
padding: 10px; | |
margin: 10px 0; | |
border-radius: 5px; | |
background: rgba(75,139,255,0.1); | |
} | |
.gradio-container { | |
background: linear-gradient(135deg, #1a1a1a 0%, #2a2a2a 100%); | |
} | |
.gr-button { | |
background: linear-gradient(90deg, #00ff87 0%, #60efff 100%); | |
border: none; | |
color: black; | |
font-weight: bold; | |
} | |
.gr-button:hover { | |
background: linear-gradient(90deg, #60efff 0%, #00ff87 100%); | |
transform: translateY(-2px); | |
box-shadow: 0 5px 15px rgba(0,255,135,0.3); | |
transition: all 0.3s ease; | |
} | |
.gr-input, .gr-dropdown { | |
border: 2px solid rgba(96,239,255,0.2); | |
border-radius: 8px; | |
background: rgba(26,26,26,0.9); | |
color: white; | |
} | |
.gr-input:focus, .gr-dropdown:focus { | |
border-color: #00ff87; | |
box-shadow: 0 0 10px rgba(0,255,135,0.3); | |
} | |
.gr-form { | |
background: rgba(42,42,42,0.8); | |
border-radius: 15px; | |
padding: 20px; | |
box-shadow: 0 8px 32px rgba(0,0,0,0.3); | |
} | |
.example-container { | |
background: rgba(255,255,255,0.05); | |
border-radius: 10px; | |
padding: 15px; | |
margin: 10px 0; | |
} | |
""" | |
USERS = { | |
"admin": "svip", | |
"svip": "svip8888" | |
} | |
# Constants | |
bases = { | |
"卡通风格": "frankjoshua/toonyou_beta6", | |
"写实风格": "emilianJR/epiCRealism", | |
"3D风格": "Lykon/DreamShaper", | |
"动漫风格": "Yntec/mistoonAnime2" | |
} | |
step_loaded = None | |
base_loaded = "写实风格" | |
motion_loaded = None | |
# Ensure model and scheduler are initialized in GPU-enabled function | |
if not torch.cuda.is_available(): | |
raise NotImplementedError("未检测到GPU!") | |
device = "cuda" | |
dtype = torch.float16 | |
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device) | |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear") | |
# Safety checkers | |
from transformers import CLIPFeatureExtractor | |
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") | |
def login(username, password): | |
if username in USERS and USERS[username] == password: | |
return True, "登录成功!" | |
return False, "用户名或密码错误!" | |
def generate_image(prompt, base="写实风格", motion="", step=8, progress=gr.Progress()): | |
global step_loaded | |
global base_loaded | |
global motion_loaded | |
print(prompt, base, step) | |
if step_loaded != step: | |
repo = "ByteDance/AnimateDiff-Lightning" | |
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" | |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False) | |
step_loaded = step | |
if base_loaded != base: | |
pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False) | |
base_loaded = base | |
if motion_loaded != motion: | |
pipe.unload_lora_weights() | |
if motion != "": | |
pipe.load_lora_weights(motion, adapter_name="motion") | |
pipe.set_adapters(["motion"], [0.7]) | |
motion_loaded = motion | |
progress((0, step)) | |
def progress_callback(i, t, z): | |
progress((i+1, step)) | |
output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step, callback=progress_callback, callback_steps=1) | |
name = str(uuid.uuid4()).replace("-", "") | |
path = f"/tmp/{name}.mp4" | |
export_to_video(output.frames[0], path, fps=10) | |
return path | |
# Gradio Interface | |
with gr.Blocks(css=custom_css) as demo: | |
# 创建两个界面容器 | |
with gr.Group(visible=True) as login_container: | |
gr.HTML(""" | |
<div class="container"> | |
<h1 class="title">🌟 OfficeChatAI 视频生成系统</h1> | |
<p class="subtitle">欢迎使用AI视频生成系统,让创意转化为现实</p> | |
</div> | |
""") | |
with gr.Group(elem_classes="gr-form"): | |
username = gr.Textbox(label="用户名", placeholder="请输入VIP用户名") | |
password = gr.Textbox(label="密码", type="password", placeholder="请输入密码") | |
login_button = gr.Button("登 录", variant="primary") | |
login_msg = gr.Textbox(label="登录状态", interactive=False) | |
# 主界面 | |
with gr.Group(visible=False) as main_container: | |
gr.HTML(""" | |
<div class="container"> | |
<h1 class="title">🎬 OfficeChatAI 视频生成工作室</h1> | |
<p class="subtitle">专业的AI视频生成平台 | VIP尊享服务</p> | |
<div class="warning">提示:首次生成视频需要较长时间,后续生成速度会显著提升</div> | |
<div class="info">为获得最佳效果,建议使用英文提示词,参考示例格式</div> | |
</div> | |
""") | |
with gr.Group(elem_classes="gr-form"): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label='创作提示词', | |
placeholder='请输入您想要生成的视频场景描述...', | |
elem_classes="gr-input" | |
) | |
with gr.Row(): | |
select_base = gr.Dropdown( | |
label='选择基础模型', | |
choices=[ | |
"卡通风格", | |
"写实风格", | |
"3D风格", | |
"动漫风格", | |
], | |
value=base_loaded, | |
interactive=True, | |
elem_classes="gr-dropdown" | |
) | |
select_motion = gr.Dropdown( | |
label='动作特效', | |
choices=[ | |
("默认效果", ""), | |
("镜头拉近", "guoyww/animatediff-motion-lora-zoom-in"), | |
("镜头拉远", "guoyww/animatediff-motion-lora-zoom-out"), | |
("向上倾斜", "guoyww/animatediff-motion-lora-tilt-up"), | |
("向下倾斜", "guoyww/animatediff-motion-lora-tilt-down"), | |
("向左平移", "guoyww/animatediff-motion-lora-pan-left"), | |
("向右平移", "guoyww/animatediff-motion-lora-pan-right"), | |
("逆时针旋转", "guoyww/animatediff-motion-lora-rolling-anticlockwise"), | |
("顺时针旋转", "guoyww/animatediff-motion-lora-rolling-clockwise"), | |
], | |
value="guoyww/animatediff-motion-lora-zoom-in", | |
interactive=True, | |
elem_classes="gr-dropdown" | |
) | |
select_step = gr.Dropdown( | |
label='生成质量', | |
choices=[ | |
('快速模式(1步)', 1), | |
('平衡模式(2步)', 2), | |
('高质量(4步)', 4), | |
('超高清(8步)', 8), | |
], | |
value=4, | |
interactive=True, | |
elem_classes="gr-dropdown" | |
) | |
submit = gr.Button( | |
value="✨ 开始生成", | |
scale=1, | |
variant="primary", | |
elem_classes=["gr-button"] | |
) | |
video = gr.Video( | |
label='创作结果', | |
autoplay=True, | |
height=512, | |
width=512, | |
elem_id="video_output", | |
elem_classes="output-video" | |
) | |
with gr.Group(elem_classes="example-container"): | |
gr.HTML("<h3 class='subtitle'>🎯 创作灵感</h3>") | |
gr.Examples( | |
examples=[ | |
["A majestic Eiffel Tower with moving clouds in the background"], | |
["A lion running through a dense forest"], | |
["An astronaut floating in space with stars twinkling"], | |
["A flock of birds flying in formation against a blue sky"], | |
["Statue of Liberty viewed from a approaching drone"], | |
["A cute panda drinking tea in a bamboo forest"], | |
["Children playing in the snow"], | |
["Cars driving on a rainy city street"] | |
], | |
fn=generate_image, | |
inputs=[prompt], | |
outputs=[video], | |
cache_examples="lazy", | |
) | |
# 生成按钮事件 | |
submit.click( | |
fn=generate_image, | |
inputs=[prompt, select_base, select_motion, select_step], | |
outputs=[video] | |
) | |
# 登录逻辑 | |
def handle_login(username, password): | |
success, message = login(username, password) | |
if success: | |
return message, gr.update(visible=False), gr.update(visible=True) | |
return message, gr.update(visible=True), gr.update(visible=False) | |
login_button.click( | |
fn=handle_login, | |
inputs=[username, password], | |
outputs=[login_msg, login_container, main_container] | |
) | |
demo.queue().launch() |