test / videoclipper.py
01du
Add application file
cdbb2b2
raw
history blame
15.9 kB
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunClip). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import re
import os
import sys
import copy
import librosa
import logging
import argparse
import numpy as np
import soundfile as sf
from moviepy.editor import *
import moviepy.editor as mpy
from moviepy.video.tools.subtitles import SubtitlesClip, TextClip
from moviepy.editor import VideoFileClip, concatenate_videoclips
from moviepy.video.compositing import CompositeVideoClip
from utils.subtitle_utils import generate_srt, generate_srt_clip,generate_audio_srt,trans_format
from utils.argparse_tools import ArgumentParser, get_commandline_args
from utils.trans_utils import pre_proc, proc, write_state, load_state, proc_spk, convert_pcm_to_float
import whisper
class VideoClipper():
def __init__(self, model):
logging.warning("Initializing VideoClipper.")
self.GLOBAL_COUNT = 0
self.model = model
def recog(self, audio_input, state=None, output_dir=None,text=None):
'''
将音频输入转化为文本。它可以选择性地进行说话人分离(SD, Speaker Diarization)和生成字幕文件(SRT格式)。
return:
res_text:识别出的文本内容。
res_srt:识别内容生成的 SRT 字幕格式。
state:包含了识别的原始结果、时间戳和句子信息的状态字典
'''
if state is None:
state = {}
sr, data = audio_input
# Convert to float64 consistently (includes data type checking)
data = convert_pcm_to_float(data)
# assert sr == 16000, "16kHz sample rate required, {} given.".format(sr)
if sr != 16000: # resample with librosa
data = librosa.resample(data, orig_sr=sr, target_sr=16000)
if len(data.shape) == 2: # multi-channel wav input
logging.warning("Input wav shape: {}, only first channel reserved.".format(data.shape))
data = data[:,0]
state['audio_input'] = (sr, data)
rec_result = trans_format(text)
res_srt = generate_srt(rec_result[0]['sentence_info'])
state['recog_res_raw'] = rec_result[0]['raw_text']
state['timestamp'] = rec_result[0]['timestamp']
state['sentences'] = rec_result[0]['sentence_info']
res_text = rec_result[0]['text']
return res_text, res_srt, state
def clip(self, dest_text, start_ost, end_ost, state, dest_spk=None, output_dir=None, timestamp_list=None):
# get from state
'''
dest_text:目标文本,根据这个文本内容来定位音频中相应的片段。
start_ost 和 end_ost:起始和结束时间偏移量,用于微调音频片段的起止位置。
state:包含函数执行所需的数据状态,例如音频数据、识别结果、时间戳等。
dest_spk:目标说话者,如果指定了这个参数,函数会根据说话者信息来提取音频片段。
output_dir:输出目录,用于保存结果。
timestamp_list:时间戳列表,如果提供了时间戳,则直接按照这些时间戳提取音频片段。
'''
audio_input = state['audio_input']
recog_res_raw = state['recog_res_raw']
timestamp = state['timestamp']
sentences = state['sentences']
sr, data = audio_input
data = data.astype(np.float64)
if timestamp_list is None:
all_ts = []
if dest_spk is None or dest_spk == '' or 'sd_sentences' not in state:
for _dest_text in dest_text.split('#'):
if '[' in _dest_text:
match = re.search(r'\[(\d+),\s*(\d+)\]', _dest_text)
if match:
offset_b, offset_e = map(int, match.groups())
log_append = ""
else:
offset_b, offset_e = 0, 0
log_append = "(Bracket detected in dest_text but offset time matching failed)"
_dest_text = _dest_text[:_dest_text.find('[')]
else:
log_append = ""
offset_b, offset_e = 0, 0
_dest_text = pre_proc(_dest_text)
ts = proc(recog_res_raw, timestamp, _dest_text) # 得到时间戳
for _ts in ts: all_ts.append([_ts[0]+offset_b*16, _ts[1]+offset_e*16])
if len(ts) > 1 and match:
log_append += '(offsets detected but No.{} sub-sentence matched to {} periods in audio, \
offsets are applied to all periods)'
else:
for _dest_spk in dest_spk.split('#'):
ts = proc_spk(_dest_spk, state['sd_sentences'])
for _ts in ts: all_ts.append(_ts)
log_append = ""
else:
all_ts = timestamp_list
ts = all_ts
# ts.sort()
srt_index = 0
clip_srt = ""
if len(ts):
start, end = ts[0]
start = min(max(0, start+start_ost*16), len(data))
end = min(max(0, end+end_ost*16), len(data))
res_audio = data[start:end]
start_end_info = "from {} to {}".format(start/16000, end/16000)
srt_clip, _, srt_index = generate_srt_clip(sentences, start/16000.0, end/16000.0, begin_index=srt_index)
clip_srt += srt_clip
for _ts in ts[1:]: # multiple sentence input or multiple output matched
start, end = _ts
start = min(max(0, start+start_ost*16), len(data))
end = min(max(0, end+end_ost*16), len(data))
start_end_info += ", from {} to {}".format(start, end)
res_audio = np.concatenate([res_audio, data[start+start_ost*16:end+end_ost*16]], -1)
srt_clip, _, srt_index = generate_srt_clip(sentences, start/16000.0, end/16000.0, begin_index=srt_index-1)
clip_srt += srt_clip
if len(ts):
message = "{} periods found in the speech: ".format(len(ts)) + start_end_info + log_append
else:
message = "No period found in the speech, return raw speech. You may check the recognition result and try other destination text."
res_audio = data
return (sr, res_audio), message, clip_srt # 音频数据、消息文本和生成的 SRT 字幕
def video_recog(self, video_filename, output_dir=None,ASR="whisper"):
'''通过处理视频获得想要的视频、音频以及其他信息'''
video = mpy.VideoFileClip(video_filename)
# Extract the base name, add '_clip.mp4', and 'wav'
if output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
_, base_name = os.path.split(video_filename)
base_name, _ = os.path.splitext(base_name)
clip_video_file = base_name + '_clip.mp4'
audio_file = base_name + '.wav'
audio_file = os.path.join(output_dir, audio_file)
else:
base_name, _ = os.path.splitext(video_filename)
clip_video_file = base_name + '_clip.mp4'
audio_file = base_name + '.wav'
video.audio.write_audiofile(audio_file)
# 在这里使用whisper对音频文件进行处理
result_audio = self.model.transcribe(audio_file,language = "zh", word_timestamps=True)
wav = librosa.load(audio_file, sr=16000)[0]
# delete the audio file after processing
if os.path.exists(audio_file):
os.remove(audio_file)
state = {
'video_filename': video_filename,
'clip_video_file': clip_video_file,
'video': video,
}
return self.recog((16000, wav), state, output_dir,text=result_audio)
def video_clip(self,
dest_text,
start_ost,
end_ost,
state,
font_size=32,
font_color='white',
add_sub=False,
dest_spk=None,
output_dir=None,
timestamp_list=None):
# get from state
recog_res_raw = state['recog_res_raw']
timestamp = state['timestamp']
sentences = state['sentences']
video = state['video']
clip_video_file = state['clip_video_file']
video_filename = state['video_filename']
if timestamp_list is None:
all_ts = []
if dest_spk is None or dest_spk == '' or 'sd_sentences' not in state:
for _dest_text in dest_text.split('#'):
if '[' in _dest_text:
match = re.search(r'\[(\d+),\s*(\d+)\]', _dest_text)
if match:
offset_b, offset_e = map(int, match.groups())
log_append = ""
else:
offset_b, offset_e = 0, 0
log_append = "(Bracket detected in dest_text but offset time matching failed)"
_dest_text = _dest_text[:_dest_text.find('[')]
else:
offset_b, offset_e = 0, 0
log_append = ""
# import pdb; pdb.set_trace()
_dest_text = pre_proc(_dest_text)
ts = proc(recog_res_raw, timestamp, _dest_text.lower())
for _ts in ts: all_ts.append([_ts[0]+offset_b*16, _ts[1]+offset_e*16])
if len(ts) > 1 and match:
log_append += '(offsets detected but No.{} sub-sentence matched to {} periods in audio, \
offsets are applied to all periods)'
else:
for _dest_spk in dest_spk.split('#'):
ts = proc_spk(_dest_spk, state['sd_sentences'])
for _ts in ts: all_ts.append(_ts)
else: # AI clip pass timestamp as input directly
all_ts = [[i[0]*16.0, i[1]*16.0] for i in timestamp_list]
srt_index = 0
time_acc_ost = 0.0
ts = all_ts
# ts.sort()
clip_srt = ""
if len(ts):
# if self.lang == 'en' and isinstance(sentences, str):
# sentences = sentences.split()
start, end = ts[0][0] / 16000, ts[0][1] / 16000
srt_clip, subs, srt_index = generate_srt_clip(sentences, start, end, begin_index=srt_index, time_acc_ost=time_acc_ost)
start, end = start+start_ost/1000.0, end+end_ost/1000.0
video_clip = video.subclip(start, end)
start_end_info = "from {} to {}".format(start, end)
clip_srt += srt_clip
if add_sub: # 叠加字幕
generator = lambda txt: TextClip(txt, font='./font/STHeitiMedium.ttc', fontsize=font_size, color=font_color)
subtitles = SubtitlesClip(subs, generator)
video_clip = CompositeVideoClip([video_clip, subtitles.set_pos(('center','bottom'))])
concate_clip = [video_clip]
time_acc_ost += end+end_ost/1000.0 - (start+start_ost/1000.0)
for _ts in ts[1:]:
start, end = _ts[0] / 16000, _ts[1] / 16000
srt_clip, subs, srt_index = generate_srt_clip(sentences, start, end, begin_index=srt_index-1, time_acc_ost=time_acc_ost)
if not len(subs):
continue
chi_subs = []
sub_starts = subs[0][0][0]
for sub in subs:
chi_subs.append(((sub[0][0]-sub_starts, sub[0][1]-sub_starts), sub[1]))
start, end = start+start_ost/1000.0, end+end_ost/1000.0
_video_clip = video.subclip(start, end)
start_end_info += ", from {} to {}".format(str(start)[:5], str(end)[:5])
clip_srt += srt_clip
if add_sub:
generator = lambda txt: TextClip(txt, font='./font/STHeitiMedium.ttc', fontsize=font_size, color=font_color)
subtitles = SubtitlesClip(chi_subs, generator)
_video_clip = CompositeVideoClip([_video_clip, subtitles.set_pos(('center','bottom'))])
# _video_clip.write_videofile("debug.mp4", audio_codec="aac")
concate_clip.append(copy.copy(_video_clip))
time_acc_ost += end+end_ost/1000.0 - (start+start_ost/1000.0)
message = "{} periods found in the audio: ".format(len(ts)) + start_end_info
logging.warning("Concating...")
if len(concate_clip) > 1: # 对视频片段进行拼接
video_clip = concatenate_videoclips(concate_clip)
# clip_video_file = clip_video_file[:-4] + '_no{}.mp4'.format(self.GLOBAL_COUNT)
if output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
_, file_with_extension = os.path.split(clip_video_file)
clip_video_file_name, _ = os.path.splitext(file_with_extension)
print(output_dir, clip_video_file)
clip_video_file = os.path.join(output_dir, "{}_no{}.mp4".format(clip_video_file_name, self.GLOBAL_COUNT))
temp_audio_file = os.path.join(output_dir, "{}_tempaudio_no{}.mp4".format(clip_video_file_name, self.GLOBAL_COUNT))
else:
clip_video_file = clip_video_file[:-4] + '_no{}.mp4'.format(self.GLOBAL_COUNT)
temp_audio_file = clip_video_file[:-4] + '_tempaudio_no{}.mp4'.format(self.GLOBAL_COUNT)
video_clip.write_videofile(clip_video_file, audio_codec="aac", temp_audiofile=temp_audio_file,fps=25) #写入指定文件路径下
self.GLOBAL_COUNT += 1
else:
clip_video_file = video_filename
message = "No period found in the audio, return raw speech. You may check the recognition result and try other destination text."
srt_clip = ''
return clip_video_file, message, clip_srt
def get_parser():
parser = ArgumentParser(
description="ClipVideo Argument",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--stage",
type=int,
choices=(1, 2),
help="Stage, 0 for recognizing and 1 for clipping",
required=True
)
parser.add_argument(
"--file",
type=str,
default=None,
help="Input file path",
required=True
)
parser.add_argument(
"--sd_switch",
type=str,
choices=("no", "yes"),
default="no",
help="Turn on the speaker diarization or not",
)
parser.add_argument(
"--output_dir",
type=str,
default='./output',
help="Output files path",
)
parser.add_argument(
"--dest_text",
type=str,
default=None,
help="Destination text string for clipping",
)
parser.add_argument(
"--dest_spk",
type=str,
default=None,
help="Destination spk id for clipping",
)
parser.add_argument(
"--start_ost",
type=int,
default=0,
help="Offset time in ms at beginning for clipping"
)
parser.add_argument(
"--end_ost",
type=int,
default=0,
help="Offset time in ms at ending for clipping"
)
parser.add_argument(
"--output_file",
type=str,
default=None,
help="Output file path"
)
parser.add_argument(
"--lang",
type=str,
default='zh',
help="language"
)
return parser