AlexandraGulamova commited on
Commit
343af91
1 Parent(s): a82ca28
EmbeddingGenerator.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #EmbeddingGenerator.py
2
+
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from sentence_transformers import SentenceTransformer
5
+ import torch
6
+ import numpy as np
7
+ import warnings
8
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers.models.bert")
9
+ import os
10
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
11
+
12
+
13
+ class EmbeddingGenerator:
14
+ def __init__(self, pavlov_model_name="DeepPavlov/rubert-base-cased", sentence_transformer_model_name="cointegrated/rubert-tiny2"):
15
+ """
16
+ Инициализирует токенизатор и модели для генерации эмбеддингов.
17
+
18
+ Args:
19
+ pavlov_model_name (str): Название модели для загрузки Pavlov модели.
20
+ sentence_transformer_model_name (str): Название модели SentenceTransformer для генерации эмбеддингов.
21
+ """
22
+ self.pavlov_tokenizer = AutoTokenizer.from_pretrained(pavlov_model_name, ignore_mismatched_sizes=True)
23
+ self.pavlov_model = AutoModel.from_pretrained(pavlov_model_name, ignore_mismatched_sizes=True)
24
+ self.sentence_transformer_model = SentenceTransformer(sentence_transformer_model_name)
25
+
26
+ def generate_embeddings(self, texts, method="pavlov"):
27
+ """
28
+ Генерирует эмбеддинги для списка текстов с использованием выбранного метода.
29
+
30
+ Args:
31
+ texts (list of str): Список текстов для генерации эмбеддингов.
32
+ method (str): Метод генерации эмбеддингов: "pavlov" или "rubert_tiny2".
33
+
34
+ Returns:
35
+ np.ndarray: Эмбеддинги текстов.
36
+ """
37
+ if method == "pavlov":
38
+ # Генерация эмбеддингов с использованием Pavlov модели
39
+ inputs = self.pavlov_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
40
+ with torch.no_grad():
41
+ outputs = self.pavlov_model(**inputs)
42
+ # Mean pooling
43
+ embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
44
+ elif method == "rubert_tiny2":
45
+ # Генерация эмбеддингов с использованием SentenceTransformer
46
+ embeddings = self.sentence_transformer_model.encode(texts, show_progress_bar=False)
47
+ else:
48
+ raise ValueError("Unsupported method. Choose 'pavlov' or 'rubert_tiny2'.")
49
+
50
+ return embeddings
51
+
52
+
TextAugmentation.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #TextAugmentation.py
2
+
3
+ from transformers import T5Tokenizer, AutoModelForSeq2SeqLM, MarianMTModel, MarianTokenizer
4
+ import torch
5
+ import os
6
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
7
+
8
+
9
+ class TextAugmentation:
10
+ def __init__(self,
11
+ paraphrase_model_name="cointegrated/rut5-base-paraphraser",
12
+ ru_en_model_name="Helsinki-NLP/opus-mt-ru-en",
13
+ en_ru_model_name="Helsinki-NLP/opus-mt-en-ru"):
14
+ # Инициализация модели для перефразирования
15
+ self.paraphrase_tokenizer = T5Tokenizer.from_pretrained(paraphrase_model_name, legacy=False)
16
+ self.paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(paraphrase_model_name)
17
+
18
+ # Инициализация моделей для обратного перевода
19
+ self.ru_en_tokenizer = MarianTokenizer.from_pretrained(ru_en_model_name)
20
+ self.ru_en_model = MarianMTModel.from_pretrained(ru_en_model_name)
21
+ self.en_ru_tokenizer = MarianTokenizer.from_pretrained(en_ru_model_name)
22
+ self.en_ru_model = MarianMTModel.from_pretrained(en_ru_model_name)
23
+
24
+
25
+ def paraphrase(self, text, num_return_sequences=1):
26
+ """
27
+ Перефразирование текста с использованием модели.
28
+
29
+ Args:
30
+ text (str): Исходный текст для перефразирования.
31
+ num_return_sequences (int): Количество вариантов перефразирования.
32
+
33
+ Returns:
34
+ list[str]: Список вариантов перефразирования текста.
35
+ """
36
+ inputs = self.paraphrase_tokenizer([text], max_length=512, truncation=True, return_tensors="pt")
37
+ outputs = self.paraphrase_model.generate(
38
+ **inputs,
39
+ max_length=128,
40
+ num_return_sequences=num_return_sequences,
41
+ do_sample=True,
42
+ temperature=1.2,
43
+ top_k=50,
44
+ top_p=0.90
45
+ )
46
+ return [self.paraphrase_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
47
+
48
+ def back_translate(self, text):
49
+ """
50
+ Выполняет обратный перевод текста: русский -> английский -> русский.
51
+
52
+ Args:
53
+ text (str): Исходный текст для обратного перевода.
54
+
55
+ Returns:
56
+ str: Текст после обратного перевода.
57
+ """
58
+ # Перевод с русского на английский
59
+ inputs = self.ru_en_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
60
+ with torch.no_grad():
61
+ outputs = self.ru_en_model.generate(**inputs)
62
+ translated_text = self.ru_en_tokenizer.decode(outputs[0], skip_special_tokens=True)
63
+
64
+ # Перевод с английского обратно на русский
65
+ inputs = self.en_ru_tokenizer(translated_text, return_tensors="pt", truncation=True, max_length=512)
66
+ with torch.no_grad():
67
+ outputs = self.en_ru_model.generate(**inputs)
68
+ back_translated_text = self.en_ru_tokenizer.decode(outputs[0], skip_special_tokens=True)
69
+
70
+ return back_translated_text
71
+
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Основные библиотеки
2
+ import os
3
+ import re
4
+ import string
5
+ import warnings
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ # Машинное обучение и обработка текста
10
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM, MarianMTModel, MarianTokenizer
11
+ from sentence_transformers import SentenceTransformer, util
12
+ from sklearn.base import BaseEstimator, TransformerMixin
13
+ # FAISS для семантического поиска
14
+ import faiss
15
+ # Лемматизация и морфология
16
+ import pymorphy3
17
+ # Streamlit для создания веб-приложений
18
+ import streamlit as st
19
+ # Кастомные модули
20
+ from EmbeddingGenerator import EmbeddingGenerator
21
+ from TextAugmentation import TextAugmentation
22
+ # LangChain для интеграции GigaChat
23
+ from langchain_community.chat_models.gigachat import GigaChat
24
+
25
+ # ======= загружаем ранее рассчитанные эмбеддинги и объявляем классы=======
26
+
27
+ # Инициализация GigaChat с ключом и отключенной проверкой SSL
28
+ giga = GigaChat(
29
+ credentials="ODk0NDE1ODEtYTJhMi00N2Y1LTk4YWItNGZlNzNkM2QwMDNiOjk5YmVjN2ZjLThmM2EtNDhjYy04OWQ2LWNkOTlhOTNkNGY3NQ==",
30
+ verify_ssl_certs=False
31
+ )
32
+
33
+ augmentor = TextAugmentation()
34
+ embedding_gen = EmbeddingGenerator()
35
+ df=pd.read_csv("movies_data_fixed.csv")
36
+ image_path = "image-2.png"
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+ # Загружаем и отображаем картинку
52
+ st.image(image_path, use_container_width=True)
53
+
54
+ # Заголовок
55
+ st.markdown(
56
+ """
57
+ <div class="title">
58
+ КиноКринж
59
+ </div>
60
+ """,
61
+ unsafe_allow_html=True
62
+ )
63
+
64
+ # Добавляем окно ввода текста
65
+ user_input = st.text_area("Добавьте описание фильма", "", height=150)
66
+
67
+ # Слайдер для выбора количества фильмов
68
+ num_results = st.slider('Выберите количество фильмов', min_value=1, max_value=20, value=4)
69
+
70
+ # Выбор модели
71
+ model_option = st.selectbox('Выберите модель для обработки запроса:', ['cointegrated/rubert-tiny2','DeepPavlov/rubert-base-cased','all-MiniLM-L6-v2', 'paraphrase-MiniLM-L6-v2'])
72
+
73
+ if model_option!='DeepPavlov/rubert-base-cased':
74
+ model = SentenceTransformer(model_option)
75
+
76
+
77
+ # ======= дополнительная фильтрация для аугментаций (убираем слишком непохожие) =======
78
+ def filter_paraphrases(original, paraphrases, threshold=0.8):
79
+ original_embedding = model.encode(original)
80
+ filtered = []
81
+ for paraphrase in paraphrases:
82
+ paraphrase_embedding = model.encode(paraphrase)
83
+ similarity = util.cos_sim(original_embedding, paraphrase_embedding).item()
84
+ if similarity >= threshold:
85
+ filtered.append(paraphrase)
86
+ return filtered
87
+ #======================СЕМПЛ======= =======
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+ # Проверка наличия рекомендованных фильмов
99
+ if 'recommended_movies' not in st.session_state:
100
+ st.session_state.recommended_movies = []
101
+
102
+ # Кнопка для поиска
103
+ if st.button('Найти фильм'):
104
+ if user_input.strip():
105
+ # Генерация эмбеддинга для запроса
106
+ if model_option != 'DeepPavlov/rubert-base-cased' and model_option != 'cointegrated/rubert-tiny2':
107
+ index = faiss.read_index('faiss_index.bin')
108
+ query_embedding = model.encode([user_input]).astype("float32")
109
+ faiss.normalize_L2(query_embedding)
110
+ elif model_option == 'DeepPavlov/rubert-base-cased':
111
+ index = faiss.read_index('pavlov3.bin')
112
+ back_translate = augmentor.back_translate(user_input)
113
+ augmented_query_pavlov = user_input + " " + back_translate
114
+ query_embedding = embedding_gen.generate_embeddings(augmented_query_pavlov, method="pavlov")
115
+ elif model_option == 'cointegrated/rubert-tiny2':
116
+ index = faiss.read_index('rubert2.bin')
117
+ paraphrase = augmentor.paraphrase(user_input, num_return_sequences=3)
118
+ filtered_rubert = filter_paraphrases(user_input, paraphrase)
119
+ augmented_query_rubert = user_input + " " + " ".join(filtered_rubert)
120
+ query_embedding = embedding_gen.generate_embeddings(augmented_query_rubert, method="rubert_tiny2").reshape(1, -1)
121
+ faiss.normalize_L2(query_embedding)
122
+
123
+ # Поиск ближайших соседей
124
+ distances, indices = index.search(query_embedding, num_results)
125
+
126
+ # Отображение результатов
127
+ st.write(f"Результаты поиска ({num_results} фильмов):")
128
+ recommended_movies = []
129
+ for idx, distance in zip(indices[0], distances[0]):
130
+ recommended_movies.append({
131
+ 'title': df.iloc[idx]['movie_title'],
132
+ 'description': df.iloc[idx]['description'],
133
+ 'image_url': df.iloc[idx]['image_url'],
134
+ 'page_url': df.iloc[idx]['page_url'],
135
+ 'similarity': distance,
136
+ 'short_description': None, # Содержимое краткого описания
137
+ 'is_short_description_shown': False # Флаг для того, чтобы избежать повторного запроса
138
+ })
139
+
140
+ # Сохраняем результаты в session_state
141
+ st.session_state.recommended_movies = recommended_movies
142
+
143
+ # Отображение рекомендованных фильмов
144
+ for idx, movie in enumerate(st.session_state.recommended_movies):
145
+ st.write(f"### {movie['title']}")
146
+ st.write(f"Описание: {movie['description']}")
147
+ st.write(f"Схожесть: {movie['similarity']:.4f}")
148
+
149
+ # Отображаем картинку постера
150
+ if movie.get('image_url'):
151
+ st.image(movie['image_url'], width=200)
152
+
153
+ # Добавляем ссылку на страницу фильма
154
+ if movie.get('page_url'):
155
+ st.markdown(f"[Перейти на страницу фильма]({movie['page_url']})")
156
+
157
+ # Генерируем уникальный ключ с использованием индекса
158
+ button_key = f"short_description_button_{idx}" # Уникальный ключ для кнопки
159
+ if st.button(f"Получить краткое содержание для {movie['title']}", key=button_key):
160
+ if not movie.get('is_short_description_shown', False): # Проверяем состояние
161
+ try:
162
+ # Отправляем запрос в GigaChat
163
+ prompt = f"{movie['title']} краткое содержание фильма не более 100 слов"
164
+ response = giga.invoke(prompt)
165
+
166
+ # Извлекаем описание из ответа
167
+ description = response.content if response else "Описание не найдено."
168
+ movie['short_description'] = description
169
+ movie['is_short_description_shown'] = True
170
+
171
+ except Exception as e:
172
+ st.error(f"Ошибка при запросе в GigaChat: {e}")
173
+
174
+ # Показываем краткое содержание
175
+ if movie.get('short_description') and movie.get('is_short_description_shown', False):
176
+ st.write(f"Краткое содержание для {movie['title']}: {movie['short_description']}")
177
+
178
+ st.write("---")
assets/fonts/Anton-Regular.ttf ADDED
Binary file (162 kB). View file
 
assets/styles.css ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Подключаем шрифт через Google Fonts */
2
+ @import url('https://fonts.googleapis.com/css2?family=Russo+One&display=swap');
3
+
4
+ /* Фон для страницы */
5
+ body {
6
+ background-image: url('file:///Users/ser/ds_bootcamp/streamlit/movie_seeker/converted_image.png');
7
+ background-size: cover;
8
+ background-position: center;
9
+ margin: 0;
10
+ padding: 0;
11
+ }
12
+
13
+ /* Стили для текста */
14
+ .title {
15
+ font-family: 'Russo One', sans-serif;
16
+ font-size: 80px;
17
+ text-align: center;
18
+ color: white;
19
+ margin-top: 10px;
20
+ text-shadow: 3px 3px 6px rgba(0, 0, 0, 0.7);
21
+ }
22
+
23
+ /* Стили для окна ввода текста */
24
+ textarea {
25
+ margin-top: 40px;
26
+ width: 100%;
27
+ padding: 10px;
28
+ font-size: 16px;
29
+ border-radius: 5px;
30
+ border: 1px solid #ccc;
31
+ background-color: #f4f4f4;
32
+ }
33
+
34
+ /* Стили для кнопки */
35
+ .stCustomButton {
36
+ background-color: #f08b29;
37
+ color: white;
38
+ font-size: 18px;
39
+ padding: 12px 24px;
40
+ border: none;
41
+ border-radius: 5px;
42
+ cursor: pointer;
43
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
44
+ transition: background-color 0.3s ease, transform 0.3s ease;
45
+ }
46
+
47
+ .stCustomButton:hover {
48
+ background-color: #d9791e;
49
+ transform: translateY(-3px);
50
+ }
51
+
52
+ .stCustomButton:active {
53
+ background-color: #bc6a0a;
54
+ transform: translateY(0);
55
+ }
56
+
57
+ /* Стили для фона в шапке */
58
+ .header-image {
59
+ width: 100%;
60
+ height: 300px;
61
+ background-image: url('/Users/ser/ds_bootcamp/streamlit/movie_seeker/Unknown-2.png');
62
+ background-size: cover;
63
+ background-position: center;
64
+ position: absolute;
65
+ top: 0;
66
+ left: 0;
67
+ z-index: -1;
68
+ }
requirements.txt ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.4.4
2
+ aiohttp==3.11.10
3
+ aiosignal==1.3.1
4
+ altair==5.5.0
5
+ annotated-types==0.7.0
6
+ anyio==4.7.0
7
+ async-timeout==4.0.3
8
+ attrs==24.2.0
9
+ blinker==1.9.0
10
+ cachetools==5.5.0
11
+ certifi==2024.8.30
12
+ charset-normalizer==3.4.0
13
+ click==8.1.7
14
+ dataclasses-json==0.6.7
15
+ DAWG-Python==0.7.2
16
+ exceptiongroup==1.2.2
17
+ faiss-cpu==1.9.0.post1
18
+ filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1733240801289/work
19
+ frozenlist==1.5.0
20
+ fsspec @ file:///home/conda/feedstock_root/build_artifacts/fsspec_1729608855534/work
21
+ gigachat==0.1.36
22
+ gitdb==4.0.11
23
+ GitPython==3.1.43
24
+ gmpy2 @ file:///home/conda/feedstock_root/build_artifacts/gmpy2_1733462536562/work
25
+ greenlet==3.1.1
26
+ h11==0.14.0
27
+ httpcore==1.0.7
28
+ httpx==0.28.1
29
+ httpx-sse==0.4.0
30
+ huggingface-hub==0.26.3
31
+ idna==3.10
32
+ Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1733217336947/work
33
+ joblib==1.4.2
34
+ jsonpatch==1.33
35
+ jsonpointer==3.0.0
36
+ jsonschema==4.23.0
37
+ jsonschema-specifications==2024.10.1
38
+ langchain==0.3.10
39
+ langchain-community==0.3.10
40
+ langchain-core==0.3.22
41
+ langchain-text-splitters==0.3.2
42
+ langsmith==0.1.147
43
+ markdown-it-py==3.0.0
44
+ MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1733219680183/work
45
+ marshmallow==3.23.1
46
+ mdurl==0.1.2
47
+ mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1733302684489/work
48
+ multidict==6.1.0
49
+ mypy-extensions==1.0.0
50
+ narwhals==1.15.2
51
+ networkx @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_networkx_1731521053/work
52
+ numpy==1.26.4
53
+ orjson==3.10.12
54
+ packaging==24.2
55
+ pandas==2.2.3
56
+ pillow==11.0.0
57
+ propcache==0.2.1
58
+ protobuf==5.29.1
59
+ pyarrow==18.1.0
60
+ pydantic==2.10.3
61
+ pydantic-settings==2.6.1
62
+ pydantic_core==2.27.1
63
+ pydeck==0.9.1
64
+ Pygments==2.18.0
65
+ pymorphy3==2.0.2
66
+ pymorphy3-dicts-ru==2.4.417150.4580142
67
+ python-dateutil==2.9.0.post0
68
+ python-dotenv==1.0.1
69
+ pytz==2024.2
70
+ PyYAML==6.0.2
71
+ referencing==0.35.1
72
+ regex==2024.11.6
73
+ requests==2.32.3
74
+ requests-toolbelt==1.0.0
75
+ rich==13.9.4
76
+ rpds-py==0.22.3
77
+ safetensors==0.4.5
78
+ scikit-learn==1.5.2
79
+ scipy==1.14.1
80
+ sentence-transformers==3.3.1
81
+ sentencepiece==0.2.0
82
+ six==1.17.0
83
+ smmap==5.0.1
84
+ sniffio==1.3.1
85
+ SQLAlchemy==2.0.36
86
+ streamlit==1.40.2
87
+ sympy==1.13.1
88
+ tenacity==9.0.0
89
+ threadpoolctl==3.5.0
90
+ tokenizers==0.21.0
91
+ toml==0.10.2
92
+ torch==2.5.1+cpu
93
+ torchaudio==2.5.1+cpu
94
+ torchvision==0.20.1+cpu
95
+ tornado==6.4.2
96
+ tqdm==4.67.1
97
+ transformers==4.47.0
98
+ typing-inspect==0.9.0
99
+ typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1733188668063/work
100
+ tzdata==2024.2
101
+ urllib3==2.2.3
102
+ watchdog==6.0.0
103
+ yarl==1.18.3