Spanicin commited on
Commit
6abf199
·
verified ·
1 Parent(s): a709b8b

Upload app_parallel.py

Browse files
Files changed (1) hide show
  1. app_parallel.py +359 -0
app_parallel.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, stream_with_context
2
+ import torch
3
+ import shutil
4
+ import os
5
+ import sys
6
+ from time import strftime
7
+ from src.utils.preprocess import CropAndExtract
8
+ from src.test_audio2coeff import Audio2Coeff
9
+ from src.facerender.animate import AnimateFromCoeff
10
+ from src.generate_batch import get_data
11
+ from src.generate_facerender_batch import get_facerender_data
12
+ # from src.utils.init_path import init_path
13
+ import tempfile
14
+ from openai import OpenAI
15
+ import elevenlabs
16
+ from elevenlabs import set_api_key, generate, play, clone, Voice, VoiceSettings
17
+ import uuid
18
+ import time
19
+ from PIL import Image
20
+ import moviepy.editor as mp
21
+ import requests
22
+ import json
23
+ import pickle
24
+ from dotenv import load_dotenv
25
+ from concurrent.futures import ProcessPoolExecutor, as_completed
26
+
27
+ # Load environment variables from .env file
28
+ load_dotenv()
29
+
30
+ # Initialize ProcessPoolExecutor for parallel processing
31
+ executor = ProcessPoolExecutor(max_workers=3)
32
+
33
+
34
+ class AnimationConfig:
35
+ def __init__(self, driven_audio_path, source_image_path, result_folder,pose_style,expression_scale,enhancer,still,preprocess,ref_pose_video_path, image_hardcoded):
36
+ self.driven_audio = driven_audio_path
37
+ self.source_image = source_image_path
38
+ self.ref_eyeblink = None
39
+ self.ref_pose = ref_pose_video_path
40
+ self.checkpoint_dir = './checkpoints'
41
+ self.result_dir = result_folder
42
+ self.pose_style = pose_style
43
+ self.batch_size = 2
44
+ self.expression_scale = expression_scale
45
+ self.input_yaw = None
46
+ self.input_pitch = None
47
+ self.input_roll = None
48
+ self.enhancer = enhancer
49
+ self.background_enhancer = None
50
+ self.cpu = False
51
+ self.face3dvis = False
52
+ self.still = still
53
+ self.preprocess = preprocess
54
+ self.verbose = False
55
+ self.old_version = False
56
+ self.net_recon = 'resnet50'
57
+ self.init_path = None
58
+ self.use_last_fc = False
59
+ self.bfm_folder = './checkpoints/BFM_Fitting/'
60
+ self.bfm_model = 'BFM_model_front.mat'
61
+ self.focal = 1015.
62
+ self.center = 112.
63
+ self.camera_d = 10.
64
+ self.z_near = 5.
65
+ self.z_far = 15.
66
+ self.device = 'cpu'
67
+ self.image_hardcoded = image_hardcoded
68
+
69
+
70
+ app = Flask(__name__)
71
+ # CORS(app)
72
+
73
+ TEMP_DIR = None
74
+ start_time = None
75
+
76
+ app.config['temp_response'] = None
77
+ app.config['generation_thread'] = None
78
+ app.config['text_prompt'] = None
79
+ app.config['final_video_path'] = None
80
+ app.config['final_video_duration'] = None
81
+
82
+ # Global paths
83
+ dir_path = os.path.dirname(os.path.realpath(__file__))
84
+ current_root_path = dir_path
85
+
86
+ path_of_lm_croper = os.path.join(current_root_path, 'checkpoints', 'shape_predictor_68_face_landmarks.dat')
87
+ path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')
88
+ dir_of_BFM_fitting = os.path.join(current_root_path, 'checkpoints', 'BFM_Fitting')
89
+ wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')
90
+ audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')
91
+ audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
92
+ audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')
93
+ audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
94
+ free_view_checkpoint = os.path.join(current_root_path, 'checkpoints', 'facevid2vid_00189-model.pth.tar')
95
+
96
+
97
+ # Function for running the actual task (using preprocessed data)
98
+ def process_chunk(audio_chunk, preprocessed_data, args):
99
+ print("Entered Process Chunk Function")
100
+ global audio2pose_checkpoint, audio2pose_yaml_path, audio2exp_checkpoint, audio2exp_yaml_path, wav2lip_checkpoint
101
+ global free_view_checkpoint
102
+ if args.preprocess == 'full':
103
+ mapping_checkpoint = os.path.join(current_root_path, 'checkpoints', 'mapping_00109-model.pth.tar')
104
+ facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender_still.yaml')
105
+ else:
106
+ mapping_checkpoint = os.path.join(current_root_path, 'checkpoints', 'mapping_00229-model.pth.tar')
107
+ facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml')
108
+
109
+ first_coeff_path = preprocessed_data["first_coeff_path"]
110
+ crop_pic_path = preprocessed_data["crop_pic_path"]
111
+ crop_info = preprocessed_data["crop_info"]
112
+
113
+ print("first_coeff_path",first_coeff_path)
114
+ print("crop_pic_path",crop_pic_path)
115
+ print("crop_info",crop_info)
116
+
117
+ batch = get_data(first_coeff_path, audio_chunk, args.device, ref_eyeblink_coeff_path=None, still=args.still)
118
+ audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path,
119
+ audio2exp_checkpoint, audio2exp_yaml_path,
120
+ wav2lip_checkpoint, args.device)
121
+ coeff_path = audio_to_coeff.generate(batch, args.result_dir, args.pose_style, ref_pose_coeff_path=None)
122
+
123
+ # Further processing with animate_from_coeff using the coeff_path
124
+ animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint,
125
+ facerender_yaml_path, args.device)
126
+
127
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_chunk,
128
+ args.batch_size, args.input_yaw, args.input_pitch, args.input_roll,
129
+ expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess)
130
+
131
+ print("Will Enter Animation")
132
+ result, base64_video, temp_file_path, _ = animate_from_coeff.generate(data, args.result_dir, args.source_image, crop_info,
133
+ enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess)
134
+
135
+ video_clip = mp.VideoFileClip(temp_file_path)
136
+ duration = video_clip.duration
137
+
138
+ app.config['temp_response'] = base64_video
139
+ app.config['final_video_path'] = temp_file_path
140
+ app.config['final_video_duration'] = duration
141
+
142
+ return base64_video, temp_file_path, duration
143
+
144
+
145
+ def create_temp_dir():
146
+ return tempfile.TemporaryDirectory()
147
+
148
+ def save_uploaded_file(file, filename,TEMP_DIR):
149
+ print("Entered save_uploaded_file")
150
+ unique_filename = str(uuid.uuid4()) + "_" + filename
151
+ file_path = os.path.join(TEMP_DIR.name, unique_filename)
152
+ file.save(file_path)
153
+ return file_path
154
+
155
+
156
+ def custom_cleanup(temp_dir, exclude_dir):
157
+ # Iterate over the files and directories in TEMP_DIR
158
+ for filename in os.listdir(temp_dir):
159
+ file_path = os.path.join(temp_dir, filename)
160
+ # Skip the directory we want to exclude
161
+ if file_path != exclude_dir:
162
+ try:
163
+ if os.path.isdir(file_path):
164
+ shutil.rmtree(file_path)
165
+ else:
166
+ os.remove(file_path)
167
+ print(f"Deleted: {file_path}")
168
+ except Exception as e:
169
+ print(f"Failed to delete {file_path}. Reason: {e}")
170
+
171
+
172
+ def generate_audio(voice_cloning, voice_gender, text_prompt):
173
+ print("generate_audio")
174
+ if voice_cloning == 'no':
175
+ if voice_gender == 'male':
176
+ voice = 'echo'
177
+ print('Entering Audio creation using elevenlabs')
178
+ # set_api_key(os.getenv('ELEVENLABS_API_KEY'))
179
+
180
+ audio = generate(text = text_prompt, voice = "Daniel", model = "eleven_multilingual_v2",stream=True, latency=4)
181
+ with tempfile.NamedTemporaryFile(suffix=".mp3", prefix="text_to_speech_",dir=TEMP_DIR.name, delete=False) as temp_file:
182
+ for chunk in audio:
183
+ temp_file.write(chunk)
184
+ driven_audio_path = temp_file.name
185
+ print('driven_audio_path',driven_audio_path)
186
+ print('Audio file saved using elevenlabs')
187
+
188
+ else:
189
+ voice = 'nova'
190
+
191
+ print('Entering Audio creation using whisper')
192
+ response = client.audio.speech.create(model="tts-1-hd",
193
+ voice=voice,
194
+ input = text_prompt)
195
+
196
+ print('Audio created using whisper')
197
+ with tempfile.NamedTemporaryFile(suffix=".wav", prefix="text_to_speech_",dir=TEMP_DIR.name, delete=False) as temp_file:
198
+ driven_audio_path = temp_file.name
199
+
200
+ response.write_to_file(driven_audio_path)
201
+ print('Audio file saved using whisper')
202
+
203
+ elif voice_cloning == 'yes':
204
+ set_api_key(os.getenv('ELEVENLABS_API_KEY'))
205
+ # voice = clone(name = "User Cloned Voice",
206
+ # files = [user_voice_path] )
207
+ voice = Voice(voice_id="CEii8R8RxmB0zhAiloZg",name="Marc",settings=VoiceSettings(
208
+ stability=0.71, similarity_boost=0.5, style=0.0, use_speaker_boost=True),)
209
+
210
+ audio = generate(text = text_prompt, voice = voice, model = "eleven_multilingual_v2",stream=True, latency=4)
211
+ with tempfile.NamedTemporaryFile(suffix=".mp3", prefix="cloned_audio_",dir=TEMP_DIR.name, delete=False) as temp_file:
212
+ for chunk in audio:
213
+ temp_file.write(chunk)
214
+ driven_audio_path = temp_file.name
215
+ print('driven_audio_path',driven_audio_path)
216
+
217
+ return driven_audio_path
218
+
219
+ # Preprocessing step that runs only once
220
+ def run_preprocessing(args):
221
+ global path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting
222
+ first_frame_dir = os.path.join(args.result_dir, 'first_frame_dir')
223
+ os.makedirs(first_frame_dir, exist_ok=True)
224
+
225
+ # Check if preprocessed data already exists
226
+ fixed_temp_dir = "C:/Users/fd01076/Downloads/preprocess_data"
227
+ os.makedirs(fixed_temp_dir, exist_ok=True)
228
+ preprocessed_data_path = os.path.join(fixed_temp_dir, "preprocessed_data.pkl")
229
+
230
+ if os.path.exists(preprocessed_data_path) and args.image_hardcoded == "yes":
231
+ with open(preprocessed_data_path, "rb") as f:
232
+ preprocessed_data = pickle.load(f)
233
+ print("Loaded existing preprocessed data from:", preprocessed_data_path)
234
+ else:
235
+ preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, args.device)
236
+ first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(args.source_image, first_frame_dir, args.preprocess, source_image_flag=True)
237
+
238
+ if not first_coeff_path:
239
+ raise Exception("Failed to get coefficients")
240
+
241
+ # Save the preprocessed data
242
+ preprocessed_data = {
243
+ "first_coeff_path": first_coeff_path,
244
+ "crop_pic_path": crop_pic_path,
245
+ "crop_info": crop_info
246
+ }
247
+ with open(preprocessed_data_path, "wb") as f:
248
+ pickle.dump(preprocessed_data, f)
249
+
250
+ return preprocessed_data
251
+
252
+ def split_audio(audio_path, chunk_duration=5):
253
+ audio_clip = mp.AudioFileClip(audio_path)
254
+ total_duration = audio_clip.duration
255
+
256
+ audio_chunks = []
257
+ for start_time in range(0, int(total_duration), chunk_duration):
258
+ end_time = min(start_time + chunk_duration, total_duration)
259
+ chunk = audio_clip.subclip(start_time, end_time)
260
+ with tempfile.NamedTemporaryFile(suffix=f"_chunk_{start_time}-{end_time}.wav", prefix="audio_chunk_", dir=TEMP_DIR.name, delete=False) as temp_file:
261
+ chunk_path = temp_file.name
262
+ chunk.write_audiofile(chunk_path)
263
+ audio_chunks.append(chunk_path)
264
+
265
+ return audio_chunks
266
+
267
+ # Generator function to yield chunk results as they are processed
268
+ def generate_chunks(audio_chunks, preprocessed_data, args):
269
+ future_to_chunk = {executor.submit(process_chunk, chunk, preprocessed_data, args): chunk for chunk in audio_chunks}
270
+
271
+ for future in as_completed(future_to_chunk):
272
+ chunk = future_to_chunk[future] # Get the original chunk that was processed
273
+ try:
274
+ base64_video, temp_file_path, duration = future.result() # Get the result of the completed task
275
+ yield f"Task for chunk {chunk} completed with video path: {temp_file_path}\n"
276
+ except Exception as e:
277
+ yield f"Task for chunk {chunk} failed: {e}\n"
278
+
279
+ @app.route("/run", methods=['POST'])
280
+ def parallel_processing():
281
+ global start_time
282
+ start_time = time.time()
283
+ global TEMP_DIR
284
+ global audio_chunks
285
+ TEMP_DIR = create_temp_dir()
286
+ print('request:',request.method)
287
+ try:
288
+ if request.method == 'POST':
289
+ # source_image = request.files['source_image']
290
+ image_path = 'C:/Users/fd01076/Downloads/marc_smile_image_videos/marc_smile_enhanced.png'
291
+ source_image = Image.open(image_path)
292
+ text_prompt = request.form['text_prompt']
293
+
294
+ print('Input text prompt: ',text_prompt)
295
+ text_prompt = text_prompt.strip()
296
+ if not text_prompt:
297
+ return jsonify({'error': 'Input text prompt cannot be blank'}), 400
298
+
299
+ voice_cloning = request.form.get('voice_cloning', 'yes')
300
+ image_hardcoded = request.form.get('image_hardcoded', 'no')
301
+ chat_model_used = request.form.get('chat_model_used', 'openai')
302
+ target_language = request.form.get('target_language', 'original_text')
303
+ print('target_language',target_language)
304
+ pose_style = int(request.form.get('pose_style', 1))
305
+ expression_scale = float(request.form.get('expression_scale', 1))
306
+ enhancer = request.form.get('enhancer', None)
307
+ voice_gender = request.form.get('voice_gender', 'male')
308
+ still_str = request.form.get('still', 'False')
309
+ still = still_str.lower() == 'false'
310
+ print('still', still)
311
+ preprocess = request.form.get('preprocess', 'crop')
312
+ print('preprocess selected: ',preprocess)
313
+ ref_pose_video = request.files.get('ref_pose', None)
314
+
315
+ app.config['text_prompt'] = text_prompt
316
+ print('Final output text prompt using openai: ',text_prompt)
317
+
318
+ source_image_path = save_uploaded_file(source_image, 'source_image.png',TEMP_DIR)
319
+ print(source_image_path)
320
+
321
+ driven_audio_path = generate_audio(voice_cloning, voice_gender, text_prompt)
322
+
323
+ save_dir = tempfile.mkdtemp(dir=TEMP_DIR.name)
324
+ result_folder = os.path.join(save_dir, "results")
325
+ os.makedirs(result_folder, exist_ok=True)
326
+
327
+ ref_pose_video_path = None
328
+ if ref_pose_video:
329
+ with tempfile.NamedTemporaryFile(suffix=".mp4", prefix="ref_pose_",dir=TEMP_DIR.name, delete=False) as temp_file:
330
+ ref_pose_video_path = temp_file.name
331
+ ref_pose_video.save(ref_pose_video_path)
332
+ print('ref_pose_video_path',ref_pose_video_path)
333
+
334
+ except Exception as e:
335
+ app.logger.error(f"An error occurred: {e}")
336
+ return "An error occurred", 500
337
+
338
+ args = AnimationConfig(driven_audio_path=driven_audio_path, source_image_path=source_image_path, result_folder=result_folder, pose_style=pose_style, expression_scale=expression_scale,enhancer=enhancer,still=still,preprocess=preprocess,ref_pose_video_path=ref_pose_video_path, image_hardcoded=image_hardcoded)
339
+
340
+ preprocessed_data = run_preprocessing(args)
341
+ chunk_duration = 5
342
+ print(f"Splitting the audio into {chunk_duration}-second chunks...")
343
+ audio_chunks = split_audio(driven_audio_path, chunk_duration=chunk_duration)
344
+ print(f"Audio has been split into {len(audio_chunks)} chunks: {audio_chunks}")
345
+
346
+ try:
347
+ return stream_with_context(generate_chunks(audio_chunks, preprocessed_data, args))
348
+ # base64_video, temp_file_path, duration = process_chunk(driven_audio_path, preprocessed_data, args)
349
+ except Exception as e:
350
+ return jsonify({'status': 'error', 'message': str(e)}), 500
351
+
352
+
353
+ @app.route("/health", methods=["GET"])
354
+ def health_status():
355
+ response = {"online": "true"}
356
+ return jsonify(response)
357
+
358
+ if __name__ == '__main__':
359
+ app.run(debug=True)