nickkornienko's picture
added feedback submit confirmation
9d7da71
raw
history blame
5.37 kB
import gradio as gr
import torch
import torch.nn as nn
from joblib import load
import numpy as np
import pandas as pd
# Define the neural network model
class ImprovedSongRecommender(nn.Module):
def __init__(self, input_size, num_titles):
super(ImprovedSongRecommender, self).__init__()
self.fc1 = nn.Linear(input_size, 128)
self.bn1 = nn.BatchNorm1d(128)
self.fc2 = nn.Linear(128, 256)
self.bn2 = nn.BatchNorm1d(256)
self.fc3 = nn.Linear(256, 128)
self.bn3 = nn.BatchNorm1d(128)
self.output = nn.Linear(128, num_titles)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = torch.relu(self.bn1(self.fc1(x)))
x = self.dropout(x)
x = torch.relu(self.bn2(self.fc2(x)))
x = self.dropout(x)
x = torch.relu(self.bn3(self.fc3(x)))
x = self.dropout(x)
x = self.output(x)
return x
# Load the trained model
model_path = "models/improved_model.pth"
num_unique_titles = 4855
model = ImprovedSongRecommender(input_size=2, num_titles=num_unique_titles)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Load the label encoders
label_encoders_path = "data/new_label_encoders.joblib"
label_encoders = load(label_encoders_path)
def encode_input(tags, artist_name):
tags_list = [tag.strip() for tag in tags.split(',')]
encoded_tags_list = []
for tag in tags_list:
try:
encoded_tags_list.append(
label_encoders['tags'].transform([tag])[0])
except ValueError:
encoded_tags_list.append(
label_encoders['tags'].transform(['unknown'])[0])
encoded_tags = np.mean(encoded_tags_list).astype(
int) if encoded_tags_list else label_encoders['tags'].transform(['unknown'])[0]
try:
encoded_artist = label_encoders['artist_name'].transform([artist_name])[
0]
except ValueError:
encoded_artist = label_encoders['artist_name'].transform(['unknown'])[
0]
return [encoded_tags, encoded_artist]
def recommend_songs(tags, artist_name):
encoded_input = encode_input(tags, artist_name)
input_tensor = torch.tensor([encoded_input]).float()
with torch.no_grad():
output = model(input_tensor)
recommendations_indices = torch.topk(output, 5).indices.squeeze().tolist()
recommendations = [label_encoders['title'].inverse_transform(
[idx])[0] for idx in recommendations_indices]
return recommendations
def record_feedback(tags, recommendations, feedbacks):
# Load existing feedback if it exists
try:
feedback_df = pd.read_csv("feedback_data/feedback_data.csv")
except FileNotFoundError:
feedback_df = pd.DataFrame(
columns=["Tags", "Recommendation", "Feedback"])
# Create new feedback entries
new_feedbacks = pd.DataFrame({
"Tags": [tags] * len(recommendations),
"Recommendation": recommendations,
"Feedback": feedbacks
})
# Only keep rows where both a song recommendation and a rating are present
new_feedbacks = new_feedbacks[new_feedbacks["Recommendation"]
!= "No recommendations found"]
new_feedbacks = new_feedbacks[new_feedbacks["Feedback"].notna()]
# Append new feedback to the existing dataframe
feedback_df = pd.concat([feedback_df, new_feedbacks], ignore_index=True)
# Save the updated dataframe to CSV
feedback_df.to_csv("feedback_data/feedback_data.csv", index=False)
return "Feedback recorded!"
app = gr.Blocks()
with app:
gr.Markdown("## Music Recommendation System")
tags_input = gr.Textbox(
label="Enter Tags (e.g., rock, jazz, pop)", placeholder="rock, pop")
submit_button = gr.Button("Get Recommendations")
recommendation_outputs = [
gr.HTML(label=f"Recommendation {i+1}") for i in range(5)]
feedback_inputs = [gr.Radio(
choices=["Thumbs Up", "Thumbs Down"], label=f"Feedback {i+1}") for i in range(5)]
feedback_submit_button = gr.Button("Submit Feedback")
song_recommendations = []
def display_recommendations(tags):
global song_recommendations
song_recommendations = recommend_songs(tags, "")
updated_recommendations = [
gr.update(value=song) for song in song_recommendations]
updated_feedbacks = [gr.update(label=song)
for song in song_recommendations]
return updated_recommendations + updated_feedbacks
submit_button.click(
fn=display_recommendations,
inputs=[tags_input],
outputs=recommendation_outputs + feedback_inputs
)
def collect_feedback(tags, *feedbacks):
global song_recommendations
feedbacks = list(feedbacks)
record_feedback(tags, song_recommendations, feedbacks)
return "Feedback submitted successfully!"
feedback_confirmation = gr.Markdown("")
feedback_submit_button.click(
fn=collect_feedback,
inputs=[tags_input] + feedback_inputs,
outputs=feedback_confirmation
)
for i in range(5):
with gr.Row():
gr.Column([recommendation_outputs[i], feedback_inputs[i]])
with gr.Row():
gr.Column([feedback_submit_button, feedback_confirmation])
app.launch()