Kr08 commited on
Commit
7d0ee66
·
verified ·
1 Parent(s): a3f7705

Create audio_processing.py

Browse files
Files changed (1) hide show
  1. audio_processing.py +145 -0
audio_processing.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisperx
2
+ import torch
3
+ import numpy as np
4
+ from scipy.signal import resample
5
+ from pyannote.audio import Pipeline
6
+ import os
7
+ from dotenv import load_dotenv
8
+ load_dotenv()
9
+ import logging
10
+ import time
11
+ from difflib import SequenceMatcher
12
+ hf_token = os.getenv("HF_TOKEN")
13
+
14
+ CHUNK_LENGTH=30
15
+ OVERLAP=0
16
+ import whisperx
17
+ import torch
18
+ import numpy as np
19
+
20
+
21
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
22
+ logger = logging.getLogger(__name__)
23
+ import spaces
24
+
25
+
26
+ def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000): # 2 seconds overlap
27
+ chunks = []
28
+ for i in range(0, len(audio), chunk_size - overlap):
29
+ chunk = audio[i:i+chunk_size]
30
+ if len(chunk) < chunk_size:
31
+ chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
32
+ chunks.append(chunk)
33
+ return chunks
34
+
35
+ @spaces.GPU()
36
+ def process_audio(audio_file, translate=False, model_size="small"):
37
+ start_time = time.time()
38
+
39
+ try:
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ print(f"Using device: {device}")
42
+ compute_type = "int8" if torch.cuda.is_available() else "float32"
43
+ audio = whisperx.load_audio(audio_file)
44
+ model = whisperx.load_model(model_size, device, compute_type=compute_type)
45
+
46
+ diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
47
+ diarization_pipeline = diarization_pipeline.to(torch.device(device))
48
+
49
+ diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
50
+
51
+ chunks = preprocess_audio(audio)
52
+
53
+ language_segments = []
54
+ final_segments = []
55
+
56
+ overlap_duration = OVERLAP # 2 seconds overlap
57
+ for i, chunk in enumerate(chunks):
58
+ chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
59
+ chunk_end_time = chunk_start_time + CHUNK_LENGTH
60
+ logger.info(f"Processing chunk {i+1}/{len(chunks)}")
61
+ lang = model.detect_language(chunk)
62
+ result_transcribe = model.transcribe(chunk, language=lang)
63
+ if translate:
64
+ result_translate = model.transcribe(chunk, task="translate")
65
+ chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
66
+ for j, t_seg in enumerate(result_transcribe["segments"]):
67
+ segment_start = chunk_start_time + t_seg["start"]
68
+ segment_end = chunk_start_time + t_seg["end"]
69
+ # Skip segments in the overlapping region of the previous chunk
70
+ if i > 0 and segment_end <= chunk_start_time + overlap_duration:
71
+ print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
72
+ continue
73
+
74
+ # Skip segments in the overlapping region of the next chunk
75
+ if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
76
+ print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
77
+ continue
78
+
79
+ speakers = []
80
+ for turn, track, speaker in diarization_result.itertracks(yield_label=True):
81
+ if turn.start <= segment_end and turn.end >= segment_start:
82
+ speakers.append(speaker)
83
+
84
+ segment = {
85
+ "start": segment_start,
86
+ "end": segment_end,
87
+ "language": lang,
88
+ "speaker": max(set(speakers), key=speakers.count) if speakers else "Unknown",
89
+ "text": t_seg["text"],
90
+ }
91
+
92
+ if translate:
93
+ segment["translated"] = result_translate["segments"][j]["text"]
94
+
95
+ final_segments.append(segment)
96
+
97
+ language_segments.append({
98
+ "language": lang,
99
+ "start": chunk_start_time,
100
+ "end": chunk_start_time + CHUNK_LENGTH
101
+ })
102
+ chunk_end_time = time.time()
103
+ logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")
104
+
105
+ final_segments.sort(key=lambda x: x["start"])
106
+ merged_segments = merge_nearby_segments(final_segments)
107
+
108
+ end_time = time.time()
109
+ logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
110
+
111
+ return language_segments, final_segments
112
+ except Exception as e:
113
+ logger.error(f"An error occurred during audio processing: {str(e)}")
114
+ raise
115
+
116
+ def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.9):
117
+ merged = []
118
+ for segment in segments:
119
+ if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
120
+ merged.append(segment)
121
+ else:
122
+ # Find the overlap
123
+ matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
124
+ match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
125
+
126
+ if match.size / len(segment['text']) > similarity_threshold:
127
+ # Merge the segments
128
+ merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
129
+ merged_translated = merged[-1]['translated'] + segment['translated'][match.b + match.size:]
130
+
131
+ merged[-1]['end'] = segment['end']
132
+ merged[-1]['text'] = merged_text
133
+ merged[-1]['translated'] = merged_translated
134
+ else:
135
+ # If no significant overlap, append as a new segment
136
+ merged.append(segment)
137
+ return merged
138
+
139
+ def print_results(segments):
140
+ for segment in segments:
141
+ print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:")
142
+ print(f"Original: {segment['text']}")
143
+ if 'translated' in segment:
144
+ print(f"Translated: {segment['translated']}")
145
+ print()