# -*- coding: utf-8 -*-
"""CHULA Gino_Parkinson.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1XPgGZILiBbDji5G0dHoFV7OQaUwGM3HJ
"""

!pip install SoundFile transformers scikit-learn

from google.colab import drive
drive.mount('/content/drive')

import matplotlib.pyplot as plt
import numpy as np

import os
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from sklearn.model_selection import train_test_split
import re
from collections import Counter
from sklearn.metrics import classification_report

# Custom Dataset class
class DysarthriaDataset(Dataset):
    def __init__(self, data, labels, max_length=100000):
        self.data = data
        self.labels = labels
        self.max_length = max_length
        self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
          wav_data, _ = sf.read(self.data[idx])
        except:
          print(f"Error opening file: {self.data[idx]}. Skipping...")
          return self.__getitem__((idx + 1) % len(self.data))
        inputs = self.processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
        input_values = inputs.input_values.squeeze(0)  # Squeeze the batch dimension
        if self.max_length - input_values.shape[-1] > 0:
            input_values = torch.cat([input_values, torch.zeros((self.max_length - input_values.shape[-1],))], dim=-1)
        else:
            input_values = input_values[:self.max_length]

        # Remove unsqueezing the channel dimension
        # input_values = input_values.unsqueeze(0)

        # label = torch.zeros(32,dtype=torch.long)
        # label[self.labels[idx]] = 1

        ### CHANGES: simply return the label as a single integer
        return {"input_values": input_values}, self.labels[idx]
        # return {"input_values": input_values, "audio_path": self.data[idx]}, self.labels[idx]
        ###

def train(model, dataloader, criterion, optimizer, device, loss_vals, epochs, current_epoch):
    model.train()
    running_loss = 0

    for i, (inputs, labels) in enumerate(dataloader):
        inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(**inputs).logits
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        # append loss value to list
        loss_vals.append(loss.item())
        running_loss += loss.item()

        if i % 10 == 0:  # Update the plot every 10 iterations
            plt.clf()  # Clear the previous plot
            plt.plot(loss_vals)
            plt.xlim([0, len(dataloader)*epochs])
            plt.ylim([0, max(loss_vals) + 2])
            plt.xlabel('Training Iterations')
            plt.ylabel('Loss')
            plt.title(f"Training Loss at Epoch {current_epoch + 1}")
            plt.pause(0.001)  # Pause to update the plot
            
    avg_loss = running_loss / len(dataloader)
    print(f"Average Loss after Epoch {current_epoch + 1}: {avg_loss}\n")
    return avg_loss

def predict(model, file_path, processor, device, max_length=100000): ### CHANGES: added max_length as an argument.
    model.eval()
    with torch.no_grad():
        wav_data, _ = sf.read(file_path)
        inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
        # inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}

        ### NEW CODES HERE
        input_values = inputs.input_values.squeeze(0)  # Squeeze the batch dimension
        if max_length - input_values.shape[-1] > 0:
            input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
        else:
            input_values = input_values[:max_length]
        input_values = input_values.unsqueeze(0).to(device)
        inputs = {"input_values": input_values}
        ###

        logits = model(**inputs).logits
        # _, predicted = torch.max(logits, dim=0)

        ### NEW CODES HERE
        # Remove the batch dimension.
        logits = logits.squeeze()
        predicted_class_id = torch.argmax(logits, dim=-1).item()
        ###

    # return predicted.item()
    return predicted_class_id

def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0
    correct_predictions = 0
    total_predictions = 0
    wrong_files = []
    all_labels = []
    all_predictions = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
            labels = labels.to(device)

            logits = model(**inputs).logits
            loss = criterion(logits, labels)
            running_loss += loss.item()

            _, predicted = torch.max(logits, 1)
            correct_predictions += (predicted == labels).sum().item()
            total_predictions += labels.size(0)
             
            wrong_idx = (predicted != labels).nonzero().squeeze().cpu().numpy()
            if wrong_idx.ndim > 0:
                for idx in wrong_idx:
                    wrong_files.append(dataloader.dataset.data[idx])
            elif wrong_idx.size > 0:
                wrong_files.append(dataloader.dataset.data[wrong_idx])

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    avg_loss = running_loss / len(dataloader)
    accuracy = correct_predictions / total_predictions
    
    return avg_loss, accuracy, wrong_files, np.array(all_labels), np.array(all_predictions)

