File size: 5,575 Bytes
a23f3fb
 
 
 
 
 
3344c3a
a23f3fb
3191631
a23f3fb
81d46d1
 
 
 
 
 
3191631
 
3344c3a
81d46d1
 
 
a23f3fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d46d1
a23f3fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4d61fc
a23f3fb
 
 
 
 
 
 
 
 
 
 
 
 
 
c4d61fc
 
 
81d46d1
c4d61fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a23f3fb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np
import pandas as pd
import streamlit as st
import requests
from sentence_transformers import util
from sentence_transformers import SentenceTransformer, util
import os
st.set_page_config(page_title="Custom Button Example", layout="wide")
from dotenv import load_dotenv

from langchain.chat_models.gigachat import GigaChat
from langchain.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
)

load_dotenv()
credentials = os.getenv('API_KEY')
chat = GigaChat(model='GigaChat', credentials=credentials, verify_ssl_certs=False)


@st.cache_resource
def load_model_all_mpnet():
    return SentenceTransformer('all-mpnet-base-v2')
model_mp = load_model_all_mpnet()
@st.cache_data
def load_embeddings(file_path):
    return np.load(file_path)
book_embeddings_mp = load_embeddings('data/book_embeddings.npy')
@st.cache_data
def load_data(file_path):
    return pd.read_csv(file_path)
df = load_data('data/books_data_cleaned.csv')
@st.cache_resource
def load_model_msmarco():
    return SentenceTransformer('msmarco-roberta-base-v3')
model_ms = load_model_msmarco()
@st.cache_data
def load_embeddings(file_path):
    return np.load(file_path)
book_embeddings_ms = load_embeddings('data/book_embeddings_ms.npy')


def get_embedding(text, model):
    text = model.encode(text, convert_to_tensor=True)
    return text


def get_top_10_recommendations(query, model, book_embeddings, top_k):
    query_embedding = get_embedding(query, model).cpu()
    similarities = util.pytorch_cos_sim(query_embedding, book_embeddings)[0]
    top_results = similarities.cpu().numpy().argsort()[::-1][:top_k]
    top_books = df.iloc[top_results].copy()
    similarity_scores = similarities.cpu().numpy()[top_results]
    top_books['similarity_score'] = similarity_scores
    return top_books


st.title('Рекомендации книг')

search = st.radio(
    "Выберите тип семантического поиска:",
    [":blue[Симметричный]", ":blue[Асимметричный]"],
    captions=[
        "Используем 'all-mpnet-base-v2'",
        "Используем 'msmarco-roberta-base-v3'",
    ],
    horizontal=True,
)

def params(search):
    if search == ":blue[Симметричный]":
        text = '''Я ищу книги в жанре фэнтези, которые описывают приключения магов и волшебников, обучающихся в специальных магических школах и сражающихся с темными силами или злыми существами. Особенно интересуют произведения, где главные герои сталкиваются с эпическими испытаниями и развивают свои уникальные способности.'''
        model = model_mp
        book_embeddings = book_embeddings_mp
        return text, model, book_embeddings
    elif search == ":blue[Асимметричный]":
        text = '''путешествие во времени'''
        model = model_ms
        book_embeddings = book_embeddings_ms
        return text, model, book_embeddings
text, model, book_embeddings = params(search)


col1, col2 = st.columns([3, 1])
with col1:
    query = st.text_area('Введите запрос, чтобы получить рекомендации', f'{text}', height=95)
with col2:
    number = st.number_input(
        "Сколько книг найти?", value=3
)
    find_button = st.button('Найти', key='find_button', use_container_width=True)

if find_button and query:
        top_10_books = get_top_10_recommendations(query, model, book_embeddings, number)
        for idx, row in top_10_books.iterrows():
            with st.container():
                col1, col2 = st.columns([1, 3])
                
                with col1:
                    st.image(row['image_url'], width = 300)
                with col2:
                    st.subheader(f"{row['title']}")
                    st.write(f"**Автор:** {row['author']}")
                    tab1, tab2 = st.tabs(['Аннотация', 'Краткое содержание'])
                    with tab1:
                        st.write(row['annotation'])
                        st.metric(label="Схожесть", value=f"{row['similarity_score']:.3f}")
                        st.write(f"**Ссылка:** {row['page_url']}")
                    with tab2:
                        template = "ты умеешь кратко в несколько предложений описывать содержание книги по ее названию"
                        system_message_prompt = SystemMessagePromptTemplate.from_template(template)
                        human_template = "Кратко опиши содержание книги под названием: {book_title}"
                        human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
                        chat_prompt = ChatPromptTemplate.from_messages(
                            [system_message_prompt, human_message_prompt]
                        )
                        formatted_prompt = chat_prompt.format_prompt(
                            book_title=row['title']
                        )
                        response = chat(formatted_prompt.to_messages())
                        st.write(response.content)
                        st.write("---")
                st.write("---")