Spaces:
Sleeping
Sleeping
File size: 6,110 Bytes
745e5b6 f36e52e 745e5b6 62b6f11 745e5b6 43f1b5e 40f6dae bbbe230 43f1b5e 759bce7 62b6f11 bbbe230 745e5b6 6e73abb bbbe230 745e5b6 62b6f11 745e5b6 adc7ace bbbe230 62b6f11 bbbe230 62b6f11 bbbe230 745e5b6 bbbe230 745e5b6 62b6f11 745e5b6 bbbe230 745e5b6 bbbe230 62b6f11 bbbe230 62b6f11 bbbe230 62b6f11 bbbe230 62b6f11 bbbe230 62b6f11 bbbe230 62b6f11 bbbe230 745e5b6 62b6f11 745e5b6 62b6f11 bbbe230 62b6f11 bcd8a56 bbbe230 adc7ace bbbe230 adc7ace bbbe230 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
hf_token = os.getenv("HF_TOKEN")
CHUNK_LENGTH=3
OVERLAP=1
import whisperx
import torch
import numpy as np
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
import spaces
def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000): # 2 seconds overlap
chunks = []
for i in range(0, len(audio), chunk_size - overlap):
chunk = audio[i:i+chunk_size]
if len(chunk) < chunk_size:
chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
chunks.append(chunk)
return chunks
@spaces.GPU(duration=60)
def process_audio(audio_file, translate=False, model_size="small"):
start_time = time.time()
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
compute_type = "int8" if torch.cuda.is_available() else "float32"
audio = whisperx.load_audio(audio_file)
model = whisperx.load_model(model_size, device, compute_type=compute_type)
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
diarization_pipeline = diarization_pipeline.to(torch.device(device))
diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
chunks = preprocess_audio(audio)
language_segments = []
final_segments = []
overlap_duration = OVERLAP # 2 seconds overlap
for i, chunk in enumerate(chunks):
chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
chunk_end_time = chunk_start_time + CHUNK_LENGTH
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
lang = model.detect_language(chunk)
result_transcribe = model.transcribe(chunk, language=lang)
if translate:
result_translate = model.transcribe(chunk, task="translate")
chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
for j, t_seg in enumerate(result_transcribe["segments"]):
segment_start = chunk_start_time + t_seg["start"]
segment_end = chunk_start_time + t_seg["end"]
# Skip segments in the overlapping region of the previous chunk
if i > 0 and segment_end <= chunk_start_time + overlap_duration:
print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
continue
# Skip segments in the overlapping region of the next chunk
if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
continue
speakers = []
for turn, track, speaker in diarization_result.itertracks(yield_label=True):
if turn.start <= segment_end and turn.end >= segment_start:
speakers.append(speaker)
segment = {
"start": segment_start,
"end": segment_end,
"language": lang,
"speaker": max(set(speakers), key=speakers.count) if speakers else "Unknown",
"text": t_seg["text"],
}
if translate:
segment["translated"] = result_translate["segments"][j]["text"]
final_segments.append(segment)
language_segments.append({
"language": lang,
"start": chunk_start_time,
"end": chunk_start_time + CHUNK_LENGTH
})
chunk_end_time = time.time()
logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")
final_segments.sort(key=lambda x: x["start"])
merged_segments = merge_nearby_segments(final_segments)
end_time = time.time()
logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
return language_segments, final_segments
except Exception as e:
logger.error(f"An error occurred during audio processing: {str(e)}")
raise
def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.9):
merged = []
for segment in segments:
if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
merged.append(segment)
else:
# Find the overlap
matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
if match.size / len(segment['text']) > similarity_threshold:
# Merge the segments
merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
merged_translated = merged[-1]['translated'] + segment['translated'][match.b + match.size:]
merged[-1]['end'] = segment['end']
merged[-1]['text'] = merged_text
merged[-1]['translated'] = merged_translated
else:
# If no significant overlap, append as a new segment
merged.append(segment)
return merged
def print_results(segments):
for segment in segments:
print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:")
print(f"Original: {segment['text']}")
if 'translated' in segment:
print(f"Translated: {segment['translated']}")
print() |