RMakushkin commited on
Commit
7eb3326
·
1 Parent(s): a18e62f

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +2 -0
  2. app.py +87 -0
  3. dataset.csv +3 -0
  4. embeddings_main.npy +3 -0
  5. faiss_index_main.index +3 -0
  6. func.py +53 -0
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  images/ser2.png filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  images/ser2.png filter=lfs diff=lfs merge=lfs -text
37
+ dataset.csv filter=lfs diff=lfs merge=lfs -text
38
+ faiss_index_main.index filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import ast
5
+ import faiss
6
+
7
+ from func import filter_by_ganre, embed_user
8
+
9
+
10
+
11
+ """
12
+ # Умный поиск сериалов
13
+ """
14
+
15
+ df = pd.read_csv('dataset.csv')
16
+ embeddings = np.load('embeddings_main.npy')
17
+ index = faiss.read_index('faiss_index_main.index')
18
+
19
+ df['ganres'] = df['ganres'].apply(lambda x: ast.literal_eval(x))
20
+
21
+ st.write(f'<p style="font-family: Arial, sans-serif; font-size: 24px; ">Количество сериалов, \
22
+ предоставляемых сервисом {len(df)}</p>', unsafe_allow_html=True)
23
+
24
+ ganres_lst = sorted(['драма', 'документальный', 'биография', 'комедия', 'фэнтези', 'приключения', 'для детей', 'мультсериалы',
25
+ 'мелодрама', 'боевик', 'детектив', 'фантастика', 'триллер', 'семейный', 'криминал', 'исторический', 'музыкальные',
26
+ 'мистика', 'аниме', 'ужасы', 'спорт', 'скетч-шоу', 'военный', 'для взрослых', 'вестерн'])
27
+
28
+ st.sidebar.header('Панель инструментов :gear:')
29
+ choice_g = st.sidebar.multiselect("Выберите жанры", options=ganres_lst)
30
+ n = st.sidebar.selectbox("Количество отображаемых элементов на странице", options=[5, 10, 15])
31
+
32
+
33
+ # col3, col4 = st.columns([5,2])
34
+
35
+ # with col3:
36
+ text = st.text_input('Введите описание для рекомендации')
37
+
38
+ # with col4:
39
+
40
+ button = st.button('Отправить запрос', type="primary")
41
+
42
+ if text and button:
43
+ if len(choice_g) == 0:
44
+ choice_g = ganres_lst
45
+ filt_ind = filter_by_ganre(df, choice_g)
46
+ user_emb = embed_user(filt_ind, embeddings, text, n)
47
+ _, sorted_indices = index.search(user_emb.reshape(1, -1), n)
48
+ st.write(f'<p style="font-family: Arial, sans-serif; font-size: 18px; text-align: center;"><strong>Всего подобранных \
49
+ рекомендаций {len(sorted_indices[0])}</strong></p>', unsafe_allow_html=True)
50
+ st.write('\n')
51
+
52
+ # Отображение изображений и названий
53
+ # for ind, sim in top_dict.items():
54
+ # col1, col2 = st.columns([3, 4])
55
+ # with col1:
56
+ # st.image(df['poster'][ind], width=300)
57
+ # with col2:
58
+ # st.write(f"***Название:*** {df['title'][ind]}")
59
+ # st.write(f"***Жанр:*** {', '.join(df['ganres'][ind])}")
60
+ # st.write(f"***Описание:*** {df['description'][ind]}")
61
+ # similarity = round(sim, 4)
62
+ # st.write(f"***Cosine Similarity : {similarity}***")
63
+ # st.write(f"***Ссылка на фильм : {df['url'][ind]}***")
64
+
65
+ # st.markdown(
66
+ # "<hr style='border: 2px solid #000; margin-top: 10px; margin-bottom: 10px;'>",
67
+ # unsafe_allow_html=True
68
+ # )
69
+
70
+ for ind in sorted_indices[0]:
71
+ col1, col2 = st.columns([3, 4])
72
+ with col1:
73
+ st.image(df['poster'][ind], width=300)
74
+ with col2:
75
+ st.write(f"***Название:*** {df['title'][ind]}")
76
+ st.write(f"***Жанр:*** {', '.join(df['ganres'][ind])}")
77
+ st.write(f"***Описание:*** {df['description'][ind]}")
78
+ # similarity = round(sim, 4)
79
+ # st.write(f"***Cosine Similarity : {similarity}***")
80
+ st.write(f"***Ссылка на фильм : {df['url'][ind]}***")
81
+
82
+ st.markdown(
83
+ "<hr style='border: 2px solid #000; margin-top: 10px; margin-bottom: 10px;'>",
84
+ unsafe_allow_html=True
85
+ )
86
+
87
+ # streamlit run app.py
dataset.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6c10dbf7a899fbf0553bf6cab5fd11abf35cf224e4e6e4f7843fdd19144c550
3
+ size 19266108
embeddings_main.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b33d9e4726eff511c3f0f74dd9d1f22f863828aa0c03ff060c2983be3dce0115
3
+ size 45892736
faiss_index_main.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5fbaa50af8354c8a54372b1c763337f98792c351fa2e3aa266f448ec8266da2
3
+ size 45892653
func.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ from transformers import BertModel, BertTokenizer
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+
7
+
8
+ tokenizer = BertTokenizer.from_pretrained("DeepPavlov/rubert-base-cased-sentence")
9
+ model = BertModel.from_pretrained("DeepPavlov/rubert-base-cased-sentence")
10
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
11
+
12
+
13
+ def filter_by_ganre(df: pd.DataFrame, ganre_list: list):
14
+ filtered_df = df[df['ganres'].apply(lambda x: any(g in ganre_list for g in(x)))]
15
+ filt_ind = filtered_df.index.to_list()
16
+ return filt_ind
17
+
18
+ # def mean_pooling(model_output, attention_mask):
19
+ # token_embeddings = model_output['last_hidden_state']
20
+ # input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
21
+ # sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
22
+ # sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
23
+ # return sum_embeddings / sum_mask
24
+
25
+ # def recommendation(filt_ind: list, embeddings: np.array, user_text: str, n=10):
26
+ # token_user_text = tokenizer(user_text, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
27
+ # user_embeddings = torch.Tensor().to(device)
28
+ # model.to(device)
29
+ # model.eval()
30
+ # with torch.no_grad():
31
+ # batch = {k: v.to(device) for k, v in token_user_text.items()}
32
+ # outputs = model(**batch)
33
+ # user_embeddings = torch.cat([user_embeddings, mean_pooling(outputs, batch['attention_mask'])])
34
+ # user_embeddings = user_embeddings.cpu().numpy()
35
+ # cosine_similarities = cosine_similarity(embeddings[filt_ind], user_embeddings.reshape(1, -1))
36
+ # df_res = pd.DataFrame(cosine_similarities.ravel(), columns=['cos_sim']).sort_values('cos_sim', ascending=False)
37
+ # dict_topn = df_res.iloc[:n, :].cos_sim.to_dict()
38
+ # return dict_topn
39
+
40
+
41
+ def embed_user(filt_ind: list, embeddings:np.array, user_text: str, n=10):
42
+ tokens = tokenizer(user_text, return_tensors="pt", padding=True, truncation=True).to(device)
43
+ model.to(device)
44
+ model.eval()
45
+ with torch.no_grad():
46
+ outputs = model(**tokens)
47
+ user_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().reshape(1, -1)
48
+ return user_embedding
49
+
50
+ # cosine_similarities = cosine_similarity(embeddings[filt_ind], user_embedding.reshape(1, -1))
51
+ # df_res = pd.DataFrame(cosine_similarities.ravel(), columns=['cos_sim']).sort_values('cos_sim', ascending=False)
52
+ # dict_topn = df_res.iloc[:n, :].cos_sim.to_dict()
53
+ # return dict_topn