Kr08 commited on
Commit
03f8b40
·
verified ·
1 Parent(s): 0caaf5e

Update audio_processing.py

Browse files
Files changed (1) hide show
  1. audio_processing.py +189 -41
audio_processing.py CHANGED
@@ -1,48 +1,196 @@
 
1
  import torch
2
- import whisper
3
- import torchaudio as ta
4
- from model_utils import get_processor, get_model, get_whisper_model_small, get_device
5
- from config import SAMPLING_RATE, CHUNK_LENGTH_S
6
-
7
- def detect_language(audio_file):
8
- whisper_model = get_whisper_model_small()
9
- trimmed_audio = whisper.pad_or_trim(audio_file.squeeze())
10
- mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device)
11
- _, probs = whisper_model.detect_language(mel)
12
- detected_lang = max(probs[0], key=probs[0].get)
13
- print(f"Detected language: {detected_lang}")
14
- return detected_lang
15
-
16
- def process_long_audio(waveform, sampling_rate, task="transcribe", language=None):
17
- processor = get_processor()
18
- model = get_model()
19
- device = get_device()
20
-
21
- input_length = waveform.shape[1]
22
- chunk_length = int(CHUNK_LENGTH_S * sampling_rate)
23
- chunks = [waveform[:, i:i + chunk_length] for i in range(0, input_length, chunk_length)]
24
-
25
- results = []
26
- for chunk in chunks:
27
- input_features = processor(chunk[0], sampling_rate=sampling_rate, return_tensors="pt").input_features.to(device)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  with torch.no_grad():
30
- if task == "translate":
31
- forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="translate")
32
- generated_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
33
- else:
34
- generated_ids = model.generate(input_features)
35
 
36
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
37
- results.extend(transcription)
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Clear GPU cache
40
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- return " ".join(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- def load_and_resample_audio(file):
45
- waveform, sampling_rate = ta.load(file)
46
- if sampling_rate != SAMPLING_RATE:
47
- waveform = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE)
48
- return waveform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
  import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from transformers import (
6
+ Wav2Vec2ForSequenceClassification,
7
+ AutoFeatureExtractor,
8
+ Wav2Vec2ForCTC,
9
+ AutoProcessor,
10
+ AutoTokenizer,
11
+ AutoModelForSeq2SeqLM
12
+ )
13
+ import logging
14
+ from difflib import SequenceMatcher
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class AudioProcessor:
20
+ def __init__(self, chunk_size=5, overlap=1, sample_rate=16000):
21
+ self.chunk_size = chunk_size
22
+ self.overlap = overlap
23
+ self.sample_rate = sample_rate
24
+ self.previous_text = ""
25
+ self.previous_lang = None
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+
28
+ def load_models(self):
29
+ """Load all required models"""
30
+ logger.info("Loading MMS models...")
31
+
32
+ # Language identification model
33
+ lid_processor = AutoFeatureExtractor.from_pretrained("facebook/mms-lid-256")
34
+ lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-256")
35
+
36
+ # Transcription model
37
+ mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
38
+ mms_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all")
39
+
40
+ # Translation model
41
+ translation_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
42
+ translation_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
43
+
44
+ return {
45
+ 'lid': (lid_model, lid_processor),
46
+ 'mms': (mms_model, mms_processor),
47
+ 'translation': (translation_model, translation_tokenizer)
48
+ }
49
+
50
+ def identify_language(self, audio_chunk, models):
51
+ """Identify language of audio chunk"""
52
+ lid_model, lid_processor = models['lid']
53
+ inputs = lid_processor(audio_chunk, sampling_rate=16000, return_tensors="pt")
54
+
55
  with torch.no_grad():
56
+ outputs = lid_model(inputs.input_values.to(self.device)).logits
57
+ lang_id = torch.argmax(outputs, dim=-1)[0].item()
58
+ detected_lang = lid_model.config.id2label[lang_id]
59
+
60
+ return detected_lang
61
 
62
+ def transcribe_chunk(self, audio_chunk, language, models):
63
+ """Transcribe audio chunk"""
64
+ mms_model, mms_processor = models['mms']
65
+
66
+ mms_processor.tokenizer.set_target_lang(language)
67
+ mms_model.load_adapter(language)
68
+
69
+ inputs = mms_processor(audio_chunk, sampling_rate=16000, return_tensors="pt")
70
+
71
+ with torch.no_grad():
72
+ outputs = mms_model(inputs.input_values.to(self.device)).logits
73
+ ids = torch.argmax(outputs, dim=-1)[0]
74
+ transcription = mms_processor.decode(ids)
75
+
76
+ return transcription
77
 