def get_wav_files(base_path):
    wav_files = []
    for subject_folder in os.listdir(base_path):
        subject_path = os.path.join(base_path, subject_folder)
        if os.path.isdir(subject_path):
            for wav_file in os.listdir(subject_path):
                if wav_file.endswith('.wav'):
                    wav_files.append(os.path.join(subject_path, wav_file))

    return wav_files

def get_torgo_data(dysarthria_path, non_dysarthria_path):
    dysarthria_files = [os.path.join(dysarthria_path, f) for f in os.listdir(dysarthria_path) if f.endswith('.wav')]
    non_dysarthria_files = [os.path.join(non_dysarthria_path, f) for f in os.listdir(non_dysarthria_path) if f.endswith('.wav')]

    data = dysarthria_files + non_dysarthria_files
    labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)

    train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, stratify=labels)
    train_data, val_data, train_labels, val_labels = train_test_split(train_data, train_labels, test_size=0.25, stratify=train_labels) # 0.25 x 0.8 = 0.2

    return train_data, val_data, test_data, train_labels, val_labels, test_labels

dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS"
    non_dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS"

    dysarthria_files = get_wav_files(dysarthria_path)
    non_dysarthria_files = get_wav_files(non_dysarthria_path)



    data = dysarthria_files + non_dysarthria_files
    labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)

    train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, stratify=labels)
    train_data, val_data, train_labels, val_labels = train_test_split(train_data, train_labels, test_size=0.25, stratify=train_labels) # 0.25 x 0.8 = 0.2
    train_dataset = DysarthriaDataset(train_data, train_labels)
    test_dataset = DysarthriaDataset(test_data, test_labels)
    val_dataset = DysarthriaDataset(val_data, val_labels) # Create a validation dataset

    train_loader = DataLoader(train_dataset, batch_size=16, drop_last=False)
    test_loader = DataLoader(test_dataset, batch_size=16, drop_last=False)
    validation_loader = DataLoader(val_dataset, batch_size=16, drop_last=False) # Use the validation dataset for the validation_loader
    
    """ dysarthria_path = "/content/drive/MyDrive/torgo_data/dysarthria_male/training"
    non_dysarthria_path = "/content/drive/MyDrive/torgo_data/non_dysarthria_male/training"

    dysarthria_files = [os.path.join(dysarthria_path, f) for f in os.listdir(dysarthria_path) if f.endswith('.wav')]
    non_dysarthria_files = [os.path.join(non_dysarthria_path, f) for f in os.listdir(non_dysarthria_path) if f.endswith('.wav')]

    data = dysarthria_files + non_dysarthria_files
    labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)

    train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2)

    train_dataset = DysarthriaDataset(train_data, train_labels)
    test_dataset = DysarthriaDataset(test_data, test_labels)

    train_loader = DataLoader(train_dataset, batch_size=8, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=8, drop_last=True)
    validation_loader = DataLoader(test_dataset, batch_size=8, drop_last=True)

    dysarthria_validation_path = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation"
    non_dysarthria_validation_path = "/content/drive/MyDrive/torgo_data/non_dysarthria_male/validation"

    dysarthria_validation_files = [os.path.join(dysarthria_validation_path, f) for f in os.listdir(dysarthria_validation_path) if f.endswith('.wav')]
    non_dysarthria_validation_files = [os.path.join(non_dysarthria_validation_path, f) for f in os.listdir(non_dysarthria_validation_path) if f.endswith('.wav')]

    validation_data = dysarthria_validation_files + non_dysarthria_validation_files
    validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)"""




    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")












    # model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
    # model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)

    ### NEW CODES
    # It seems like the classifier layer is excluded from the model's forward method (i.e., model(**inputs)).
    # That's why the number of labels in the output was 32 instead of 2 even when you had already changed the classifier.
    # Instead, huggingface offers the option for loading the Wav2Vec model with an adjustable classifier head on top (by setting num_labels).
    
    model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
    ##
    model_path = "/content/dysarthria_classifier1.pth"
    if os.path.exists(model_path):
        print(f"Loading saved model {model_path}")
        model.load_state_dict(torch.load(model_path))

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

from torch.optim.lr_scheduler import StepLR

scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

# dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/testing"
   # non_dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/testing"

    #dysarthria_validation_files = get_wav_files(dysarthria_validation_path)
   # non_dysarthria_validation_files = get_wav_files(non_dysarthria_validation_path)

    #validation_data = dysarthria_validation_files + non_dysarthria_validation_files
    #validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)

