File size: 6,110 Bytes
745e5b6
f36e52e
745e5b6
 
 
 
 
 
62b6f11
 
 
745e5b6
43f1b5e
40f6dae
 
bbbe230
 
 
 
43f1b5e
759bce7
62b6f11
bbbe230
745e5b6
6e73abb
bbbe230
745e5b6
62b6f11
745e5b6
 
 
 
 
 
adc7ace
bbbe230
62b6f11
 
 
bbbe230
 
 
62b6f11
bbbe230
 
 
 
 
 
745e5b6
bbbe230
745e5b6
62b6f11
 
745e5b6
bbbe230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745e5b6
bbbe230
 
 
 
62b6f11
bbbe230
 
 
 
 
 
62b6f11
 
 
bbbe230
 
62b6f11
 
 
bbbe230
62b6f11
bbbe230
 
62b6f11
 
bbbe230
 
62b6f11
bbbe230
 
745e5b6
62b6f11
 
745e5b6
62b6f11
 
 
bbbe230
62b6f11
 
bcd8a56
bbbe230
adc7ace
bbbe230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adc7ace
bbbe230
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
hf_token = os.getenv("HF_TOKEN")

CHUNK_LENGTH=3
OVERLAP=1
import whisperx
import torch
import numpy as np


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
import spaces


def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):  # 2 seconds overlap
    chunks = []
    for i in range(0, len(audio), chunk_size - overlap):
        chunk = audio[i:i+chunk_size]
        if len(chunk) < chunk_size:
            chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
        chunks.append(chunk)
    return chunks

@spaces.GPU(duration=60)
def process_audio(audio_file, translate=False, model_size="small"):
    start_time = time.time()
    
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {device}")
        compute_type = "int8" if torch.cuda.is_available() else "float32"
        audio = whisperx.load_audio(audio_file)
        model = whisperx.load_model(model_size, device, compute_type=compute_type)

        diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
        diarization_pipeline = diarization_pipeline.to(torch.device(device))

        diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})

        chunks = preprocess_audio(audio)

        language_segments = []
        final_segments = []
        
        overlap_duration = OVERLAP  # 2 seconds overlap
        for i, chunk in enumerate(chunks):
            chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
            chunk_end_time = chunk_start_time + CHUNK_LENGTH
            logger.info(f"Processing chunk {i+1}/{len(chunks)}")
            lang = model.detect_language(chunk)
            result_transcribe = model.transcribe(chunk, language=lang)
            if translate:
                result_translate = model.transcribe(chunk, task="translate")
            chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
            for j, t_seg in enumerate(result_transcribe["segments"]):
                segment_start = chunk_start_time + t_seg["start"]
                segment_end = chunk_start_time + t_seg["end"]
                # Skip segments in the overlapping region of the previous chunk
                if i > 0 and segment_end <= chunk_start_time + overlap_duration:
                    print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
                    continue
            
                # Skip segments in the overlapping region of the next chunk
                if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
                    print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
                    continue
            
                speakers = []
                for turn, track, speaker in diarization_result.itertracks(yield_label=True):
                    if turn.start <= segment_end and turn.end >= segment_start:
                        speakers.append(speaker)
            
                segment = {
                    "start": segment_start,
                    "end": segment_end,
                    "language": lang,
                    "speaker": max(set(speakers), key=speakers.count) if speakers else "Unknown",
                    "text": t_seg["text"],
                }
                
                if translate:
                    segment["translated"] = result_translate["segments"][j]["text"]
            
                final_segments.append(segment)

            language_segments.append({
                "language": lang,
                "start": chunk_start_time,
                "end": chunk_start_time + CHUNK_LENGTH
            })
            chunk_end_time = time.time()
            logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")

        final_segments.sort(key=lambda x: x["start"])
        merged_segments = merge_nearby_segments(final_segments)

        end_time = time.time()
        logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")

        return language_segments, final_segments
    except Exception as e:
        logger.error(f"An error occurred during audio processing: {str(e)}")
        raise 


def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.9):
    merged = []
    for segment in segments:
        if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
            merged.append(segment)
        else:
            # Find the overlap
            matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
            match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
            
            if match.size / len(segment['text']) > similarity_threshold:
                # Merge the segments
                merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
                merged_translated = merged[-1]['translated'] + segment['translated'][match.b + match.size:]
                
                merged[-1]['end'] = segment['end']
                merged[-1]['text'] = merged_text
                merged[-1]['translated'] = merged_translated
            else:
                # If no significant overlap, append as a new segment
                merged.append(segment)
    return merged


def print_results(segments):
    for segment in segments:
        print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:")
        print(f"Original: {segment['text']}")
        if 'translated' in segment:
            print(f"Translated: {segment['translated']}")
        print()