|
import whisperx |
|
import torch |
|
import numpy as np |
|
from scipy.signal import resample |
|
from pyannote.audio import Pipeline |
|
import os |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
import logging |
|
import time |
|
from difflib import SequenceMatcher |
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
CHUNK_LENGTH=30 |
|
OVERLAP=0 |
|
import whisperx |
|
import torch |
|
import numpy as np |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
import spaces |
|
|
|
|
|
def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000): |
|
chunks = [] |
|
for i in range(0, len(audio), chunk_size - overlap): |
|
chunk = audio[i:i+chunk_size] |
|
if len(chunk) < chunk_size: |
|
chunk = np.pad(chunk, (0, chunk_size - len(chunk))) |
|
chunks.append(chunk) |
|
return chunks |
|
|
|
@spaces.GPU() |
|
def process_audio(audio_file, translate=False, model_size="small"): |
|
start_time = time.time() |
|
|
|
try: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
compute_type = "int8" if torch.cuda.is_available() else "float32" |
|
audio = whisperx.load_audio(audio_file) |
|
model = whisperx.load_model(model_size, device, compute_type=compute_type) |
|
|
|
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token) |
|
diarization_pipeline = diarization_pipeline.to(torch.device(device)) |
|
|
|
diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000}) |
|
|
|
chunks = preprocess_audio(audio) |
|
|
|
language_segments = [] |
|
final_segments = [] |
|
|
|
overlap_duration = OVERLAP |
|
for i, chunk in enumerate(chunks): |
|
chunk_start_time = i * (CHUNK_LENGTH - overlap_duration) |
|
chunk_end_time = chunk_start_time + CHUNK_LENGTH |
|
logger.info(f"Processing chunk {i+1}/{len(chunks)}") |
|
lang = model.detect_language(chunk) |
|
result_transcribe = model.transcribe(chunk, language=lang) |
|
if translate: |
|
result_translate = model.transcribe(chunk, task="translate") |
|
chunk_start_time = i * (CHUNK_LENGTH - overlap_duration) |
|
for j, t_seg in enumerate(result_transcribe["segments"]): |
|
segment_start = chunk_start_time + t_seg["start"] |
|
segment_end = chunk_start_time + t_seg["end"] |
|
|
|
if i > 0 and segment_end <= chunk_start_time + overlap_duration: |
|
print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}") |
|
continue |
|
|
|
|
|
if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration: |
|
print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}") |
|
continue |
|
|
|
speakers = [] |
|
for turn, track, speaker in diarization_result.itertracks(yield_label=True): |
|
if turn.start <= segment_end and turn.end >= segment_start: |
|
speakers.append(speaker) |
|
|
|
segment = { |
|
"start": segment_start, |
|
"end": segment_end, |
|
"language": lang, |
|
"speaker": max(set(speakers), key=speakers.count) if speakers else "Unknown", |
|
"text": t_seg["text"], |
|
} |
|
|
|
if translate: |
|
segment["translated"] = result_translate["segments"][j]["text"] |
|
|
|
final_segments.append(segment) |
|
|
|
language_segments.append({ |
|
"language": lang, |
|
"start": chunk_start_time, |
|
"end": chunk_start_time + CHUNK_LENGTH |
|
}) |
|
chunk_end_time = time.time() |
|
logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds") |
|
|
|
final_segments.sort(key=lambda x: x["start"]) |
|
merged_segments = merge_nearby_segments(final_segments) |
|
|
|
end_time = time.time() |
|
logger.info(f"Total processing time: {end_time - start_time:.2f} seconds") |
|
|
|
return language_segments, final_segments |
|
except Exception as e: |
|
logger.error(f"An error occurred during audio processing: {str(e)}") |
|
raise |
|
|
|
def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.9): |
|
merged = [] |
|
for segment in segments: |
|
if not merged or segment['start'] - merged[-1]['end'] > time_threshold: |
|
merged.append(segment) |
|
else: |
|
|
|
matcher = SequenceMatcher(None, merged[-1]['text'], segment['text']) |
|
match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text'])) |
|
|
|
if match.size / len(segment['text']) > similarity_threshold: |
|
|
|
merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:] |
|
merged_translated = merged[-1]['translated'] + segment['translated'][match.b + match.size:] |
|
|
|
merged[-1]['end'] = segment['end'] |
|
merged[-1]['text'] = merged_text |
|
merged[-1]['translated'] = merged_translated |
|
else: |
|
|
|
merged.append(segment) |
|
return merged |
|
|
|
def print_results(segments): |
|
for segment in segments: |
|
print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:") |
|
print(f"Original: {segment['text']}") |
|
if 'translated' in segment: |
|
print(f"Translated: {segment['translated']}") |
|
print() |