openfree commited on
Commit
43cb38b
·
verified ·
1 Parent(s): 914dc02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -220
app.py CHANGED
@@ -1,223 +1,244 @@
1
- import spaces
2
- import gradio as gr
3
- import os
4
- import numpy as np
5
- from pydub import AudioSegment
6
- import hashlib
7
- import io
8
- from sonic import Sonic
9
  from PIL import Image
10
- import torch
11
-
12
- # 초기 실행 시 필요한 모델들을 다운로드
13
- cmd = (
14
- 'python3 -m pip install "huggingface_hub[cli]" accelerate; '
15
- 'huggingface-cli download LeonJoe13/Sonic --local-dir checkpoints; '
16
- 'huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --local-dir checkpoints/stable-video-diffusion-img2vid-xt; '
17
- 'huggingface-cli download openai/whisper-tiny --local-dir checkpoints/whisper-tiny;'
 
 
 
18
  )
19
- os.system(cmd)
20
-
21
- pipe = Sonic()
22
-
23
- def get_md5(content_bytes: bytes):
24
- """MD5 해시를 계산하여 32자리 문자열을 반환"""
25
- return hashlib.md5(content_bytes).hexdigest()
26
-
27
- tmp_path = './tmp_path/'
28
- res_path = './res_path/'
29
- os.makedirs(tmp_path, exist_ok=True)
30
- os.makedirs(res_path, exist_ok=True)
31
-
32
- @spaces.GPU(duration=600) # 비디오 처리를 위해 duration 600초로 설정 (10분)
33
- def get_video_res(img_path, audio_path, res_video_path, dynamic_scale=1.0):
34
- """
35
- Sonic pipeline으로부터 실제 비디오를 생성하는 함수.
36
- 최대 60초 길이의 오디오에 대해 inference_steps를 결정하여,
37
- 얼굴 탐지 후 영상 생성 작업을 수행함.
38
- """
39
- expand_ratio = 0.0
40
- min_resolution = 512
41
-
42
- # 오디오 길이 계산
43
- audio = AudioSegment.from_file(audio_path)
44
- duration = len(audio) / 1000.0 # 초 단위
45
-
46
- # 오디오 길이에 따라 inference_steps 결정 (최소 25프레임 ~ 최대 750프레임)
47
- inference_steps = min(max(int(duration * 12.5), 25), 750)
48
- print(f"[INFO] Audio duration: {duration:.2f} seconds, using inference_steps={inference_steps}")
49
-
50
- # 얼굴 인식
51
- face_info = pipe.preprocess(img_path, expand_ratio=expand_ratio)
52
- print(f"[INFO] Face detection info: {face_info}")
53
-
54
- # 얼굴이 하나라도 검출되면 -> pipeline 진행
55
- if face_info['face_num'] > 0:
56
- os.makedirs(os.path.dirname(res_video_path), exist_ok=True)
57
- pipe.process(
58
- img_path,
59
- audio_path,
60
- res_video_path,
61
- min_resolution=min_resolution,
62
- inference_steps=inference_steps,
63
- dynamic_scale=dynamic_scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  )
65
- return res_video_path
66
- else:
67
- # 얼굴이 전혀 없으면 -1 리턴
68
- return -1
69
-
70
- def process_sonic(image, audio, dynamic_scale):
71
- """
72
- Gradio 인터페이스에서 호출되는 함수:
73
- 1. 이미지/오디오 검사
74
- 2. MD5 해시 -> 파일명
75
- 3. 캐시 검사 -> 없으면 영상 생성
76
- """
77
- if image is None:
78
- raise gr.Error("Please upload an image")
79
- if audio is None:
80
- raise gr.Error("Please upload an audio file")
81
-
82
- # (1) 이미지 MD5
83
- buf_img = io.BytesIO()
84
- image.save(buf_img, format="PNG")
85
- img_bytes = buf_img.getvalue()
86
- img_md5 = get_md5(img_bytes)
87
-
88
- # (2) 오디오 MD5
89
- sampling_rate, arr = audio[:2]
90
- if len(arr.shape) == 1:
91
- arr = arr[:, None]
92
- audio_segment = AudioSegment(
93
- arr.tobytes(),
94
- frame_rate=sampling_rate,
95
- sample_width=arr.dtype.itemsize,
96
- channels=arr.shape[1]
97
- )
98
- # Whisper 호환을 위해 mono/16kHz로 변환
99
- audio_segment = audio_segment.set_channels(1).set_frame_rate(16000)
100
-
101
- MAX_DURATION_MS = 60000
102
- if len(audio_segment) > MAX_DURATION_MS:
103
- audio_segment = audio_segment[:MAX_DURATION_MS]
104
-
105
- buf_audio = io.BytesIO()
106
- audio_segment.export(buf_audio, format="wav")
107
- audio_bytes = buf_audio.getvalue()
108
- audio_md5 = get_md5(audio_bytes)
109
-
110
- # (3) 파일 경로
111
- image_path = os.path.abspath(os.path.join(tmp_path, f'{img_md5}.png'))
112
- audio_path = os.path.abspath(os.path.join(tmp_path, f'{audio_md5}.wav'))
113
- res_video_path = os.path.abspath(os.path.join(res_path, f'{img_md5}_{audio_md5}_{dynamic_scale}.mp4'))
114
-
115
- if not os.path.exists(image_path):
116
- with open(image_path, "wb") as f:
117
- f.write(img_bytes)
118
- if not os.path.exists(audio_path):
119
- with open(audio_path, "wb") as f:
120
- f.write(audio_bytes)
121
-
122
- # (4) 캐싱된 결과가 있으면 재사용
123
- if os.path.exists(res_video_path):
124
- print(f"[INFO] Using cached result: {res_video_path}")
125
- return res_video_path
126
- else:
127
- print(f"[INFO] Generating new video with dynamic_scale={dynamic_scale}")
128
- video_result = get_video_res(image_path, audio_path, res_video_path, dynamic_scale)
129
- return video_result
130
-
131
- def get_example():
132
- return []
133
-
134
- css = """
135
- .gradio-container {
136
- font-family: 'Arial', sans-serif;
137
- }
138
- .main-header {
139
- text-align: center;
140
- color: #2a2a2a;
141
- margin-bottom: 2em;
142
- }
143
- .parameter-section {
144
- background-color: #f5f5f5;
145
- padding: 1em;
146
- border-radius: 8px;
147
- margin: 1em 0;
148
- }
149
- .example-section {
150
- margin-top: 2em;
151
- }
152
- """
153
-
154
- with gr.Blocks(css=css) as demo:
155
- gr.HTML("""
156
- <div class="main-header">
157
- <h1>🎭 Sonic: Advanced Portrait Animation</h1>
158
- <p>Transform still images into dynamic videos synchronized with audio (up to 1 minute)</p>
159
- </div>
160
- """)
161
-
162
- with gr.Row():
163
- with gr.Column():
164
- image_input = gr.Image(
165
- type='pil',
166
- label="Portrait Image",
167
- elem_id="image_input"
168
- )
169
- audio_input = gr.Audio(
170
- label="Voice/Audio Input (up to 1 minute)",
171
- elem_id="audio_input",
172
- type="numpy"
173
- )
174
- with gr.Column():
175
- dynamic_scale = gr.Slider(
176
- minimum=0.5,
177
- maximum=2.0,
178
- value=1.0,
179
- step=0.1,
180
- label="Animation Intensity",
181
- info="Adjust to control movement intensity (0.5: subtle, 2.0: dramatic)"
182
- )
183
- process_btn = gr.Button(
184
- "Generate Animation",
185
- variant="primary",
186
- elem_id="process_btn"
187
- )
188
-
189
- with gr.Column():
190
- video_output = gr.Video(
191
- label="Generated Animation",
192
- elem_id="video_output"
193
- )
194
-
195
- process_btn.click(
196
- fn=process_sonic,
197
- inputs=[image_input, audio_input, dynamic_scale],
198
- outputs=video_output,
199
- )
200
-
201
- gr.Examples(
202
- examples=get_example(),
203
- fn=process_sonic,
204
- inputs=[image_input, audio_input, dynamic_scale],
205
- outputs=video_output,
206
- cache_examples=False
207
- )
208
-
209
- gr.HTML("""
210
- <div style="text-align: center; margin-top: 2em;">
211
- <div style="margin-bottom: 1em;">
212
- <a href="https://github.com/jixiaozhong/Sonic" target="_blank" style="text-decoration: none;">
213
- <img src="https://img.shields.io/badge/GitHub-Repo-blue?style=for-the-badge&logo=github" alt="GitHub Repo">
214
- </a>
215
- <a href="https://arxiv.org/pdf/2411.16331" target="_blank" style="text-decoration: none;">
216
- <img src="https://img.shields.io/badge/Paper-arXiv-red?style=for-the-badge&logo=arxiv" alt="arXiv Paper">
217
- </a>
218
- </div>
219
- <p>🔔 Note: For optimal results, use clear portrait images and high-quality audio (now supports up to 1 minute!)</p>
220
- </div>
221
- """)
222
-
223
- demo.launch(share=True)
 