epochs = 10
plt.ion()
fig, ax = plt.subplots()
x_vals = np.arange(len(train_loader)*epochs)
loss_vals = []
for epoch in range(epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch)
    print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")

    val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device)
    print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
    print("Misclassified Files")
    for file_path in wrong_files:
      print(file_path)
    

    sentence_pattern = re.compile(r"_(\d+)\.wav$")

    sentence_counts = Counter()
    for file_path in wrong_files:
        match = sentence_pattern.search(file_path)
        if match:
            sentence_number = int(match.group(1))
            sentence_counts[sentence_number] += 1

    total_wrong = len(wrong_files)
    print("Total wrong files:", total_wrong)
    print()

    for sentence_number, count in sentence_counts.most_common():
        percent = count / total_wrong * 100
        print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
    scheduler.step()
    print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria']))
audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
predicted_label = predict(model, audio_file, train_dataset.processor, device)
print(f"Predicted label: {predicted_label}")

    



# Test on a specific audio file
##audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
##predicted_label = predict(model, audio_file, train_dataset.processor, device)
##print(f"Predicted label: {predicted_label}")

torch.save(model.state_dict(), "dysarthria_classifier1.pth")
print("Predicting...")

"""#audio aug"""

!pip install audiomentations
from audiomentations import Compose, PitchShift, TimeStretch

augmenter = Compose([
    PitchShift(min_semitones=-2, max_semitones=2, p=0.1),
    TimeStretch(min_rate=0.9, max_rate=1.1, p=0.1)
])

# from torch.optim.lr_scheduler import StepLR

# scheduler = StepLR(optimizer, step_size=2, gamma=0.5)

from transformers import get_linear_schedule_with_warmup

# Define the total number of training steps
# It is usually the number of epochs times the number of batches per epoch
num_training_steps = epochs * len(train_loader)

# Define the number of warmup steps
# Usually set to a fraction of total_training_steps such as 0.1 * num_training_steps
num_warmup_steps = int(num_training_steps * 0.3)

# Create the learning rate scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)

model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
    ##
model_path = "/content/models/my_model_06/pytorch_model.bin"
if os.path.exists(model_path):
    print(f"Loading saved model {model_path}")
    model.load_state_dict(torch.load(model_path))

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

import numpy as np

def trainaug(model, dataloader, criterion, optimizer, device, loss_vals, epochs, current_epoch):
    model.train()
    running_loss = 0

    for i, (inputs, labels) in enumerate(dataloader):
        inputs = {key: value.squeeze().to(device) for key, value in inputs.items() if torch.is_tensor(value)}
        labels = labels.to(device)

        # Apply audio augmentation
        augmented_audio = []
        for audio in inputs['input_values']:
            # The augmenter works with numpy arrays, so we need to convert the tensor to a numpy array
            audio_np = audio.cpu().numpy()

            # Apply the augmentation
            augmented = augmenter(audio_np, sample_rate=16000)  # Assuming a sample rate of 16000Hz

            augmented_audio.append(augmented)

        # Convert the list of numpy arrays back to a tensor
        inputs['input_values'] = torch.from_numpy(np.array(augmented_audio)).to(device)

        optimizer.zero_grad()
        logits = model(**inputs).logits
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        # append loss value to list
        loss_vals.append(loss.item())
        running_loss += loss.item()

        if i % 10 == 0:  # Update the plot every 10 iterations
            plt.clf()  # Clear the previous plot
            plt.plot(loss_vals)
            plt.xlim([0, len(dataloader)*epochs])
            plt.ylim([0, max(loss_vals) + 2])
            plt.xlabel('Training Iterations')
            plt.ylabel('Loss')
            plt.title(f"Training Loss at Epoch {current_epoch + 1}")
            plt.pause(0.001)  # Pause to update the plot
            
    avg_loss = running_loss / len(dataloader)
    print(f"Average Loss after Epoch {current_epoch + 1}: {avg_loss}\n")
    return avg_loss

