import gradio as gr
import torch
import torch.nn as nn
from joblib import load
import numpy as np
import pandas as pd

# Define the neural network model


class ImprovedSongRecommender(nn.Module):
    def __init__(self, input_size, num_titles):
        super(ImprovedSongRecommender, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.bn1 = nn.BatchNorm1d(128)
        self.fc2 = nn.Linear(128, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.output = nn.Linear(128, num_titles)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = torch.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn3(self.fc3(x)))
        x = self.dropout(x)
        x = self.output(x)
        return x


# Load the trained model
model_path = "models/improved_model.pth"
num_unique_titles = 4855
model = ImprovedSongRecommender(input_size=2, num_titles=num_unique_titles)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Load the label encoders
label_encoders_path = "data/new_label_encoders.joblib"
label_encoders = load(label_encoders_path)


def encode_input(tags, artist_name):
    tags_list = [tag.strip() for tag in tags.split(',')]
    encoded_tags_list = []
    for tag in tags_list:
        try:
            encoded_tags_list.append(
                label_encoders['tags'].transform([tag])[0])
        except ValueError:
            encoded_tags_list.append(
                label_encoders['tags'].transform(['unknown'])[0])

    encoded_tags = np.mean(encoded_tags_list).astype(
        int) if encoded_tags_list else label_encoders['tags'].transform(['unknown'])[0]

    try:
        encoded_artist = label_encoders['artist_name'].transform([artist_name])[
            0]
    except ValueError:
        encoded_artist = label_encoders['artist_name'].transform(['unknown'])[
            0]

    return [encoded_tags, encoded_artist]


def recommend_songs(tags, artist_name):
    encoded_input = encode_input(tags, artist_name)
    input_tensor = torch.tensor([encoded_input]).float()
    with torch.no_grad():
        output = model(input_tensor)
    recommendations_indices = torch.topk(output, 5).indices.squeeze().tolist()
    recommendations = [label_encoders['title'].inverse_transform(
        [idx])[0] for idx in recommendations_indices]
    return recommendations


def record_feedback(tags, recommendations, feedbacks):
    # Load existing feedback if it exists
    try:
        feedback_df = pd.read_csv("feedback_data/feedback_data.csv")
    except FileNotFoundError:
        feedback_df = pd.DataFrame(
            columns=["Tags", "Recommendation", "Feedback"])

    # Create new feedback entries
    new_feedbacks = pd.DataFrame({
        "Tags": [tags] * len(recommendations),
        "Recommendation": recommendations,
        "Feedback": feedbacks
    })

    # Only keep rows where both a song recommendation and a rating are present
    new_feedbacks = new_feedbacks[new_feedbacks["Recommendation"]
                                  != "No recommendations found"]
    new_feedbacks = new_feedbacks[new_feedbacks["Feedback"].notna()]

    # Append new feedback to the existing dataframe
    feedback_df = pd.concat([feedback_df, new_feedbacks], ignore_index=True)

    # Save the updated dataframe to CSV
    feedback_df.to_csv("feedback_data/feedback_data.csv", index=False)

    return "Feedback recorded!"


app = gr.Blocks()

with app:
    gr.Markdown("## Music Recommendation System")
    tags_input = gr.Textbox(
        label="Enter Tags (e.g., rock, jazz, pop)", placeholder="rock, pop")
    submit_button = gr.Button("Get Recommendations")

    recommendation_outputs = [
        gr.HTML(label=f"Recommendation {i+1}") for i in range(5)]
    feedback_inputs = [gr.Radio(
        choices=["Thumbs Up", "Thumbs Down"], label=f"Feedback {i+1}") for i in range(5)]

    feedback_submit_button = gr.Button("Submit Feedback")

    song_recommendations = []

    def display_recommendations(tags):
        global song_recommendations
        song_recommendations = recommend_songs(tags, "")
        updated_recommendations = [
            gr.update(value=song) for song in song_recommendations]
        updated_feedbacks = [gr.update(label=song)
                             for song in song_recommendations]
        return updated_recommendations + updated_feedbacks

    submit_button.click(
        fn=display_recommendations,
        inputs=[tags_input],
        outputs=recommendation_outputs + feedback_inputs
    )

    def collect_feedback(tags, *feedbacks):
        global song_recommendations
        feedbacks = list(feedbacks)
        record_feedback(tags, song_recommendations, feedbacks)
        return "Feedback submitted successfully!"

    feedback_confirmation = gr.Markdown("")

    feedback_submit_button.click(
        fn=collect_feedback,
        inputs=[tags_input] + feedback_inputs,
        outputs=feedback_confirmation
    )

    for i in range(5):
        with gr.Row():
            gr.Column([recommendation_outputs[i], feedback_inputs[i]])

    with gr.Row():
        gr.Column([feedback_submit_button, feedback_confirmation])

app.launch()