78
+ def translate_text(self, text, models):
79
+ """Translate text to English"""
80
+ translation_model, translation_tokenizer = models['translation']
81
+
82
+ inputs = translation_tokenizer(text, return_tensors="pt")
83
+ inputs = inputs.to(self.device)
84
+
85
+ with torch.no_grad():
86
+ outputs = translation_model.generate(
87
+ **inputs,
88
+ forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"),
89
+ max_length=100
90
+ )
91
+ translation = translation_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
92
+
93
+ return translation
94
+
95
+ def process_audio(self, audio_path, translate=False):
96
+ """Main processing function"""
97
+ try:
98
+ # Load audio
99
+ waveform, sample_rate = torchaudio.load(audio_path)
100
+ if waveform.shape[0] > 1:
101
+ waveform = torch.mean(waveform, dim=0)
102
+
103
+ # Resample if necessary
104
+ if sample_rate != self.sample_rate:
105
+ waveform = torchaudio.transforms.Resample(sample_rate, self.sample_rate)(waveform)
106
 
107
+ # Load models
108
+ models = self.load_models()
109
+
110
+ # Process in chunks
111
+ chunk_samples = int(self.chunk_size * self.sample_rate)
112
+ overlap_samples = int(self.overlap * self.sample_rate)
113
+
114
+ segments = []
115
+ language_segments = []
116
+
117
+ for i in range(0, len(waveform), chunk_samples - overlap_samples):
118
+ chunk = waveform[i:i + chunk_samples]
119
+ if len(chunk) < chunk_samples:
120
+ chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
121
+
122
+ # Process chunk
123
+ start_time = i / self.sample_rate
124
+ end_time = (i + len(chunk)) / self.sample_rate
125
+
126
+ # Identify language
127
+ language = self.identify_language(chunk, models)
128
+
129
+ # Record language segment
130
+ language_segments.append({
131
+ "language": language,
132
+ "start": start_time,
133
+ "end": end_time
134
+ })
135
+
136
+ # Transcribe
137
+ transcription = self.transcribe_chunk(chunk, language, models)
138
+
139
+ segment = {
140
+ "start": start_time,
141
+ "end": end_time,
142
+ "language": language,
143
+ "text": transcription,
144
+ "speaker": "Speaker" # Simple speaker assignment
145
+ }
146
+
147
+ if translate:
148
+ translation = self.translate_text(transcription, models)
149
+ segment["translated"] = translation
150
+
151
+ segments.append(segment)
152
+
153
+ # Clean up GPU memory
154
+ torch.cuda.empty_cache()
155
+ gc.collect()
156
+
157
+ # Merge nearby segments
158
+ merged_segments = self.merge_segments(segments)
159
+
160
+ return language_segments, merged_segments
161
 
162
+ except Exception as e:
163
+ logger.error(f"Error processing audio: {str(e)}")
164
+ raise
165
+
166
+ def merge_segments(self, segments, time_threshold=0.5, similarity_threshold=0.7):
167
+ """Merge similar nearby segments"""
168
+ if not segments:
169
+ return segments
170
+
171
+ merged = []
172
+ current = segments[0]
173
+
174
+ for next_segment in segments[1:]:
175
+ if (next_segment['start'] - current['end'] <= time_threshold and
176
+ current['language'] == next_segment['language']):
177
+
178
+ # Check text similarity
179
+ matcher = SequenceMatcher(None, current['text'], next_segment['text'])
180
+ similarity = matcher.ratio()
181
+
182
+ if similarity > similarity_threshold:
183
+ # Merge segments
184
+ current['end'] = next_segment['end']
185
+ current['text'] = current['text'] + ' ' + next_segment['text']
186
+ if 'translated' in current and 'translated' in next_segment:
187
+ current['translated'] = current['translated'] + ' ' + next_segment['translated']
188
+ else:
189
+ merged.append(current)
190
+ current = next_segment
191
+ else:
192
+ merged.append(current)
193
+ current = next_segment
194
+
195
+ merged.append(current)
196
+ return merged