import os
import torch
import argparse
import numpy as np 
from transformers import AutoFeatureExtractor, WhisperModel

import torchaudio
import torch.nn as nn
import torch.nn.functional as F

import speechbrain
import librosa

from subprocess import CalledProcessError, run

#openai whispers load audio
SAMPLE_RATE=16000
def denorm(input_x):
    input_x = input_x*(5-0) + 0
    return input_x
    
def load_audio(file: str, sr: int = SAMPLE_RATE):
    """
    Open an audio file and read as mono waveform, resampling as necessary

    Parameters
    ----------
    file: str
        The audio file to open

    sr: int
        The sample rate to resample the audio if necessary

    Returns
    -------
    A NumPy array containing the audio waveform, in float32 dtype.
    """

    # This launches a subprocess to decode audio while down-mixing
    # and resampling as necessary.  Requires the ffmpeg CLI in PATH.
    # fmt: off
    cmd = [
        "ffmpeg",
        "-nostdin",
        "-threads", "0",
        "-i", file,
        "-f", "s16le",
        "-ac", "1",
        "-acodec", "pcm_s16le",
        "-ar", str(sr),
        "-"
    ]
    # fmt: on
    try:
        out = run(cmd, capture_output=True, check=True).stdout
    except CalledProcessError as e:
        raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

    return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0

class MosPredictor(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.mean_net_conv = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = 16, kernel_size = (3,3), padding = (1,1)),
            nn.Conv2d(in_channels = 16, out_channels = 16, kernel_size = (3,3), padding = (1,1)),
            nn.Conv2d(in_channels = 16, out_channels = 16, kernel_size = (3,3), padding = (1,1), stride=(1,3)),
            nn.Dropout(0.3),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = (3,3), padding = (1,1)),
            nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = (3,3), padding = (1,1)),
            nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = (3,3), padding = (1,1), stride=(1,3)),
            nn.Dropout(0.3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = (3,3), padding = (1,1)),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = (1,1)),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = (1,1), stride=(1,3)),
            nn.Dropout(0.3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = (3,3), padding = (1,1)),
            nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (3,3), padding = (1,1)),
            nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (3,3), padding = (1,1), stride=(1,3)),
            nn.Dropout(0.3),
            nn.BatchNorm2d(128),
            nn.ReLU())
        
        self.relu_ = nn.ReLU()
        self.sigmoid_ = nn.Sigmoid()
        
        self.ssl_features = 1280
        self.dim_layer = nn.Linear(self.ssl_features, 512)

        self.mean_net_rnn = nn.LSTM(input_size = 512, hidden_size = 128, num_layers = 1, batch_first = True, bidirectional = True)
        self.mean_net_dnn = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
        )        

        self.sinc = speechbrain.nnet.CNN.SincConv(in_channels=1, out_channels=257, kernel_size=251, stride=256, sample_rate=16000)
        self.att_output_layer_quality = nn.MultiheadAttention(128, num_heads=8)                
        self.output_layer_quality = nn.Linear(128, 1)
        self.qualaverage_score = nn.AdaptiveAvgPool1d(1)  
     
        self.att_output_layer_intell = nn.MultiheadAttention(128, num_heads=8)           
        self.output_layer_intell = nn.Linear(128, 1)
        self.intellaverage_score = nn.AdaptiveAvgPool1d(1)  
                       
        self.att_output_layer_stoi= nn.MultiheadAttention(128, num_heads=8)          
        self.output_layer_stoi = nn.Linear(128, 1)        
        self.stoiaverage_score = nn.AdaptiveAvgPool1d(1) 

    def new_method(self):
        self.sin_conv 
                
    def forward(self, wav, lps, whisper):
        #SSL Features
        wav_ = wav.squeeze(1)  ## [batches, audio_len]
        ssl_feat_red = self.dim_layer(whisper.squeeze(1))
        ssl_feat_red = self.relu_(ssl_feat_red)
 
        #PS Features
        sinc_feat=self.sinc(wav.squeeze(1))
        unsq_sinc =  torch.unsqueeze(sinc_feat, axis=1)
        concat_lps_sinc = torch.cat((lps,unsq_sinc), axis=2)
        cnn_out = self.mean_net_conv(concat_lps_sinc)
        batch = concat_lps_sinc.shape[0]
        time = concat_lps_sinc.shape[2]        
        re_cnn = cnn_out.view((batch, time, 512))
        
        concat_feat = torch.cat((re_cnn,ssl_feat_red), axis=1)
        out_lstm, (h, c) = self.mean_net_rnn(concat_feat)
        out_dense = self.mean_net_dnn(out_lstm) # (batch, seq, 1)       
        
        quality_att, _ = self.att_output_layer_quality (out_dense, out_dense, out_dense) 
        frame_quality = self.output_layer_quality(quality_att)
        frame_quality = self.sigmoid_(frame_quality)   
        quality_utt = self.qualaverage_score(frame_quality.permute(0,2,1))

        int_att, _ = self.att_output_layer_intell (out_dense, out_dense, out_dense) 
        frame_int = self.output_layer_intell(int_att)
        frame_int = self.sigmoid_(frame_int)   
        int_utt = self.intellaverage_score(frame_int.permute(0,2,1))

                
        return quality_utt.squeeze(1), int_utt.squeeze(1), frame_quality.squeeze(2), frame_int.squeeze(2)