epochs = 20
plt.ion()
fig, ax = plt.subplots()
x_vals = np.arange(len(train_loader)*epochs)
loss_vals = []
for epoch in range(epochs):
    train_loss = trainaug(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch)
    print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")

    val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device)
    print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
    print("Misclassified Files")
    for file_path in wrong_files:
      print(file_path)
    

    sentence_pattern = re.compile(r"_(\d+)\.wav$")

    sentence_counts = Counter()
    for file_path in wrong_files:
        match = sentence_pattern.search(file_path)
        if match:
            sentence_number = int(match.group(1))
            sentence_counts[sentence_number] += 1

    total_wrong = len(wrong_files)
    print("Total wrong files:", total_wrong)
    print()

    for sentence_number, count in sentence_counts.most_common():
        percent = count / total_wrong * 100
        print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
    scheduler.step()
    print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria']))
audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
# predicted_label = predict(model, audio_file, train_dataset.processor, device)
# print(f"Predicted label: {predicted_label}")

    



# Test on a specific audio file
##audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
##predicted_label = predict(model, audio_file, train_dataset.processor, device)
##print(f"Predicted label: {predicted_label}")

import re
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report

# Define the pattern to extract the sentence number from the file path
sentence_pattern = re.compile(r"_(\d+)\.wav$")

# Counter for the total number of each sentence type in the dataset
total_sentence_counts = Counter()

for file_path in train_loader.dataset.data: # Access the file paths directly
    match = sentence_pattern.search(file_path)
    if match:
        sentence_number = int(match.group(1))
        total_sentence_counts[sentence_number] += 1

epochs = 1
plt.ion()
fig, ax = plt.subplots()
x_vals = np.arange(len(train_loader)*epochs)
loss_vals = []

for epoch in range(epochs):
    # train_loss = trainaug(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch)
    # print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")

    val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device)
    print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
    print("Misclassified Files")
    for file_path in wrong_files:
        print(file_path)

    # Counter for the misclassified sentences
    sentence_counts = Counter()

    for file_path in wrong_files:
        match = sentence_pattern.search(file_path)
        if match:
            sentence_number = int(match.group(1))
            sentence_counts[sentence_number] += 1

    print("Total wrong files:", len(wrong_files))
    print()

    for sentence_number, count in sentence_counts.most_common():
        percent = count / total_sentence_counts[sentence_number] * 100
        print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
        
    scheduler.step()
    print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria']))

torch.save(model.state_dict(), "dysarthria_classifier2.pth")

save_dir = "models/my_model_06"
model.save_pretrained(save_dir)

"""## Cross testing

"""

# dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/testing"
   # non_dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/testing"

    #dysarthria_validation_files = get_wav_files(dysarthria_validation_path)
   # non_dysarthria_validation_files = get_wav_files(non_dysarthria_validation_path)

    #validation_data = dysarthria_validation_files + non_dysarthria_validation_files
    #validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)

epochs = 1
plt.ion()
fig, ax = plt.subplots()
x_vals = np.arange(len(train_loader)*epochs)
loss_vals = []
for epoch in range(epochs):
    #train_loss = train(model, train_loader, criterion, optimizer, device, loss_vals, epochs, epoch)
    #print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")

    val_loss, val_accuracy, wrong_files, true_labels, pred_labels = evaluate(model, validation_loader, criterion, device)
    print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
    print("Misclassified Files")
    for file_path in wrong_files:
      print(file_path)
    

    sentence_pattern = re.compile(r"_(\d+)\.wav$")

    sentence_counts = Counter()
    for file_path in wrong_files:
        match = sentence_pattern.search(file_path)
        if match:
            sentence_number = int(match.group(1))
            sentence_counts[sentence_number] += 1

    total_wrong = len(wrong_files)
    print("Total wrong files:", total_wrong)
    print()

    for sentence_number, count in sentence_counts.most_common():
        percent = count / total_wrong * 100
        print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")
    scheduler.step()
    print(classification_report(true_labels, pred_labels, target_names=['non_dysarthria', 'dysarthria']))
audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
predicted_label = predict(model, audio_file, train_dataset.processor, device)
print(f"Predicted label: {predicted_label}")

    



# Test on a specific audio file
##audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
##predicted_label = predict(model, audio_file, train_dataset.processor, device)
##print(f"Predicted label: {predicted_label}")

"""## DEBUGGING"""

dysarthria_path = "/content/drive/MyDrive/torgo_data/dysarthria_male/training"
non_dysarthria_path = "/content/drive/MyDrive/torgo_data/non_dysarthria_male/training"

