import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
hf_token = os.getenv("HF_TOKEN")

CHUNK_LENGTH=30
OVERLAP=0
import whisperx
import torch
import numpy as np


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
import spaces


def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):  # 2 seconds overlap
    chunks = []
    for i in range(0, len(audio), chunk_size - overlap):
        chunk = audio[i:i+chunk_size]
        if len(chunk) < chunk_size:
            chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
        chunks.append(chunk)
    return chunks

@spaces.GPU()
def process_audio(audio_file, translate=False, model_size="small"):
    start_time = time.time()
    
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {device}")
        compute_type = "int8" if torch.cuda.is_available() else "float32"
        audio = whisperx.load_audio(audio_file)
        model = whisperx.load_model(model_size, device, compute_type=compute_type)

        diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
        diarization_pipeline = diarization_pipeline.to(torch.device(device))

        diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})

        chunks = preprocess_audio(audio)

        language_segments = []
        final_segments = []
        
        overlap_duration = OVERLAP  # 2 seconds overlap
        for i, chunk in enumerate(chunks):
            chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
            chunk_end_time = chunk_start_time + CHUNK_LENGTH
            logger.info(f"Processing chunk {i+1}/{len(chunks)}")
            lang = model.detect_language(chunk)
            result_transcribe = model.transcribe(chunk, language=lang)
            if translate:
                result_translate = model.transcribe(chunk, task="translate")
            chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
            for j, t_seg in enumerate(result_transcribe["segments"]):
                segment_start = chunk_start_time + t_seg["start"]
                segment_end = chunk_start_time + t_seg["end"]
                # Skip segments in the overlapping region of the previous chunk
                if i > 0 and segment_end <= chunk_start_time + overlap_duration:
                    print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
                    continue
            
                # Skip segments in the overlapping region of the next chunk
                if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
                    print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
                    continue
            
                speakers = []
                for turn, track, speaker in diarization_result.itertracks(yield_label=True):
                    if turn.start <= segment_end and turn.end >= segment_start:
                        speakers.append(speaker)
            
                segment = {
                    "start": segment_start,
                    "end": segment_end,
                    "language": lang,
                    "speaker": max(set(speakers), key=speakers.count) if speakers else "Unknown",
                    "text": t_seg["text"],
                }
                
                if translate:
                    segment["translated"] = result_translate["segments"][j]["text"]
            
                final_segments.append(segment)

            language_segments.append({
                "language": lang,
                "start": chunk_start_time,
                "end": chunk_start_time + CHUNK_LENGTH
            })
            chunk_end_time = time.time()
            logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")

        final_segments.sort(key=lambda x: x["start"])
        merged_segments = merge_nearby_segments(final_segments)

        end_time = time.time()
        logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")

        return language_segments, final_segments
    except Exception as e:
        logger.error(f"An error occurred during audio processing: {str(e)}")
        raise

def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.9):
    merged = []
    for segment in segments:
        if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
            merged.append(segment)
        else:
            # Find the overlap
            matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
            match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
            
            if match.size / len(segment['text']) > similarity_threshold:
                # Merge the segments
                merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
                merged_translated = merged[-1]['translated'] + segment['translated'][match.b + match.size:]
                
                merged[-1]['end'] = segment['end']
                merged[-1]['text'] = merged_text
                merged[-1]['translated'] = merged_translated
            else:
                # If no significant overlap, append as a new segment
                merged.append(segment)
    return merged

def print_results(segments):
    for segment in segments:
        print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:")
        print(f"Original: {segment['text']}")
        if 'translated' in segment:
            print(f"Translated: {segment['translated']}")
        print()