Cun-Duck commited on
Commit
b2077e8
Β·
verified Β·
1 Parent(s): a5758ff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +323 -0
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+
5
+ # 1. Instalasi Dependencies (Pastikan ini dijalankan hanya jika diperlukan)
6
+ # Cek apakah dependencies sudah terinstall
7
+ def check_dependencies():
8
+ try:
9
+ import torch
10
+ import transformers
11
+ import datasets
12
+ import librosa
13
+ import numpy
14
+ import scipy
15
+ import ffmpeg
16
+ import gradio
17
+ import huggingface_hub
18
+ return True
19
+ except ImportError:
20
+ return False
21
+
22
+ if not check_dependencies():
23
+ # Install pytorch (CPU version)
24
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==1.12.1+cpu", "torchvision==0.13.1+cpu", "torchaudio==0.12.1", "--extra-index-url", "https://download.pytorch.org/whl/cpu"])
25
+
26
+ # Install other dependencies
27
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers==4.24.0", "datasets==2.7.1", "librosa==0.9.2", "numpy==1.23.4", "scipy==1.9.3", "ffmpeg-python==0.2.0", "gradio==3.10.1", "huggingface_hub==0.11.0"])
28
+
29
+ # Install non-pip dependencies
30
+ os.system("apt-get update && apt-get install -y ffmpeg")
31
+
32
+ # 2. Impor Libraries
33
+ import torch
34
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
35
+ from datasets import load_dataset, Audio
36
+ import librosa
37
+ import numpy as np
38
+ from scipy.io import wavfile
39
+ import ffmpeg
40
+ import gradio as gr
41
+ from huggingface_hub import HfApi, HfFolder
42
+
43
+ # 3. Konfigurasi Hugging Face Hub
44
+ # Dapatkan token dari environment variable (lebih aman)
45
+ HF_TOKEN = os.environ.get("HF_TOKEN") # Gunakan secrets HF_TOKEN pada pengaturan HF Spaces
46
+
47
+ # Atau, jika Anda ingin hardcode token (tidak disarankan untuk production)
48
+ # HF_TOKEN = "YOUR_HUGGINGFACE_TOKEN"
49
+
50
+ # Konfigurasi repository
51
+ repo_id = "your_username/your_model_repo" # Ganti dengan username dan nama repo Anda
52
+ model_filename = "lipsync_model.pth"
53
+
54
+ # Inisialisasi HfApi
55
+ api = HfApi()
56
+
57
+ # Login ke Hugging Face Hub (jika belum)
58
+ if HF_TOKEN:
59
+ api.set_access_token(HF_TOKEN)
60
+ # Atau bisa juga menggunakan:
61
+ # HfFolder.save_token(HF_TOKEN);
62
+ # api.set_access_token(HfFolder.get_token())
63
+ else:
64
+ print("HF_TOKEN not found. Model will not be uploaded.")
65
+
66
+ # 4. Definisi Model dan Fungsi-Fungsi
67
+
68
+ # Model ASR (sama seperti sebelumnya)
69
+ asr_model_name = "facebook/wav2vec2-base-960h"
70
+ asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)
71
+ asr_processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
72
+
73
+ # Placeholder untuk model lipsync (Model yang lebih ringan dan efisien)
74
+ class LipSyncModel(torch.nn.Module):
75
+ def __init__(self):
76
+ super().__init__()
77
+ # Arsitektur yang lebih sederhana:
78
+ self.fc1 = torch.nn.Linear(512, 256) # Reduced input features
79
+ self.relu = torch.nn.ReLU()
80
+ self.fc2 = torch.nn.Linear(256, 128 * 3 * 32 * 32) # Reduced output size
81
+
82
+ def forward(self, x):
83
+ x = self.fc1(x)
84
+ x = self.relu(x)
85
+ x = self.fc2(x)
86
+ x = x.view(-1, 3, 32, 32) # Reduced frame size: 32x32
87
+ return x
88
+
89
+ lipsync_model = LipSyncModel()
90
+ optimizer = torch.optim.Adam(lipsync_model.parameters(), lr=5e-5)
91
+ criterion = torch.nn.MSELoss()
92
+
93
+ # Fungsi untuk mengekstrak fitur audio (sama seperti sebelumnya)
94
+ def extract_audio_features(audio_file):
95
+ audio, sr = librosa.load(audio_file, sr=asr_processor.feature_extractor.sampling_rate, mono=True) # Ensure mono audio
96
+ inputs = asr_processor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
97
+
98
+ with torch.no_grad():
99
+ # Get hidden states from a specific layer (before the output layer)
100
+ # Note: Wav2Vec2 might not provide hidden features directly.
101
+ # You may need to modify the model to obtain the desired features.
102
+ # Alternatively, use MFCCs:
103
+ mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=16, hop_length=512)
104
+ mfccs = torch.tensor(mfccs.T).float()[:512, :] # Limit feature size, adjust as needed
105
+ return mfccs
106
+
107
+ # Fungsi untuk memproses video dan audio (sama seperti sebelumnya)
108
+ def process_video(video_file, audio_file):
109
+ # 1. Ekstrak audio dari video (jika video memiliki audio)
110
+ if audio_file is None:
111
+ try:
112
+ audio_file = "temp_audio.wav"
113
+ (
114
+ ffmpeg.input(video_file)
115
+ .output(audio_file, acodec="pcm_s16le", ar="16000", ac=1) # Convert to mono
116
+ .run(overwrite_output=True, quiet=True)
117
+ )
118
+ except ffmpeg.Error as e:
119
+ print(f"Error extracting audio from {video_file}: {e.stderr.decode()}")
120
+ return None, None
121
+
122
+ # 2. Ekstrak frame dari video
123
+ probe = ffmpeg.probe(video_file)
124
+ video_info = next(s for s in probe['streams'] if s['codec_type'] == 'video')
125
+ width = int(video_info['width'])
126
+ height = int(video_info['height'])
127
+ num_frames = int(video_info['nb_frames'])
128
+ fps = eval(video_info['r_frame_rate'])
129
+
130
+ frames, _, _ = (
131
+ ffmpeg.input(video_file)
132
+ .output("pipe:", format="rawvideo", pix_fmt="rgb24", s="32x32") # Downsample to 32x32
133
+ .run(capture_stdout=True, quiet=True)
134
+ )
135
+ frames = np.frombuffer(frames, np.uint8).reshape([-1, 32, 32, 3])
136
+ frames = torch.tensor(frames).permute(0, 3, 1, 2).float() / 255.0
137
+
138
+ # 3. Ekstrak fitur audio
139
+ audio_features = extract_audio_features(audio_file)
140
+
141
+ return frames, audio_features, fps
142
+
143
+ # Fungsi untuk melatih model lipsync
144
+ def train_lipsync_model(video_file, audio_file, epochs=5):
145
+ frames, audio_features, fps = process_video(video_file, audio_file)
146
+
147
+ if frames is None or audio_features is None:
148
+ print("Skipping training due to error in video or audio processing.")
149
+ return
150
+
151
+ for epoch in range(epochs):
152
+ optimizer.zero_grad()
153
+
154
+ # Sesuaikan ukuran audio features
155
+ num_frames = frames.shape[0]
156
+
157
+ # Reduce the number of frames to match the audio features, if necessary
158
+ if num_frames > audio_features.shape[0]:
159
+ frames = frames[:audio_features.shape[0]]
160
+ num_frames = audio_features.shape[0]
161
+
162
+ # Pad audio features if they are shorter than the number of frames
163
+ if audio_features.shape[0] < num_frames:
164
+ padding_size = num_frames - audio_features.shape[0]
165
+ padding = audio_features[-1,:].repeat(padding_size, 1)
166
+ audio_features_padded = torch.cat((audio_features, padding), dim=0)
167
+ else:
168
+ audio_features_padded = audio_features
169
+
170
+ # Generate video frame
171
+ generated_frames = lipsync_model(audio_features_padded)
172
+
173
+ # Hitung loss
174
+ loss = criterion(generated_frames, frames)
175
+
176
+ # Backpropagation dan optimasi
177
+ loss.backward()
178
+ optimizer.step()
179
+
180
+ print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
181
+
182
+ # Simpan dan upload model setelah pelatihan
183
+ if HF_TOKEN:
184
+ save_and_upload_model()
185
+
186
+ # Fungsi untuk inference (sama seperti sebelumnya)
187
+ def lipsync_inference(video_file, audio_file, output_file="output.mp4"):
188
+ frames, audio_features, fps = process_video(video_file, audio_file)
189
+
190
+ if frames is None or audio_features is None:
191
+ print("Error during video or audio processing.")
192
+ return None
193
+
194
+ with torch.no_grad():
195
+ num_frames = frames.shape[0]
196
+
197
+ # Reduce the number of frames to match the audio features, if necessary
198
+ if num_frames > audio_features.shape[0]:
199
+ frames = frames[:audio_features.shape[0]]
200
+ num_frames = audio_features.shape[0]
201
+
202
+ # Pad audio features if they are shorter than the number of frames
203
+ if audio_features.shape[0] < num_frames:
204
+ padding_size = num_frames - audio_features.shape[0]
205
+ padding = audio_features[-1,:].repeat(padding_size, 1)
206
+ audio_features_padded = torch.cat((audio_features, padding), dim=0)
207
+ else:
208
+ audio_features_padded = audio_features
209
+
210
+ generated_frames = lipsync_model(audio_features_padded)
211
+
212
+ # Convert tensor to numpy array
213
+ generated_frames = (generated_frames * 255).byte().permute(0, 2, 3, 1).cpu().numpy()
214
+
215
+ # Simpan video hasil inference
216
+ temp_video = "temp_output.mp4"
217
+ (
218
+ ffmpeg.input(
219
+ "pipe:",
220
+ format="rawvideo",
221
+ pix_fmt="rgb24",
222
+ s=f"{generated_frames.shape[2]}x{generated_frames.shape[1]}",
223
+ r=fps,
224
+ )
225
+ .output(temp_video, pix_fmt="yuv420p", vcodec="libx264", crf=28)
226
+ .overwrite_output()
227
+ .run(input=generated_frames.tobytes(), quiet=True)
228
+ )
229
+
230
+ # Gabungkan audio baru dengan video
231
+ (
232
+ ffmpeg.input(temp_video)
233
+ .input(audio_file)
234
+ .output(output_file, c="copy", map="0:v:0", map="1:a:0")
235
+ .overwrite_output()
236
+ .run(quiet=True)
237
+ )
238
+
239
+ os.remove(temp_video)
240
+ print(f"Video hasil lipsync disimpan di: {output_file}")
241
+ return output_file
242
+
243
+ # 5. Fungsi untuk menyimpan dan mengupload model
244
+ def save_and_upload_model():
245
+ # Create repo if it doesn't exist
246
+ try:
247
+ api.create_repo(repo_id=repo_id, token=HF_TOKEN, private=True, exist_ok=True) # repo dibuat private agar lebih aman
248
+ except Exception as e:
249
+ print(f"Error creating repo: {e}")
250
+
251
+ # Simpan model secara lokal
252
+ torch.save(lipsync_model.state_dict(), model_filename)
253
+ print(f"Model saved locally to {model_filename}")
254
+
255
+ # Upload model ke Hugging Face Hub
256
+ try:
257
+ api.upload_file(
258
+ path_or_fileobj=model_filename,
259
+ path_in_repo=model_filename,
260
+ repo_id=repo_id,
261
+ token=HF_TOKEN,
262
+ )
263
+ print(f"Model uploaded to {repo_id}/{model_filename}")
264
+ except Exception as e:
265
+ print(f"Error uploading model: {e}")
266
+
267
+ # 6. Fungsi untuk mengunduh dan memuat model
268
+ def download_and_load_model():
269
+ try:
270
+ model_path = api.model_info(repo_id=repo_id, token=HF_TOKEN).siblings[0].rfilename
271
+ api.download_file(
272
+ path_or_fileobj=model_filename,
273
+ path_in_repo=model_path,
274
+ repo_id=repo_id,
275
+ token=HF_TOKEN,
276
+ local_dir="."
277
+ )
278
+ lipsync_model.load_state_dict(torch.load(model_filename))
279
+ print("Model loaded from Hugging Face Hub")
280
+ except Exception as e:
281
+ print(f"Error loading model: {e}")
282
+ print("Starting with a fresh model.")
283
+
284
+ # 7. Antarmuka Gradio
285
+ def run_app(input_video, input_audio, output_video):
286
+
287
+ # Coba untuk load model dari HF Hub
288
+ if HF_TOKEN:
289
+ download_and_load_model()
290
+
291
+ # save files to path
292
+ input_video_path = "input_video.mp4"
293
+ input_audio_path = "input_audio.wav"
294
+
295
+ with open(input_video_path, "wb") as f:
296
+ f.write(input_video.getbuffer())
297
+ with open(input_audio_path, "wb") as f:
298
+ f.write(input_audio.getbuffer())
299
+
300
+ # Lakukan pelatihan selama 5 epoch
301
+ train_lipsync_model(input_video_path, input_audio_path, epochs=5)
302
+
303
+ output_video = lipsync_inference(input_video_path, input_audio_path, output_video)
304
+
305
+ # remove files from path
306
+ os.remove(input_video_path)
307
+ os.remove(input_audio_path)
308
+
309
+ return output_video
310
+
311
+ input_video = gr.inputs.Video(type="file", label="Input Video")
312
+ input_audio = gr.inputs.Audio(type="file", label="Input Audio")
313
+ output_video = "output_video.mp4"
314
+
315
+ iface = gr.Interface(
316
+ fn=run_app,
317
+ inputs=[input_video, input_audio],
318
+ outputs="video",
319
+ title="LipSync AI on CPU",
320
+ description="Ubah audio dari video menggunakan AI Lipsync (CPU Version).",
321
+ )
322
+
323
+ iface.launch(debug=True)