STUDIO / app.py
ginipick's picture
Create app.py
5c398a5 verified
raw
history blame
10.9 kB
import spaces
import logging
from datetime import datetime
from pathlib import Path
import gradio as gr
import torch
import torchaudio
import os
import requests
from transformers import pipeline
import tempfile
import numpy as np
from einops import rearrange
import cv2
from scipy.io import wavfile
import librosa
import json
from typing import Optional, Tuple, List
import atexit
# ν™˜κ²½ λ³€μˆ˜ μ„€μ •μœΌλ‘œ torch.load 체크 우회 (μž„μ‹œ ν•΄κ²°μ±…)
os.environ["TRANSFORMERS_ALLOW_UNSAFE_DESERIALIZATION"] = "1"
try:
import mmaudio
except ImportError:
os.system("pip install -e .")
import mmaudio
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
setup_eval_logging)
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.sequence_config import SequenceConfig
from mmaudio.model.utils.features_utils import FeaturesUtils
# λ‘œκΉ… μ„€μ •
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
log = logging.getLogger()
# CUDA μ„€μ •
if torch.cuda.is_available():
device = torch.device("cuda")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
else:
device = torch.device("cpu")
dtype = torch.bfloat16
# λͺ¨λΈ μ„€μ •
model: ModelConfig = all_model_cfg['large_44k_v2']
model.download_if_needed()
output_dir = Path('./output/gradio')
setup_eval_logging()
# λ²ˆμ—­κΈ° μ„€μ • - safetensors μ‚¬μš© μ‹œλ„
try:
# λ¨Όμ € safetensors ν˜•μ‹μ΄ μžˆλŠ”μ§€ 확인
translator = pipeline("translation",
model="Helsinki-NLP/opus-mt-ko-en",
device="cpu",
use_fast=True, # Fast tokenizer μ‚¬μš©
trust_remote_code=False)
except Exception as e:
log.warning(f"Failed to load translation model with safetensors: {e}")
# λŒ€μ²΄ 방법: ν™˜κ²½ λ³€μˆ˜ μ„€μ • ν›„ λ‘œλ“œ
try:
translator = pipeline("translation",
model="Helsinki-NLP/opus-mt-ko-en",
device="cpu")
except Exception as e2:
log.error(f"Failed to load translation model: {e2}")
translator = None
PIXABAY_API_KEY = "33492762-a28a596ec4f286f84cd328b17"
def cleanup_temp_files():
temp_dir = tempfile.gettempdir()
for file in os.listdir(temp_dir):
if file.endswith(('.mp4', '.flac')):
try:
os.remove(os.path.join(temp_dir, file))
except:
pass
atexit.register(cleanup_temp_files)
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
with torch.cuda.device(device):
seq_cfg = model.seq_cfg
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
log.info(f'Loaded weights from {model.model_path}')
feature_utils = FeaturesUtils(
tod_vae_ckpt=model.vae_path,
synchformer_ckpt=model.synchformer_ckpt,
enable_conditions=True,
mode=model.mode,
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
need_vae_encoder=False
).to(device, dtype).eval()
return net, feature_utils, seq_cfg
net, feature_utils, seq_cfg = get_model()
# translate_prompt ν•¨μˆ˜ μˆ˜μ •
def translate_prompt(text):
try:
# λ²ˆμ—­κΈ°κ°€ μ—†μœΌλ©΄ 원본 ν…μŠ€νŠΈ λ°˜ν™˜
if translator is None:
return text
if text and any(ord(char) >= 0x3131 and ord(char) <= 0xD7A3 for char in text):
# CPUμ—μ„œ λ²ˆμ—­ μ‹€ν–‰
with torch.no_grad():
translation = translator(text)[0]['translation_text']
return translation
return text
except Exception as e:
logging.error(f"Translation error: {e}")
return text
# search_videos ν•¨μˆ˜ μˆ˜μ •
@torch.no_grad()
def search_videos(query):
try:
# CPUμ—μ„œ λ²ˆμ—­ μ‹€ν–‰
query = translate_prompt(query)
return search_pixabay_videos(query, PIXABAY_API_KEY)
except Exception as e:
logging.error(f"Video search error: {e}")
return []
def search_pixabay_videos(query, api_key):
try:
base_url = "https://pixabay.com/api/videos/"
params = {
"key": api_key,
"q": query,
"per_page": 40
}
response = requests.get(base_url, params=params)
if response.status_code == 200:
data = response.json()
return [video['videos']['large']['url'] for video in data.get('hits', [])]
return []
except Exception as e:
logging.error(f"Pixabay API error: {e}")
return []
@spaces.GPU
@torch.inference_mode()
def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
cfg_strength: float, duration: float):
prompt = translate_prompt(prompt)
negative_prompt = translate_prompt(negative_prompt)
rng = torch.Generator(device=device)
rng.manual_seed(seed)
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
clip_frames, sync_frames, duration = load_video(video, duration)
clip_frames = clip_frames.unsqueeze(0)
sync_frames = sync_frames.unsqueeze(0)
seq_cfg.duration = duration
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
audios = generate(clip_frames,
sync_frames, [prompt],
negative_text=[negative_prompt],
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=cfg_strength)
audio = audios.float().cpu()[0]
video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
make_video(video,
video_save_path,
audio,
sampling_rate=seq_cfg.sampling_rate,
duration_sec=seq_cfg.duration)
return video_save_path
@spaces.GPU
@torch.inference_mode()
def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
duration: float):
prompt = translate_prompt(prompt)
negative_prompt = translate_prompt(negative_prompt)
rng = torch.Generator(device=device)
rng.manual_seed(seed)
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
clip_frames = sync_frames = None
seq_cfg.duration = duration
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
audios = generate(clip_frames,
sync_frames, [prompt],
negative_text=[negative_prompt],
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=cfg_strength)
audio = audios.float().cpu()[0]
audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
return audio_save_path
# CSS μŠ€νƒ€μΌ
custom_css = """
.gradio-container {
background: linear-gradient(45deg, #1a1a1a, #2a2a2a);
border-radius: 15px;
box-shadow: 0 8px 32px rgba(0,0,0,0.3);
color: #e0e0e0;
}
.input-container, .output-container {
background: rgba(40, 40, 40, 0.95);
backdrop-filter: blur(10px);
border-radius: 10px;
padding: 20px;
transform-style: preserve-3d;
transition: transform 0.3s ease;
border: 1px solid rgba(255, 255, 255, 0.1);
}
.input-container:hover {
transform: translateZ(20px);
box-shadow: 0 8px 32px rgba(0,0,0,0.5);
}
.gallery-item {
transition: transform 0.3s ease;
border-radius: 8px;
overflow: hidden;
background: #2a2a2a;
}
.gallery-item:hover {
transform: scale(1.05);
box-shadow: 0 4px 15px rgba(0,0,0,0.4);
}
.tabs {
background: rgba(30, 30, 30, 0.95);
border-radius: 10px;
padding: 10px;
border: 1px solid rgba(255, 255, 255, 0.05);
}
button {
background: linear-gradient(45deg, #2196F3, #1976D2);
border: none;
border-radius: 5px;
transition: all 0.3s ease;
color: white;
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 15px rgba(33,150,243,0.3);
}
textarea, input[type="text"], input[type="number"] {
background: rgba(30, 30, 30, 0.95) !important;
border: 1px solid rgba(255, 255, 255, 0.1) !important;
color: #e0e0e0 !important;
border-radius: 5px !important;
}
label {
color: #e0e0e0 !important;
}
.gallery {
background: rgba(30, 30, 30, 0.95);
padding: 15px;
border-radius: 10px;
border: 1px solid rgba(255, 255, 255, 0.05);
}
"""
css = """
footer {
visibility: hidden;
}
""" + custom_css
# Gradio μΈν„°νŽ˜μ΄μŠ€ 생성
text_to_audio_tab = gr.Interface(
fn=text_to_audio,
inputs=[
gr.Textbox(label="Prompt(ν•œκΈ€μ§€μ›)" if translator else "Prompt"),
gr.Textbox(label="Negative Prompt"),
gr.Number(label="Seed", value=0),
gr.Number(label="Steps", value=25),
gr.Number(label="Guidance Scale", value=4.5),
gr.Number(label="Duration (sec)", value=8),
],
outputs=gr.Audio(label="Generated Audio"),
css=custom_css
)
video_to_audio_tab = gr.Interface(
fn=video_to_audio,
inputs=[
gr.Video(label="Input Video"),
gr.Textbox(label="Prompt(ν•œκΈ€μ§€μ›)" if translator else "Prompt"),
gr.Textbox(label="Negative Prompt", value="music"),
gr.Number(label="Seed", value=0),
gr.Number(label="Steps", value=25),
gr.Number(label="Guidance Scale", value=4.5),
gr.Number(label="Duration (sec)", value=8),
],
outputs=gr.Video(label="Generated Result"),
css=custom_css
)
video_search_tab = gr.Interface(
fn=search_videos,
inputs=gr.Textbox(label="Search Query(ν•œκΈ€μ§€μ›)" if translator else "Search Query"),
outputs=gr.Gallery(label="Search Results", columns=4, rows=20),
css=custom_css,
api_name=False
)
# 메인 μ‹€ν–‰
if __name__ == "__main__":
# λ²ˆμ—­κΈ° λ‘œλ“œ μ‹€νŒ¨ μ‹œ κ²½κ³  λ©”μ‹œμ§€
if translator is None:
log.warning("Translation model failed to load. Korean translation will be disabled.")
gr.TabbedInterface(
[video_search_tab, video_to_audio_tab, text_to_audio_tab],
["Video Search", "Video-to-Audio", "Text-to-Audio"],
theme="soft",
css=css
).launch(allowed_paths=[output_dir])