|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
import ast |
|
import random |
|
import torch |
|
import time |
|
from joblib import load |
|
|
|
from transformers import BertTokenizer, BertModel |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
""" |
|
## Сервис умного поиска сериалов 📽️ |
|
""" |
|
|
|
|
|
embeddings = np.loadtxt('data/embs.txt') |
|
|
|
model_path = "model" |
|
tokenizer_path = "tokenizer" |
|
|
|
loaded_model = BertModel.from_pretrained(model_path) |
|
|
|
loaded_tokenizer = BertTokenizer.from_pretrained(tokenizer_path) |
|
|
|
df = pd.read_csv('data/data.csv') |
|
df['ganres'] = df['ganres'].apply(lambda x: ast.literal_eval(x)) |
|
df['description'] = df['description'].astype(str) |
|
|
|
st.write(f'<p style="font-family: Arial, sans-serif; font-size: 24px; ">Наш сервис насчитывает \ |
|
{len(df)} лучших сериалов</p>', unsafe_allow_html=True) |
|
|
|
st.image('images/ser2.png') |
|
|
|
ganres_lst = sorted(['драма', 'документальный', 'биография', 'комедия', 'фэнтези', 'приключения', 'для детей', 'мультсериалы', |
|
'мелодрама', 'боевик', 'детектив', 'фантастика', 'триллер', 'семейный', 'криминал', 'исторический', 'музыкальные', |
|
'мистика', 'аниме', 'ужасы', 'спорт', 'скетч-шоу', 'военный', 'для взрослых', 'вестерн']) |
|
|
|
st.sidebar.header('Панель инструментов :gear:') |
|
choice_g = st.sidebar.multiselect("Выберите жанры", options=ganres_lst) |
|
n = st.sidebar.selectbox("Количество отображаемых элементов на странице", options=[5, 10, 15, 20, 30]) |
|
st.sidebar.info("📚 Для наилучшего соответствия, запрос должен быть максимально развернутым") |
|
|
|
text = st.text_input('Введите описание для рекомендации') |
|
|
|
|
|
loaded_model.eval() |
|
tokens = loaded_tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
|
start_time = time.time() |
|
tokens = {key: value.to(loaded_model.device) for key, value in tokens.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
output = loaded_model(**tokens) |
|
|
|
|
|
user_embedding = output.last_hidden_state.mean(dim=1).squeeze().cpu().detach().numpy() |
|
cosine_similarities = cosine_similarity(embeddings, user_embedding.reshape(1, -1)) |
|
|
|
button = st.button('Отправить запрос', type="primary") |
|
|
|
if text and button: |
|
|
|
if len(choice_g) == 0: |
|
choice_g = ganres_lst |
|
|
|
top_ind = np.unravel_index(np.argsort(cosine_similarities, axis=None)[-30:][::-1], cosine_similarities.shape) |
|
confidence = cosine_similarities[top_ind] |
|
top_ind = list(top_ind[0]) |
|
conf_dict = {} |
|
for value, conf in zip(top_ind, confidence): |
|
conf_dict[int(value)] = conf |
|
|
|
output_dict = {} |
|
for i in top_ind: |
|
for ganre in df['ganres'][i]: |
|
if ganre in choice_g: |
|
output_dict[i] = df['ganres'][i] |
|
|
|
sorted_lst = sorted(output_dict.items(), key=lambda x: len(set(x[1]) & set(choice_g)), reverse=True) |
|
n_lst = [i[0] for i in sorted_lst[:n]] |
|
st.write(f'<p style="font-family: Arial, sans-serif; font-size: 18px; text-align: center;"><strong>Всего подобранных \ |
|
рекомендаций {len(sorted_lst)}</strong></p>', unsafe_allow_html=True) |
|
st.write('\n') |
|
|
|
|
|
for i in n_lst: |
|
col1, col2 = st.columns([2, 5]) |
|
with col1: |
|
st.image(df['poster'][i], width=200) |
|
with col2: |
|
st.write(f"***Название:*** {df['title'][i]}") |
|
st.write(f"***Жанр:*** {', '.join(df['ganres'][i])}") |
|
st.write(f"***Описание:*** {df['description'][i]}") |
|
|
|
|
|
st.markdown(f"[***ссылка на сериал***]({df['url'][i]})") |
|
st.write(f"") |
|
end_time = time.time() |
|
st.write(f"<small>*Степень соответствия по косинусному сходству: {conf_dict[i]:.4f}*</small>", unsafe_allow_html=True) |
|
st.markdown( |
|
"<hr style='border: 2px solid #000; margin-top: 10px; margin-bottom: 10px;'>", |
|
unsafe_allow_html=True |
|
) |