dysarthria_files = [os.path.join(dysarthria_path, f) for f in os.listdir(dysarthria_path) if f.endswith('.wav')]
non_dysarthria_files = [os.path.join(non_dysarthria_path, f) for f in os.listdir(non_dysarthria_path) if f.endswith('.wav')]

data = dysarthria_files + non_dysarthria_files
labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)

train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2)

train_dataset = DysarthriaDataset(train_data, train_labels)
test_dataset = DysarthriaDataset(test_data, test_labels)

train_loader = DataLoader(train_dataset, batch_size=4, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=4, drop_last=True)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
# model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)

model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)

max_length = 100_000
processor = train_dataset.processor

model.eval()
audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
# predicted_label = predict(model, audio_file, train_dataset.processor, device)
# print(f"Predicted label: {predicted_label}")

wav_data, _ = sf.read(audio_file)
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values.squeeze(0)  # Squeeze the batch dimension
if max_length - input_values.shape[-1] > 0:
    input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
else:
    input_values = input_values[:max_length]

input_values = input_values.unsqueeze(0).to(device)
input_values.shape

with torch.no_grad():
  outputs = model(**{"input_values": input_values})
  logits = outputs.logits

input_values.shape, logits.shape

import torch.nn.functional as F
# Remove the batch dimension.
logits = logits.squeeze()
predicted_class_id = torch.argmax(logits, dim=-1)
predicted_class_id

"""Cross testing

##origial code
"""

import os
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from sklearn.model_selection import train_test_split

# Custom Dataset class
class DysarthriaDataset(Dataset):
    def __init__(self, data, labels, max_length=100000):
        self.data = data
        self.labels = labels
        self.max_length = max_length
        self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
          wav_data, _ = sf.read(self.data[idx])
        except:
          print(f"Error opening file: {self.data[idx]}. Skipping...")
          return self.__getitem__((idx + 1) % len(self.data))
        inputs = self.processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
        input_values = inputs.input_values.squeeze(0)  # Squeeze the batch dimension
        if self.max_length - input_values.shape[-1] > 0:
            input_values = torch.cat([input_values, torch.zeros((self.max_length - input_values.shape[-1],))], dim=-1)
        else:
            input_values = input_values[:self.max_length]

        # Remove unsqueezing the channel dimension
        # input_values = input_values.unsqueeze(0)

        # label = torch.zeros(32,dtype=torch.long)
        # label[self.labels[idx]] = 1

        ### CHANGES: simply return the label as a single integer
        return {"input_values": input_values}, self.labels[idx]
        ###


def train(model, dataloader, criterion, optimizer, device, ax, loss_vals, x_vals, fig,train_loader,epochs):
    model.train()
    running_loss = 0

    for i, (inputs, labels) in enumerate(dataloader):
        inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(**inputs).logits
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        # append loss value to list
        loss_vals.append(loss.item())
        running_loss += loss.item()

        if i:
            # update plot
            ax.clear()
            ax.set_xlim([0, len(train_loader)*epochs])
            ax.set_xlabel('Training Iterations')
            ax.set_ylim([0, max(loss_vals) + 2])
            ax.set_ylabel('Loss')
            ax.plot(x_vals[:len(loss_vals)], loss_vals)
            fig.canvas.draw()
            plt.pause(0.001)
            
    avg_loss = running_loss / len(dataloader)
    print(avg_loss)
    print("\n")
    return avg_loss



