ASR_gradio / audio_processing.py
Kr08's picture
Update audio_processing.py
bcd8a56 verified
raw
history blame
6.11 kB
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
hf_token = os.getenv("HF_TOKEN")
CHUNK_LENGTH=3
OVERLAP=1
import whisperx
import torch
import numpy as np
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
import spaces
def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000): # 2 seconds overlap
chunks = []
for i in range(0, len(audio), chunk_size - overlap):
chunk = audio[i:i+chunk_size]
if len(chunk) < chunk_size:
chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
chunks.append(chunk)
return chunks
@spaces.GPU(duration=60)
def process_audio(audio_file, translate=False, model_size="small"):
start_time = time.time()
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
compute_type = "int8" if torch.cuda.is_available() else "float32"
audio = whisperx.load_audio(audio_file)
model = whisperx.load_model(model_size, device, compute_type=compute_type)
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
diarization_pipeline = diarization_pipeline.to(torch.device(device))
diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
chunks = preprocess_audio(audio)
language_segments = []
final_segments = []
overlap_duration = OVERLAP # 2 seconds overlap
for i, chunk in enumerate(chunks):
chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
chunk_end_time = chunk_start_time + CHUNK_LENGTH
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
lang = model.detect_language(chunk)
result_transcribe = model.transcribe(chunk, language=lang)
if translate:
result_translate = model.transcribe(chunk, task="translate")
chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
for j, t_seg in enumerate(result_transcribe["segments"]):
segment_start = chunk_start_time + t_seg["start"]
segment_end = chunk_start_time + t_seg["end"]
# Skip segments in the overlapping region of the previous chunk
if i > 0 and segment_end <= chunk_start_time + overlap_duration:
print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
continue
# Skip segments in the overlapping region of the next chunk
if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
continue
speakers = []
for turn, track, speaker in diarization_result.itertracks(yield_label=True):
if turn.start <= segment_end and turn.end >= segment_start:
speakers.append(speaker)
segment = {
"start": segment_start,
"end": segment_end,
"language": lang,
"speaker": max(set(speakers), key=speakers.count) if speakers else "Unknown",
"text": t_seg["text"],
}
if translate:
segment["translated"] = result_translate["segments"][j]["text"]
final_segments.append(segment)
language_segments.append({
"language": lang,
"start": chunk_start_time,
"end": chunk_start_time + CHUNK_LENGTH
})
chunk_end_time = time.time()
logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")
final_segments.sort(key=lambda x: x["start"])
merged_segments = merge_nearby_segments(final_segments)
end_time = time.time()
logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
return language_segments, final_segments
except Exception as e:
logger.error(f"An error occurred during audio processing: {str(e)}")
raise
def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.9):
merged = []
for segment in segments:
if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
merged.append(segment)
else:
# Find the overlap
matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
if match.size / len(segment['text']) > similarity_threshold:
# Merge the segments
merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
merged_translated = merged[-1]['translated'] + segment['translated'][match.b + match.size:]
merged[-1]['end'] = segment['end']
merged[-1]['text'] = merged_text
merged[-1]['translated'] = merged_translated
else:
# If no significant overlap, append as a new segment
merged.append(segment)
return merged
def print_results(segments):
for segment in segments:
print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:")
print(f"Original: {segment['text']}")
if 'translated' in segment:
print(f"Translated: {segment['translated']}")
print()