Manjot Singh commited on
Commit
b4a3bdb
·
1 Parent(s): 6e1e8ec

asr_timestamp_transcription+diarization

Browse files
Files changed (1) hide show
  1. audio_processing.py +169 -0
audio_processing.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisperx
2
+ import torch
3
+ import numpy as np
4
+ from scipy.signal import resample
5
+ import numpy as np
6
+ import whisperx
7
+ from pyannote.audio import Pipeline
8
+ import os
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv()
12
+
13
+ hf_token = os.getenv("HF_TOKEN")
14
+ import whisperx
15
+ import torch
16
+ import numpy as np
17
+
18
+ import whisperx
19
+ import torch
20
+ import numpy as np
21
+
22
+ import whisperx
23
+ import torch
24
+ import numpy as np
25
+ CHUNK_LENGTH=5
26
+
27
+ # def process_audio(audio_file):
28
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ # compute_type = "float32"
30
+ # audio = whisperx.load_audio(audio_file)
31
+ # model = whisperx.load_model("small", device, compute_type=compute_type)
32
+
33
+ # # Initial transcription
34
+ # result = model.transcribe(audio, batch_size=8)
35
+
36
+ # # Sliding window for language detection
37
+ # window_size = 5 # seconds
38
+ # step_size = 1 # seconds
39
+ # sample_rate = 16000
40
+
41
+ # language_probs = []
42
+ # audio_duration = len(audio) / sample_rate
43
+
44
+ # if audio_duration <= window_size:
45
+ # # If audio is shorter than or equal to window size, detect language for entire audio
46
+ # lang = model.detect_language(audio)
47
+ # language_probs.append((0, lang))
48
+ # else:
49
+ # for i in range(0, len(audio) - window_size * sample_rate + 1, step_size * sample_rate):
50
+ # window = audio[i:i + window_size * sample_rate]
51
+ # lang = model.detect_language(window)
52
+ # language_probs.append((i / sample_rate, lang))
53
+
54
+ # # Detect language changes
55
+ # language_segments = []
56
+ # current_lang = language_probs[0][1]
57
+ # start_time = 0
58
+ # for time, lang in language_probs[1:]:
59
+ # if lang != current_lang:
60
+ # language_segments.append({
61
+ # "language": current_lang,
62
+ # "start": start_time,
63
+ # "end": time
64
+ # })
65
+ # current_lang = lang
66
+ # start_time = time
67
+
68
+ # # Add the last segment
69
+ # language_segments.append({
70
+ # "language": current_lang,
71
+ # "start": start_time,
72
+ # "end": audio_duration
73
+ # })
74
+
75
+ # # Re-transcribe each language segment
76
+ # final_segments = []
77
+ # for segment in language_segments:
78
+ # start_sample = int(segment["start"] * sample_rate)
79
+ # end_sample = int(segment["end"] * sample_rate)
80
+ # segment_audio = audio[start_sample:end_sample]
81
+
82
+ # segment_result = model.transcribe(segment_audio, language=segment["language"])
83
+
84
+ # for seg in segment_result["segments"]:
85
+ # seg["start"] += segment["start"]
86
+ # seg["end"] += segment["start"]
87
+ # seg["language"] = segment["language"]
88
+ # final_segments.append(seg)
89
+
90
+ # return language_segments, final_segments
91
+
92
+ import whisperx
93
+ import torch
94
+ import numpy as np
95
+
96
+ def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000): # 30 seconds at 16kHz
97
+ chunks = []
98
+ for i in range(0, len(audio), chunk_size):
99
+ chunk = audio[i:i+chunk_size]
100
+ if len(chunk) < chunk_size:
101
+ chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
102
+ chunks.append(chunk)
103
+ return chunks
104
+
105
+ def process_audio(audio_file):
106
+ device = "cuda" if torch.cuda.is_available() else "cpu"
107
+ compute_type = "float32"
108
+ audio = whisperx.load_audio(audio_file)
109
+ model = whisperx.load_model("small", device, compute_type=compute_type)
110
+
111
+ # Initialize speaker diarization pipeline
112
+ diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
113
+ diarization_pipeline = diarization_pipeline.to(torch.device(device))
114
+
115
+ # Perform diarization on the entire audio
116
+ diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
117
+
118
+
119
+ # Preprocess audio into consistent chunks
120
+ chunks = preprocess_audio(audio)
121
+
122
+ language_segments = []
123
+ final_segments = []
124
+
125
+ for i, chunk in enumerate(chunks):
126
+ # Detect language for this chunk
127
+ lang = model.detect_language(chunk)
128
+
129
+ # Transcribe this chunk
130
+ result = model.transcribe(chunk, language=lang)
131
+
132
+ chunk_start_time = i * 5 # Each chunk is 30 seconds
133
+
134
+ # Adjust timestamps and add language information
135
+ for segment in result["segments"]:
136
+ segment_start = chunk_start_time + segment["start"]
137
+ segment_end = chunk_start_time + segment["end"]
138
+ segment["start"] = segment_start
139
+ segment["end"] = segment_end
140
+ segment["language"] = lang
141
+
142
+ speakers = []
143
+ for turn, track, speaker in diarization_result.itertracks(yield_label=True):
144
+ if turn.start <= segment_end and turn.end >= segment_start:
145
+ speakers.append(speaker)
146
+ if speakers:
147
+ segment["speaker"] = max(set(speakers), key=speakers.count)
148
+ else:
149
+ segment["speaker"] = "Unknown"
150
+
151
+ final_segments.append(segment)
152
+ # Add language segment
153
+ language_segments.append({
154
+ "language": lang,
155
+ "start": chunk_start_time,
156
+ "end": chunk_start_time + 5
157
+ })
158
+
159
+ return language_segments, final_segments
160
+
161
+ def print_results(language, language_probs, segments):
162
+ print(f"Detected Language: {language}")
163
+ print("Language Probabilities:")
164
+ for lang, prob in language_probs.items():
165
+ print(f" {lang}: {prob:.4f}")
166
+
167
+ print("\nTranscription:")
168
+ for segment in segments:
169
+ print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] Speaker {segment['speaker']}: {segment['text']}")