1
+ import os, math, torch, cv2
 
 
 
 
 
 
 
2
  from PIL import Image
3
+ from omegaconf import OmegaConf
4
+ from tqdm import tqdm
5
+
6
+ from diffusers import AutoencoderKLTemporalDecoder
7
+ from diffusers.schedulers import EulerDiscreteScheduler
8
+ from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor
9
+
10
+ from src.utils.util import save_videos_grid, seed_everything
11
+ from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor
12
+ from src.models.base.unet_spatio_temporal_condition import (
13
+ UNetSpatioTemporalConditionModel, add_ip_adapters,
14
  )
15
+ from src.pipelines.pipeline_sonic import SonicPipeline
16
+ from src.models.audio_adapter.audio_proj import AudioProjModel
17
+ from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel
18
+ from src.utils.RIFE.RIFE_HDv3 import RIFEModel
19
+ from src.dataset.face_align.align import AlignImage
20
+
21
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
22
+
23
+
24
+ # ------------------------------------------------------------------
25
+ # single image + speech → video-tensor generator
26
+ # ------------------------------------------------------------------
27
+ def test(
28
+ pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
29
+ width, height, batch,
30
+ ):
31
+ # ---- 배치 차원 맞추기 -----------------------------------------
32
+ for k, v in batch.items():
33
+ if isinstance(v, torch.Tensor):
34
+ batch[k] = v.unsqueeze(0).to(pipe.device).float()
35
+
36
+ ref_img = batch["ref_img"]
37
+ clip_img = batch["clip_images"]
38
+ face_mask = batch["face_mask"]
39
+ image_embeds = image_encoder(clip_img).image_embeds # (1,1024)
40
+
41
+ audio_feature = batch["audio_feature"] # (1, 80, T)
42
+ audio_len = int(batch["audio_len"])
43
+ step = int(config.step)
44
+
45
+ window = 16_000 # 1-sec chunks
46
+ audio_prompts, last_prompts = [], []
47
+
48
+ for i in range(0, audio_feature.shape[-1], window):
49
+ chunk = audio_feature[:, :, i : i + window] # (1, 80, win)
50
+ layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
51
+ last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
52
+ audio_prompts.append(torch.stack(layers, dim=2)) # (1, w, L, 384)
53
+ last_prompts.append(last)
54
+
55
+ if not audio_prompts:
56
+ raise ValueError("[ERROR] No speech recognised in the provided audio.")
57
+
58
+ audio_prompts = torch.cat(audio_prompts, dim=1)
59
+ last_prompts = torch.cat(last_prompts, dim=1)
60
+
61
+ # padding 규칙
62
+ audio_prompts = torch.cat(
63
+ [torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
64
+ torch.zeros_like(audio_prompts[:, :6])], dim=1)
65
+ last_prompts = torch.cat(
66
+ [torch.zeros_like(last_prompts[:, :24]), last_prompts,
67
+ torch.zeros_like(last_prompts[:, :26])], dim=1)
68
+
69
+ total_tokens = audio_prompts.shape[1]
70
+ num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
71
+
72
+ ref_list, audio_list, uncond_list, motion_buckets = [], [], [], []
73
+
74
+ for i in tqdm(range(num_chunks)):
75
+ start = i * 2 * step
76
+
77
+ # ------------ cond_clip : (1,1,10,5,384) ------------------
78
+ clip_raw = audio_prompts[:, start : start + 10] # (1, ≤10, L, 384)
79
+
80
+ # ★ W-padding은 dim=1 이어야 함!
81
+ if clip_raw.shape[1] < 10:
82
+ pad_w = torch.zeros_like(clip_raw[:, : 10 - clip_raw.shape[1]])
83
+ clip_raw = torch.cat([clip_raw, pad_w], dim=1)
84
+
85
+ # ★ L-padding은 dim=2
86
+ while clip_raw.shape[2] < 5:
87
+ clip_raw = torch.cat([clip_raw, clip_raw[:, :, -1:]], dim=2)
88
+ clip_raw = clip_raw[:, :, :5] # (1,10,5,384)
89
+
90
+ cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
91
+
92
+ # ------------ bucket_clip : (1,1,50,1,384) -----------------
93
+ bucket_raw = last_prompts[:, start : start + 50]
94
+ if bucket_raw.shape[1] < 50: # ★ dim=1
95
+ pad_w = torch.zeros_like(bucket_raw[:, : 50 - bucket_raw.shape[1]])
96
+ bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
97
+ bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
98
+
99
+ motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
100
+
101
+ ref_list.append(ref_img[0])
102
+ audio_list.append(audio_pe(cond_clip).squeeze(0)) # (50,1024)
103
+ uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0))
104
+ motion_buckets.append(motion[0])
105
+
106
+ # ---- Stable-Video-Diffusion 호출 ------------------------------
107
+ video = pipe(
108
+ ref_img, clip_img, face_mask,
109
+ audio_list, uncond_list, motion_buckets,
110
+ height=height, width=width,
111
+ num_frames=len(audio_list),
112
+ decode_chunk_size=config.decode_chunk_size,
113
+ motion_bucket_scale=config.motion_bucket_scale,
114
+ fps=config.fps,
115
+ noise_aug_strength=config.noise_aug_strength,
116
+ min_guidance_scale1=config.min_appearance_guidance_scale,
117
+ max_guidance_scale1=config.max_appearance_guidance_scale,
118
+ min_guidance_scale2=config.audio_guidance_scale,
119
+ max_guidance_scale2=config.audio_guidance_scale,
120
+ overlap=config.overlap,
121
+ shift_offset=config.shift_offset,
122
+ frames_per_batch=config.n_sample_frames,
123
+ num_inference_steps=config.num_inference_steps,
124
+ i2i_noise_strength=config.i2i_noise_strength,
125
+ ).frames
126
+
127
+ video = (video * 0.5 + 0.5).clamp(0, 1)
128
+ return video.to(pipe.device).unsqueeze(0).cpu()
129
+
130
+
131
+ # ------------------------------------------------------------------
132
+ # Sonic 클래스
133
+ # ------------------------------------------------------------------
134
+ class Sonic:
135
+ config_file = os.path.join(BASE_DIR, "config/inference/sonic.yaml")
136
+ config = OmegaConf.load(config_file)
137
+
138
+ def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True):
139
+ cfg = self.config
140
+ cfg.use_interframe = enable_interpolate_frame
141
+ self.device = f"cuda:{device_id}" if device_id >= 0 and torch.cuda.is_available() else "cpu"
142
+ cfg.pretrained_model_name_or_path = os.path.join(BASE_DIR, cfg.pretrained_model_name_or_path)
143
+
144
+ self._load_models(cfg)
145
+ print("Sonic init done")
146
+
147
+ # --------------------------------------------------------------
148
+ def _load_models(self, cfg):
149
+ dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}[cfg.weight_dtype]
150
+
151
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
152
+ sched = EulerDiscreteScheduler.from_pretrained (cfg.pretrained_model_name_or_path, subfolder="scheduler")
153
+ img_e = CLIPVisionModelWithProjection.from_pretrained (cfg.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
154
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", variant="fp16")
155
+ add_ip_adapters(unet, [32], [cfg.ip_audio_scale])
156
+
157
+ a2t = AudioProjModel(10, 5, 384, 1024, 1024, 32).to(self.device)
158
+ a2b = Audio2bucketModel(50, 1, 384, 1024, 1024, 1, 2).to(self.device)
159
+
160
+ unet.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.unet_checkpoint_path), map_location="cpu"))
161
+ a2t.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2token_checkpoint_path), map_location="cpu"))
162
+ a2b.load_state_dict(torch.load(os.path.join(BASE_DIR, cfg.audio2bucket_checkpoint_path), map_location="cpu"))
163
+
164
+ whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny")).to(self.device).eval()
165
+ whisper.requires_grad_(False)
166
+
167
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, "checkpoints/whisper-tiny"))
168
+ self.face_det = AlignImage(self.device, det_path=os.path.join(BASE_DIR, "checkpoints/yoloface_v5m.pt"))
169
+ if cfg.use_interframe:
170
+ self.rife = RIFEModel(device=self.device)
171
+ self.rife.load_model(os.path.join(BASE_DIR, "checkpoints/RIFE/"))
172
+
173
+ img_e.to(dtype); vae.to(dtype); unet.to(dtype)
174
+
175
+ self.pipe = SonicPipeline(unet=unet, image_encoder=img_e, vae=vae, scheduler=sched).to(device=self.device, dtype=dtype)
176
+ self.image_encoder = img_e
177
+ self.audio2token = a2t
178
+ self.audio2bucket = a2b
179
+ self.whisper = whisper
180
+
181
+ # --------------------------------------------------------------
182
+ def preprocess(self, img_path: str, expand_ratio: float = 1.0):
183
+ img = cv2.imread(img_path)
184
+ h, w = img.shape[:2]
185
+ _, _, faces = self.face_det(img, maxface=True)
186
+ if faces:
187
+ x1, y1, ww, hh = faces[0]
188
+ return {"face_num": 1, "crop_bbox": process_bbox((x1, y1, x1 + ww, y1 + hh), expand_ratio, h, w)}
189
+ return {"face_num": 0, "crop_bbox": None}
190
+
191
+ # --------------------------------------------------------------
192
+ @torch.no_grad()
193
+ def process(
194
+ self,
195
+ img_path: str,
196
+ audio_path:str,
197
+ out_path: str,
198
+ min_resolution: int = 512,
199
+ inference_steps:int = 25,
200
+ dynamic_scale: float = 1.0,
201
+ keep_resolution: bool = False,
202
+ seed: int | None = None,
203
+ ):
204
+ cfg = self.config
205
+ if seed is not None: cfg.seed = seed
206
+ cfg.num_inference_steps = inference_steps
207
+ cfg.motion_bucket_scale = dynamic_scale
208
+ seed_everything(cfg.seed)
209
+
210
+ sample = image_audio_to_tensor(
211
+ self.face_det, self.feature_extractor,
212
+ img_path, audio_path,
213
+ limit=-1, image_size=min_resolution, area=cfg.area,
214
+ )
215
+ if sample is None:
216
+ return -1
217
+
218
+ h, w = sample["ref_img"].shape[-2:]
219
+ resolution = (f"{(Image.open(img_path).width //2)*2}x{(Image.open(img_path).height//2)*2}"
220
+ if keep_resolution else f"{w}x{h}")
221
+
222
+ video = test(
223
+ self.pipe, cfg, self.whisper, self.audio2token,
224
+ self.audio2bucket, self.image_encoder,
225
+ w, h, sample,
226
+ )
227
+
228
+ if cfg.use_interframe:
229
+ out = video.to(self.device)
230
+ frames = []
231
+ for i in tqdm(range(out.shape[2] - 1), ncols=0):
232
+ mid = self.rife.inference(out[:, :, i], out[:, :, i + 1]).clamp(0, 1).detach()
233
+ frames.extend([out[:, :, i], mid])
234
+ frames.append(out[:, :, -1])
235
+ video = torch.stack(frames, 2).cpu()
236
+
237
+ tmp = out_path.replace(".mp4", "_noaudio.mp4")
238
+ save_videos_grid(video, tmp, n_rows=video.shape[0], fps=cfg.fps * (2 if cfg.use_interframe else 1))
239
+ os.system(
240
+ f"ffmpeg -i '{tmp}' -i '{audio_path}' -s {resolution} "
241
+ f"-vcodec libx264 -acodec aac -crf 18 -shortest '{out_path}' -y -loglevel error"
242
  )
243
+ os.remove(tmp)
244
+ return 0