File size: 1,842 Bytes
19759e2
 
 
 
f90365c
b697762
 
 
1cb6e78
19759e2
 
640b157
19759e2
 
b697762
 
1cb6e78
 
 
 
 
 
 
 
19759e2
 
101af6f
 
 
 
19759e2
 
101af6f
1cb6e78
19759e2
101af6f
f90365c
101af6f
19759e2
 
 
 
 
b697762
 
 
eb4c8c4
1cb6e78
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import streamlit as st
from streamlit import session_state as session
from src.config.configs import ProjectPaths
import numpy as np
from src.laion_clap.inference import AudioEncoder
import pickle
import torch
import pandas as pd
import json


@st.cache_data
def load_data():
    vectors = np.load(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy"))
    with open(ProjectPaths.DATA_DIR.joinpath("vectors", "song_names.pkl"), "rb") as reader:
        song_names = pickle.load(reader)

    with open(ProjectPaths.DATA_DIR.joinpath("json", "youtube_data.json"), "r") as reader:
        youtube_data = json.load(reader)

    df_youtube = pd.DataFrame(youtube_data)
    df_youtube["id"] = df_youtube["artist_name"] + " - " + df_youtube["track_name"] + ".wav"
    df_youtube.set_index("id", inplace=True)
    return vectors, song_names, df_youtube


@st.cache_resource
def load_model():
    recommender = AudioEncoder()
    return recommender


recommender = load_model()
audio_vectors, song_names, df_youtube = load_data()

st.title("""Curate me a Playlist.""")
session.text_input = st.text_input(label="Describe a playlist")
session.slider_count = st.slider(label="Track counts", min_value=5, max_value=30, step=5)
buffer1, col1, buffer2 = st.columns([1.45, 1, 1])

is_clicked = col1.button(label="Curate")
if is_clicked:
    text_embed = recommender.get_text_embedding(session.text_input)
    with torch.no_grad():
        ranking = torch.tensor(audio_vectors) @ torch.tensor(text_embed).t()
        ranking = ranking[:, 0].reshape(-1, 1)
    dataframe = pd.DataFrame(ranking, columns=[session.text_input], index=song_names).nlargest(int(session.slider_count), session.text_input).rename(columns={session.text_input: "score"})
    dataframe["link"] = df_youtube["link"]

    st.dataframe(dataframe, use_container_width=True)