buletomato25 commited on
Commit
4e9c42f
·
2 Parent(s): dd43f38 1792d9b

Merge branch 'main' into suwabe/docker

Browse files
Files changed (1) hide show
  1. process.py +65 -8
process.py CHANGED
@@ -7,6 +7,7 @@ import random
7
  from datetime import datetime
8
  from pyannote.audio import Model, Inference
9
  from pydub import AudioSegment
 
10
  class AudioProcessor():
11
  def __init__(self,cache_dir = "/tmp/hf_cache"):
12
  hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
@@ -17,6 +18,7 @@ class AudioProcessor():
17
  model = Model.from_pretrained("pyannote/embedding", use_auth_token=hf_token, cache_dir=cache_dir)
18
  self.inference = Inference(model)
19
 
 
20
  def cosine_similarity(self,vec1, vec2):
21
  vec1 = vec1 / np.linalg.norm(vec1)
22
  vec2 = vec2 / np.linalg.norm(vec2)
@@ -53,7 +55,17 @@ class AudioProcessor():
53
  embedding1 = self.inference(path1)
54
  embedding2 = self.inference(path2)
55
  return float(self.cosine_similarity(embedding1.data.flatten(), embedding2.data.flatten()))
 
 
 
 
56
 
 
 
 
 
 
 
57
  def process_audio(self, reference_path, input_path, output_folder='/tmp/data/matched_segments', seg_duration=1.0, threshold=0.5):
58
  # 出力先ディレクトリの中身をクリアする
59
  if os.path.exists(output_folder):
@@ -76,13 +88,58 @@ class AudioProcessor():
76
 
77
  unmatched_time_ms = total_duration_ms - matched_time_ms
78
  return matched_time_ms, unmatched_time_ms
 
79
 
80
- def generate_random_string(self,length):
81
- letters = string.ascii_letters + string.digits
82
- return ''.join(random.choice(letters) for i in range(length))
 
 
 
 
 
 
83
 
84
- def generate_filename(self,random_length):
85
- random_string = self.generate_random_string(random_length)
86
- current_time = datetime.now().strftime("%Y%m%d%H%M%S")
87
- filename = f"{current_time}_{random_string}.wav"
88
- return filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from datetime import datetime
8
  from pyannote.audio import Model, Inference
9
  from pydub import AudioSegment
10
+
11
  class AudioProcessor():
12
  def __init__(self,cache_dir = "/tmp/hf_cache"):
13
  hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
 
18
  model = Model.from_pretrained("pyannote/embedding", use_auth_token=hf_token, cache_dir=cache_dir)
19
  self.inference = Inference(model)
20
 
21
+
22
  def cosine_similarity(self,vec1, vec2):
23
  vec1 = vec1 / np.linalg.norm(vec1)
24
  vec2 = vec2 / np.linalg.norm(vec2)
 
55
  embedding1 = self.inference(path1)
56
  embedding2 = self.inference(path2)
57
  return float(self.cosine_similarity(embedding1.data.flatten(), embedding2.data.flatten()))
58
+
59
+ def generate_random_string(self,length):
60
+ letters = string.ascii_letters + string.digits
61
+ return ''.join(random.choice(letters) for i in range(length))
62
 
63
+ def generate_filename(self,random_length):
64
+ random_string = self.generate_random_string(random_length)
65
+ current_time = datetime.now().strftime("%Y%m%d%H%M%S")
66
+ filename = f"{current_time}_{random_string}.wav"
67
+ return filename
68
+
69
  def process_audio(self, reference_path, input_path, output_folder='/tmp/data/matched_segments', seg_duration=1.0, threshold=0.5):
70
  # 出力先ディレクトリの中身をクリアする
71
  if os.path.exists(output_folder):
 
88
 
89
  unmatched_time_ms = total_duration_ms - matched_time_ms
90
  return matched_time_ms, unmatched_time_ms
91
+
92
 
93
+ def process_multi_audio(self, reference_pathes, input_path, output_folder='/tmp/data/matched_multi_segments', seg_duration=1.0, threshold=0.5):
94
+ # 出力先ディレクトリの中身をクリアする
95
+ if os.path.exists(output_folder):
96
+ for file in os.listdir(output_folder):
97
+ file_path = os.path.join(output_folder, file)
98
+ if os.path.isfile(file_path):
99
+ os.remove(file_path)
100
+ else:
101
+ os.makedirs(output_folder, exist_ok=True)
102
 
103
+ # 入力音声をセグメントに分割
104
+ segmented_path, total_duration_ms = self.segment_audio(input_path, seg_duration=seg_duration)
105
+ segment_files = sorted(os.listdir(segmented_path))
106
+ num_segments = len(segment_files)
107
+
108
+ # 各リファレンスごとにセグメントとの類似度を計算し、行列 (rows: reference, columns: segment) を作成
109
+ similarity = []
110
+ for reference_path in reference_pathes:
111
+ ref_similarity = []
112
+ for file in segment_files:
113
+ segment_file = os.path.join(segmented_path, file)
114
+ sim = self.calculate_similarity(segment_file, reference_path)
115
+ ref_similarity.append(sim)
116
+ similarity.append(ref_similarity)
117
+
118
+ # 転置行列を作成 (rows: segment, columns: reference)
119
+ similarity_transposed = []
120
+ for seg_idx in range(num_segments):
121
+ seg_sim = []
122
+ for ref_idx in range(len(reference_pathes)):
123
+ seg_sim.append(similarity[ref_idx][seg_idx])
124
+ similarity_transposed.append(seg_sim)
125
+
126
+ # 各セグメントについて、最も高い類似度のリファレンスを選択
127
+ best_matches = []
128
+ for seg_sim in similarity_transposed:
129
+ best_ref = np.argmax(seg_sim) # 最も類似度の高いリファレンスのインデックス
130
+ # 閾値チェック (必要に応じて)
131
+ if seg_sim[best_ref] < threshold:
132
+ best_matches.append(None) # 閾値未満の場合はマッチなしとする
133
+ else:
134
+ best_matches.append(best_ref)
135
+
136
+ # 各リファレンスごとに一致時間を集計 (セグメントごとの長さ seg_duration を加算)
137
+ matched_time = [0] * len(reference_pathes)
138
+ for match in best_matches:
139
+ if match is not None:
140
+ matched_time[match] += seg_duration
141
+
142
+ return matched_time
143
+
144
+
145
+