Spaces:
Running
Running
import streamlit as st | |
from streamlit import session_state as session | |
from src.config.configs import ProjectPaths | |
import numpy as np | |
from src.laion_clap.inference import AudioEncoder | |
import pickle | |
import torch | |
import pandas as pd | |
import json | |
def load_data(): | |
vectors = np.load(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy")) | |
with open(ProjectPaths.DATA_DIR.joinpath("vectors", "song_names.pkl"), "rb") as reader: | |
song_names = pickle.load(reader) | |
with open(ProjectPaths.DATA_DIR.joinpath("json", "youtube_data.json"), "r") as reader: | |
youtube_data = json.load(reader) | |
df_youtube = pd.DataFrame(youtube_data) | |
df_youtube["id"] = df_youtube["artist_name"] + " - " + df_youtube["track_name"] + ".wav" | |
df_youtube.set_index("id", inplace=True) | |
return vectors, song_names, df_youtube | |
def load_model(): | |
recommender = AudioEncoder() | |
return recommender | |
recommender = load_model() | |
audio_vectors, song_names, df_youtube = load_data() | |
st.title("""Curate me a Playlist.""") | |
session.text_input = st.text_input(label="Describe a playlist") | |
session.slider_count = st.slider(label="Track counts", min_value=5, max_value=30, step=5) | |
buffer1, col1, buffer2 = st.columns([1.45, 1, 1]) | |
is_clicked = col1.button(label="Curate") | |
if is_clicked: | |
text_embed = recommender.get_text_embedding(session.text_input) | |
with torch.no_grad(): | |
ranking = torch.tensor(audio_vectors) @ torch.tensor(text_embed).t() | |
ranking = ranking[:, 0].reshape(-1, 1) | |
dataframe = pd.DataFrame(ranking, columns=[session.text_input], index=song_names).nlargest(int(session.slider_count), session.text_input).rename(columns={session.text_input: "score"}) | |
dataframe["link"] = df_youtube["link"] | |
st.dataframe(dataframe, use_container_width=True) | |