File size: 6,025 Bytes
17dcbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24fa644
 
 
 
17dcbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')

import os
import pprint
from dotenv import load_dotenv
from typing import List, Tuple, Optional, Union
from loguru import logger as log

import tiktoken
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.text_splitter import Document
from langchain.output_parsers import PydanticOutputParser
from langchain_openai import AzureOpenAIEmbeddings
from langchain_openai import AzureChatOpenAI
from langchain_community.vectorstores import Chroma
from langchain.document_loaders.pdf import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter

from pydantic import BaseModel, Field
import streamlit as st
import logging

logging.basicConfig()
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)


def _calc_tokens(splits: List[Document]) -> int:
    tokens = 0

    for doc in splits:
        encoding = tiktoken.get_encoding('cl100k_base')
        tokens += len(encoding.encode(doc.page_content))

    return tokens


class LineList(BaseModel):
    lines: List[str] = Field(description="Lines of text")


class LineListOutputParser(PydanticOutputParser):
    def __init__(self) -> None:
        super().__init__(pydantic_object=LineList)

    def parse(self, text: str) -> LineList:
        lines = text.strip().split("\n")
        return LineList(lines=lines)


class Assistant:
    def __init__(self):
        load_dotenv()
        self.db_dir = 'docs/chroma/'
        self.embedding = AzureOpenAIEmbeddings(azure_deployment="ada_dev")
        self.llm = AzureChatOpenAI(
            azure_deployment="35_turbo",
            model_name="gpt-35-turbo",
            temperature=0
        )

        os.environ["AZURE_OPENAI_API_KEY"] = os.environ["AZURE_OPENAI_API_KEY"]
        os.environ["OPENAI_API_TYPE"] = os.environ["OPENAI_API_TYPE"]
        os.environ["OPENAI_API_VERSION"] = os.environ["OPENAI_API_VERSION"]
        os.environ["AZURE_OPENAI_ENDPOINT"] = os.environ["AZURE_OPENAI_ENDPOINT"]
        self.make_template()

    def run(self):
        st.title('Гаррі Поттер асистент')

        instruction = st.text_input('Питання', '')

        if st.button('Згенерувати відповідь'):
            result, docs = self.stuff_search(instruction)
            st.subheader('Відповідь')
            st.text(result)
            st.header('Знайдені чанки')
            for doc in docs:
                st.subheader(f'Сторінка {doc.metadata.get("page")}')
                st.text(doc.page_content)

    def make_template(self):
        template = """Ти ШІ консультант. Твоя задача відповідати на запитання користувачів. Запитання будуть про книгу "Гаррі Поттер та філософський камінь". Додатково тобі будуть надані частини тексту з книги в якості контексту, з яких ти повинен надати відповідь. Ти повинен використовувати для відповіді лише наданий контекст і не додумувати нічого від себе. Якщо в частинах тексту немає відповідної інформації, щоб надати відповідь - вибачся та скажи, що не знаєш відповіді. ВАЖЛИВО відповідати виключно УКРАЇНСЬКОЮ мовою.
        Контекст:
        {context}
        Запитання: {question}
        Відповідь:"""
        self.prompt = PromptTemplate.from_template(template)

    def load_pdf(self, file_name: str) -> List[Document]:
        log.info("Loading pdf")
        loader = PyPDFLoader(f"files_to_load/{file_name}")
        return loader.load()

    def split_documents(self, pages: List[dict]) -> Union[List[Document], None]:
        log.info("Splitting pdf")
        text_splitter = CharacterTextSplitter(
            separator="\n",
            chunk_size=1000,
            chunk_overlap=150,
            length_function=len
        )

        return text_splitter.split_documents(pages)

    def save_in_db(self, splits: List[Document]):
        log.info("Saving chunks in db")
        if len(splits) == 0:
            log.warning(
                "There are no splits to save in db. Please provide them in arguments or call the split_documents(headers_to_split, pages) method")
            return None

        vectordb = Chroma.from_documents(
            documents=splits,
            embedding=self.embedding,
            persist_directory=self.db_dir
        )

        log.info(f"{vectordb._collection.count()} rows were saved")
        log.info(f"{_calc_tokens(splits)} tokens were affected")
        return True

    def stuff_search(self, question: str):
        vectordb = Chroma(persist_directory=self.db_dir,
                          embedding_function=self.embedding)

        qa_chain = RetrievalQA.from_chain_type(
            self.llm,
            retriever=vectordb.as_retriever(),
            return_source_documents=True,
            chain_type_kwargs={"prompt": self.prompt}
        )

        result = qa_chain({"query": question})
        log.info(f'Questing: {question}')
        log.info(f'Result: {result["result"]}')
        log.info("DOCUMENTS:")
        for doc in result["source_documents"]:
            log.info(doc)

        return result["result"], result["source_documents"]

if __name__ == "__main__":
    assistant = Assistant()
    vectordb = Chroma(persist_directory="docs/chroma/",
                    embedding_function=assistant.embedding)
    if(len(vectordb.get().get("documents")) == 0):
        pdf = assistant.load_pdf("Harry_Potter.pdf")
        splits = assistant.split_documents(pdf)
        assistant.save_in_db(splits)
    assistant.run()