JusTalk / process.py
rein0421's picture
Upload 5 files
99daaaf verified
raw
history blame
3.76 kB
import os
import shutil
import numpy as np
import string
import random
from datetime import datetime
from pyannote.audio import Model, Inference
from pydub import AudioSegment
class AudioProcessor():
def __init__(self,cache_dir = "/tmp/hf_cache"):
hf_token = os.environ.get("HF")
if hf_token is None:
print('3')
raise ValueError("HUGGINGFACE_HUB_TOKEN が設定されていません。")
os.makedirs(cache_dir, exist_ok=True)
# pyannote モデルの読み込み
model = Model.from_pretrained("pyannote/embedding", use_auth_token=hf_token, cache_dir=cache_dir)
self.inference = Inference(model)
def cosine_similarity(self,vec1, vec2):
vec1 = vec1 / np.linalg.norm(vec1)
vec2 = vec2 / np.linalg.norm(vec2)
return np.dot(vec1, vec2)
def segment_audio(self, path, target_path='/tmp/setup_voice', seg_duration=1.0):
"""
音声を指定秒数ごとに分割し、短いセグメントには無音をパディングする。
"""
os.makedirs(target_path, exist_ok=True)
base_sound = AudioSegment.from_file(path)
duration_ms = len(base_sound)
seg_duration_ms = int(seg_duration * 1000)
for i, start in enumerate(range(0, duration_ms, seg_duration_ms)):
end = min(start + seg_duration_ms, duration_ms)
segment = base_sound[start:end]
# セグメントが指定長さに満たない場合、無音でパディングする
if len(segment) < seg_duration_ms:
silence = AudioSegment.silent(duration=(seg_duration_ms - len(segment)))
segment = segment + silence
segment.export(os.path.join(target_path, f'{i}.wav'), format="wav")
return target_path, duration_ms
def calculate_similarity(self,path1, path2):
embedding1 = self.inference(path1)
embedding2 = self.inference(path2)
return float(self.cosine_similarity(embedding1.data.flatten(), embedding2.data.flatten()))
def process_audio(self,reference_path, input_path, output_folder='/tmp/data/matched_segments', seg_duration=1.0, threshold=0.5):
"""
入力音声ファイルを seg_duration 秒ごとに分割し、各セグメントと参照音声の類似度を計算。
類似度が threshold を超えたセグメントを output_folder にコピーし、マッチした時間(ms)と
マッチしなかった時間(ms)を返す。
"""
os.makedirs(output_folder, exist_ok=True)
segmented_path, total_duration_ms = self.segment_audio(input_path, seg_duration=seg_duration)
matched_time_ms = 0
for file in sorted(os.listdir(segmented_path)):
segment_file = os.path.join(segmented_path, file)
similarity = self.calculate_similarity(segment_file, reference_path)
if similarity > threshold:
shutil.copy(segment_file, output_folder)
matched_time_ms += len(AudioSegment.from_file(segment_file))
unmatched_time_ms = total_duration_ms - matched_time_ms
return matched_time_ms, unmatched_time_ms
def generate_random_string(self,length):
letters = string.ascii_letters + string.digits
return ''.join(random.choice(letters) for i in range(length))
def generate_filename(self,random_length):
random_string = self.generate_random_string(random_length)
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
filename = f"{current_time}_{random_string}.wav"
return filename