def main():
    dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/training"
    non_dysarthria_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/training"

    dysarthria_files = get_wav_files(dysarthria_path)
    non_dysarthria_files = get_wav_files(non_dysarthria_path)

    data = dysarthria_files + non_dysarthria_files
    labels = [1] * len(dysarthria_files) + [0] * len(non_dysarthria_files)

    train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2)

    train_dataset = DysarthriaDataset(train_data, train_labels)
    test_dataset = DysarthriaDataset(test_data, test_labels)

    train_loader = DataLoader(train_dataset, batch_size=8, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=8, drop_last=True)
    validation_loader = DataLoader(test_dataset, batch_size=8, drop_last=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
    # model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)

    ### NEW CODES
    # It seems like the classifier layer is excluded from the model's forward method (i.e., model(**inputs)).
    # That's why the number of labels in the output was 32 instead of 2 even when you had already changed the classifier.
    # Instead, huggingface offers the option for loading the Wav2Vec model with an adjustable classifier head on top (by setting num_labels).
    
    model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
    ###
    #model_path = "/content/dysarthria_classifier3.pth"
    #if os.path.exists(model_path):
        #print(f"Loading saved model {model_path}")
        #model.load_state_dict(torch.load(model_path))

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
    dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/SP_ANALYSIS/testing"
    non_dysarthria_validation_path = "/content/drive/MyDrive/RECORDINGS_ANALYSIS/CT_ANALYSIS/testing"

    dysarthria_validation_files = get_wav_files(dysarthria_validation_path)
    non_dysarthria_validation_files = get_wav_files(non_dysarthria_validation_path)

    validation_data = dysarthria_validation_files + non_dysarthria_validation_files
    validation_labels = [1] * len(dysarthria_validation_files) + [0] * len(non_dysarthria_validation_files)

    epochs = 10
    fig, ax = plt.subplots()
    x_vals = np.arange(len(train_loader)*epochs)
    loss_vals = []
    nume = 1
    for epoch in range(epochs):
        train_loss = train(model, train_loader, criterion, optimizer, device, ax, loss_vals, x_vals, fig, train_loader, epoch+1)
        print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")

        val_loss, val_accuracy, wrong_files = evaluate(model, validation_loader, criterion, device)
        print(f"Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy:.2f}")
        print("Misclassified Files")
        for file_path in wrong_files:
          print(file_path)
        

        sentence_pattern = re.compile(r"_(\d+)\.wav$")

        sentence_counts = Counter()
        for file_path in wrong_files:
            match = sentence_pattern.search(file_path)
            if match:
                sentence_number = int(match.group(1))
                sentence_counts[sentence_number] += 1

        total_wrong = len(wrong_files)
        print("Total wrong files:", total_wrong)
        print()

        for sentence_number, count in sentence_counts.most_common():
            percent = count / total_wrong * 100
            print(f"Sentence {sentence_number}: {count} ({percent:.2f}%)")


    torch.save(model.state_dict(), "dysarthria_classifier4.pth")
    print("Predicting...")
    # Test on a specific audio file
    ##audio_file = "/content/drive/MyDrive/torgo_data/dysarthria_male/validation/M01_Session1_0005.wav"
    ##predicted_label = predict(model, audio_file, train_dataset.processor, device)
    ##print(f"Predicted label: {predicted_label}")

def predict(model, file_path, processor, device, max_length=100000): ### CHANGES: added max_length as an argument.
    model.eval()
    with torch.no_grad():
        wav_data, _ = sf.read(file_path)
        inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
        # inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}

        ### NEW CODES HERE
        input_values = inputs.input_values.squeeze(0)  # Squeeze the batch dimension
        if max_length - input_values.shape[-1] > 0:
            input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
        else:
            input_values = input_values[:max_length]
        input_values = input_values.unsqueeze(0).to(device)
        inputs = {"input_values": input_values}
        ###

        logits = model(**inputs).logits
        # _, predicted = torch.max(logits, dim=0)

        ### NEW CODES HERE
        # Remove the batch dimension.
        logits = logits.squeeze()
        predicted_class_id = torch.argmax(logits, dim=-1).item()
        ###

    # return predicted.item()
    return predicted_class_id
def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0
    correct_predictions = 0
    total_predictions = 0
    wrong_files = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = {key: value.squeeze().to(device) for key, value in inputs.items()}
            labels = labels.to(device)

            logits = model(**inputs).logits
            loss = criterion(logits, labels)
            running_loss += loss.item()

            _, predicted = torch.max(logits, 1)
            correct_predictions += (predicted == labels).sum().item()
            total_predictions += labels.size(0)
             
            wrong_idx = (predicted != labels).nonzero().squeeze().cpu().numpy()
            if wrong_idx.ndim > 0:
              for idx in wrong_idx:
                  wrong_files.append(dataloader.dataset.data[idx])
            elif wrong_idx.size > 0:
                wrong_files.append(dataloader.dataset.data[wrong_idx])


    avg_loss = running_loss / len(dataloader)
    accuracy = correct_predictions / total_predictions
    return avg_loss, accuracy, wrong_files



def get_wav_files(base_path):
    wav_files = []
    for subject_folder in os.listdir(base_path):
        subject_path = os.path.join(base_path, subject_folder)
        if os.path.isdir(subject_path):
            for wav_file in os.listdir(subject_path):
                if wav_file.endswith('.wav'):
                    wav_files.append(os.path.join(subject_path, wav_file))
    return wav_files
if __name__ == "__main__":
    main()