Kr08 commited on
Commit
759bce7
·
verified ·
1 Parent(s): 444b9c9

Optimized audio_processing.py with optional diarization

Browse files
Files changed (1) hide show
  1. audio_processing.py +28 -35
audio_processing.py CHANGED
@@ -16,38 +16,41 @@ hf_token = os.getenv("HF_TOKEN")
16
  CHUNK_LENGTH = 30
17
  OVERLAP = 2
18
 
19
- logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
- # Global variables for models
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- compute_type = "float16" if device == "cuda" else "int8"
25
- whisper_model = None
26
- diarization_pipeline = None
27
 
28
- def load_models(model_size="small"):
29
- global whisper_model, diarization_pipeline, device, compute_type
30
-
31
- # Load Whisper model
 
32
  try:
33
- whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
 
 
34
  except RuntimeError as e:
35
  logger.warning(f"Failed to load Whisper model on {device}. Falling back to CPU. Error: {str(e)}")
36
  device = "cpu"
37
  compute_type = "int8"
38
- whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
 
 
 
39
 
 
40
  def load_diarization_pipeline():
41
- global diarization_pipeline, device
42
-
43
- # Try to initialize diarization pipeline
44
  try:
45
- diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
46
- if device == "cuda":
47
- diarization_pipeline = diarization_pipeline.to(torch.device(device))
 
 
48
  except Exception as e:
49
  logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
50
- diarization_pipeline = None
 
51
 
52
  def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
53
  chunks = []
@@ -58,18 +61,17 @@ def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000
58
  chunks.append(chunk)
59
  return chunks
60
 
 
61
  def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
62
  merged = []
63
  for segment in segments:
64
  if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
65
  merged.append(segment)
66
  else:
67
- # Find the overlap
68
  matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
69
  match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
70
 
71
  if match.size / len(segment['text']) > similarity_threshold:
72
- # Merge the segments
73
  merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
74
  merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
75
 
@@ -78,11 +80,9 @@ def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7
78
  if 'translated' in segment:
79
  merged[-1]['translated'] = merged_translated
80
  else:
81
- # If no significant overlap, append as a new segment
82
  merged.append(segment)
83
  return merged
84
 
85
- # Helper function to get the most common speaker in a time range
86
  def get_most_common_speaker(diarization_result, start_time, end_time):
87
  speakers = []
88
  for turn, _, speaker in diarization_result.itertracks(yield_label=True):
@@ -90,7 +90,6 @@ def get_most_common_speaker(diarization_result, start_time, end_time):
90
  speakers.append(speaker)
91
  return max(set(speakers), key=speakers.count) if speakers else "Unknown"
92
 
93
- # Helper function to split long audio files
94
  def split_audio(audio, max_duration=30):
95
  sample_rate = 16000
96
  max_samples = max_duration * sample_rate
@@ -104,25 +103,19 @@ def split_audio(audio, max_duration=30):
104
 
105
  return splits
106
 
107
- # Main processing function with optimizations
108
  @spaces.GPU(duration=60)
109
  def process_audio(audio_file, translate=False, model_size="small", use_diarization=True):
110
- global whisper_model, diarization_pipeline
111
-
112
- if whisper_model is None:
113
- load_models(model_size)
114
-
115
  start_time = time.time()
116
 
117
  try:
 
118
  audio = whisperx.load_audio(audio_file)
119
  audio_splits = split_audio(audio)
120
 
121
- # Perform diarization if requested and pipeline is available
122
  diarization_result = None
123
  if use_diarization:
124
- if diarization_pipeline is None:
125
- load_diarization_pipeline()
126
  if diarization_pipeline is not None:
127
  try:
128
  diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
@@ -139,8 +132,8 @@ def process_audio(audio_file, translate=False, model_size="small", use_diarizati
139
  lang = result["language"]
140
 
141
  for segment in result["segments"]:
142
- segment_start = segment["start"] + (i * 30) # Adjust start time based on split
143
- segment_end = segment["end"] + (i * 30) # Adjust end time based on split
144
 
145
  speaker = "Unknown"
146
  if diarization_result is not None:
 
16
  CHUNK_LENGTH = 30
17
  OVERLAP = 2
18
 
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
20
  logger = logging.getLogger(__name__)
21
 
 
 
 
 
 
22
 
23
+ @spaces.GPU(duration=60)
24
+ def load_whisper_model(model_size="small"):
25
+ logger.info(f"Loading Whisper model (size: {model_size})...")
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ compute_type = "float16" if device == "cuda" else "int8"
28
  try:
29
+ model = whisperx.load_model(model_size, device, compute_type=compute_type)
30
+ logger.info(f"Whisper model loaded successfully on {device}")
31
+ return model
32
  except RuntimeError as e:
33
  logger.warning(f"Failed to load Whisper model on {device}. Falling back to CPU. Error: {str(e)}")
34
  device = "cpu"
35
  compute_type = "int8"
36
+ model = whisperx.load_model(model_size, device, compute_type=compute_type)
37
+ logger.info("Whisper model loaded successfully on CPU")
38
+ return model
39
+
40
 
41
+ @spaces.GPU(duration=60)
42
  def load_diarization_pipeline():
43
+ logger.info("Loading diarization pipeline...")
 
 
44
  try:
45
+ pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
46
+ if torch.cuda.is_available():
47
+ pipeline = pipeline.to(torch.device("cuda"))
48
+ logger.info("Diarization pipeline loaded successfully")
49
+ return pipeline
50
  except Exception as e:
51
  logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
52
+ return None
53
+
54
 
55
  def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
56
  chunks = []
 
61
  chunks.append(chunk)
62
  return chunks
63
 
64
+
65
  def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
66
  merged = []
67
  for segment in segments:
68
  if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
69
  merged.append(segment)
70
  else:
 
71
  matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
72
  match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
73
 
74
  if match.size / len(segment['text']) > similarity_threshold:
 
75
  merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
76
  merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
77
 
 
80
  if 'translated' in segment:
81
  merged[-1]['translated'] = merged_translated
82
  else:
 
83
  merged.append(segment)
84
  return merged
85
 
 
86
  def get_most_common_speaker(diarization_result, start_time, end_time):
87
  speakers = []
88
  for turn, _, speaker in diarization_result.itertracks(yield_label=True):
 
90
  speakers.append(speaker)
91
  return max(set(speakers), key=speakers.count) if speakers else "Unknown"
92
 
 
93
  def split_audio(audio, max_duration=30):
94
  sample_rate = 16000
95
  max_samples = max_duration * sample_rate
 
103
 
104
  return splits
105
 
 
106
  @spaces.GPU(duration=60)
107
  def process_audio(audio_file, translate=False, model_size="small", use_diarization=True):
108
+ logger.info(f"Starting audio processing: translate={translate}, model_size={model_size}, use_diarization={use_diarization}")
 
 
 
 
109
  start_time = time.time()
110
 
111
  try:
112
+ whisper_model = load_whisper_model(model_size)
113
  audio = whisperx.load_audio(audio_file)
114
  audio_splits = split_audio(audio)
115
 
 
116
  diarization_result = None
117
  if use_diarization:
118
+ diarization_pipeline = load_diarization_pipeline()
 
119
  if diarization_pipeline is not None:
120
  try:
121
  diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
 
132
  lang = result["language"]
133
 
134
  for segment in result["segments"]:
135
+ segment_start = segment["start"] + (i * 30)
136
+ segment_end = segment["end"] + (i * 30)
137
 
138
  speaker = "Unknown"
139
  if diarization_result is not None: