Spaces:
Running
on
Zero
Running
on
Zero
Update sonic.py
Browse files
sonic.py
CHANGED
@@ -73,31 +73,47 @@ def test(pipe, cfg, wav_enc, audio_pe, audio2bucket, image_encoder,
|
|
73 |
|
74 |
ref_list, audio_list, uncond_list, motion_buckets = [], [], [], []
|
75 |
|
|
|
76 |
for i in tqdm(range(num_chunks)):
|
77 |
start = i * 2 * step
|
78 |
|
79 |
-
#
|
80 |
-
cond_clip
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
95 |
|
96 |
ref_list.append(ref_img[0])
|
97 |
-
audio_list.append(audio_pe(cond_clip).squeeze(0)[0])
|
98 |
uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
|
99 |
motion_buckets.append(motion[0])
|
100 |
|
|
|
|
|
|
|
|
|
101 |
# -------- diffusion --------------------------------------------------
|
102 |
video = pipe(
|
103 |
ref_img, clip_img, face_mask,
|
|
|
73 |
|
74 |
ref_list, audio_list, uncond_list, motion_buckets = [], [], [], []
|
75 |
|
76 |
+
|
77 |
for i in tqdm(range(num_chunks)):
|
78 |
start = i * 2 * step
|
79 |
|
80 |
+
# ------------------------------------------------------------
|
81 |
+
# cond_clip : (bz, f=1, w=10, b=5, c=384)
|
82 |
+
# bucket_clip: (bz, f=1, w=50, b=1, c=384)
|
83 |
+
# Whisper-tiny 는 hidden_state 층 수가 2 → 5 로 패딩
|
84 |
+
# ------------------------------------------------------------
|
85 |
+
clip_raw = audio_prompts[:, start:start + 10] # (1, ≤10, L, 384)
|
86 |
+
if clip_raw.shape[1] < 10: # w 패딩
|
87 |
+
pad_w = torch.zeros_like(clip_raw[:, :10 - clip_raw.shape[1]])
|
88 |
+
clip_raw = torch.cat([clip_raw, pad_w], dim=1)
|
89 |
+
|
90 |
+
# ---- L(=layers) 패딩: 부족하면 마지막 layer 를 반복 ----------
|
91 |
+
L_now = clip_raw.shape[2]
|
92 |
+
if L_now < 5:
|
93 |
+
pad_L = clip_raw[:, :, -1:].repeat(1, 1, 5 - L_now, 1)
|
94 |
+
clip_raw = torch.cat([clip_raw, pad_L], dim=2)
|
95 |
+
clip_raw = clip_raw[:, :, :5] # (1,10,5,384)
|
96 |
+
|
97 |
+
cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
|
98 |
+
|
99 |
+
# ------------------------------------------------------------
|
100 |
+
bucket_raw = last_prompts[:, start:start + 50] # (1, ≤50, 1, 384)
|
101 |
+
if bucket_raw.shape[1] < 50:
|
102 |
+
pad_w = torch.zeros_like(bucket_raw[:, :50 - bucket_raw.shape[1]])
|
103 |
+
bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
|
104 |
+
bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
|
105 |
|
106 |
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
107 |
|
108 |
ref_list.append(ref_img[0])
|
109 |
+
audio_list.append(audio_pe(cond_clip).squeeze(0)[0])
|
110 |
uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0)[0])
|
111 |
motion_buckets.append(motion[0])
|
112 |
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
# -------- diffusion --------------------------------------------------
|
118 |
video = pipe(
|
119 |
ref_img, clip_img, face_mask,
|