VlaTal commited on
Commit
17dcbf0
·
1 Parent(s): 0d3862a
Files changed (3) hide show
  1. app.py +153 -0
  2. files_to_load/Harry_Potter.pdf +0 -0
  3. requirements.txt +175 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __import__('pysqlite3')
2
+ import sys
3
+ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
4
+
5
+ import os
6
+ import pprint
7
+ from dotenv import load_dotenv
8
+ from typing import List, Tuple, Optional, Union
9
+ from loguru import logger as log
10
+
11
+ import tiktoken
12
+ from langchain.chains import RetrievalQA
13
+ from langchain.prompts import PromptTemplate
14
+ from langchain.text_splitter import Document
15
+ from langchain.output_parsers import PydanticOutputParser
16
+ from langchain_openai import AzureOpenAIEmbeddings
17
+ from langchain_openai import AzureChatOpenAI
18
+ from langchain_community.vectorstores import Chroma
19
+ from langchain.document_loaders.pdf import PyPDFLoader
20
+ from langchain.text_splitter import CharacterTextSplitter
21
+
22
+ from pydantic import BaseModel, Field
23
+ import streamlit as st
24
+ import logging
25
+
26
+ logging.basicConfig()
27
+ logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)
28
+
29
+
30
+ def _calc_tokens(splits: List[Document]) -> int:
31
+ tokens = 0
32
+
33
+ for doc in splits:
34
+ encoding = tiktoken.get_encoding('cl100k_base')
35
+ tokens += len(encoding.encode(doc.page_content))
36
+
37
+ return tokens
38
+
39
+
40
+ class LineList(BaseModel):
41
+ lines: List[str] = Field(description="Lines of text")
42
+
43
+
44
+ class LineListOutputParser(PydanticOutputParser):
45
+ def __init__(self) -> None:
46
+ super().__init__(pydantic_object=LineList)
47
+
48
+ def parse(self, text: str) -> LineList:
49
+ lines = text.strip().split("\n")
50
+ return LineList(lines=lines)
51
+
52
+
53
+ class Assistant:
54
+ def __init__(self):
55
+ load_dotenv()
56
+ self.db_dir = 'docs/chroma/'
57
+ self.embedding = AzureOpenAIEmbeddings(azure_deployment="ada_dev")
58
+ self.llm = AzureChatOpenAI(
59
+ azure_deployment="35_turbo",
60
+ model_name="gpt-35-turbo",
61
+ temperature=0
62
+ )
63
+
64
+ os.environ["AZURE_OPENAI_API_KEY"] = st.secrets["AZURE_OPENAI_API_KEY"]
65
+ os.environ["OPENAI_API_TYPE"] = st.secrets["OPENAI_API_TYPE"]
66
+ os.environ["OPENAI_API_VERSION"] = st.secrets["OPENAI_API_VERSION"]
67
+ os.environ["AZURE_OPENAI_ENDPOINT"] = st.secrets["AZURE_OPENAI_ENDPOINT"]
68
+ self.make_template()
69
+
70
+ def run(self):
71
+ st.title('Гаррі Поттер асистент')
72
+
73
+ instruction = st.text_input('Питання', '')
74
+
75
+ if st.button('Згенерувати відповідь'):
76
+ result, docs = self.stuff_search(instruction)
77
+ st.subheader('Відповідь')
78
+ st.text(result)
79
+ st.header('Знайдені чанки')
80
+ for doc in docs:
81
+ st.subheader(f'Сторінка {doc.metadata.get("page")}')
82
+ st.text(doc.page_content)
83
+
84
+ def make_template(self):
85
+ template = """Ти ШІ консультант. Твоя задача відповідати на запитання користувачів. Запитання будуть про книгу "Гаррі Поттер та філософський камінь". Додатково тобі будуть надані частини тексту з книги в якості контексту, з яких ти повинен надати відповідь. Ти повинен використовувати для відповіді лише наданий контекст і не додумувати нічого від себе. Якщо в частинах тексту немає відповідної інформації, щоб надати відповідь - вибачся та скажи, що не знаєш відповіді. ВАЖЛИВО відповідати виключно УКРАЇНСЬКОЮ мовою.
86
+ Контекст:
87
+ {context}
88
+ Запитання: {question}
89
+ Відповідь:"""
90
+ self.prompt = PromptTemplate.from_template(template)
91
+
92
+ def load_pdf(self, file_name: str) -> List[Document]:
93
+ log.info("Loading pdf")
94
+ loader = PyPDFLoader(f"files_to_load/{file_name}")
95
+ return loader.load()
96
+
97
+ def split_documents(self, pages: List[dict]) -> Union[List[Document], None]:
98
+ log.info("Splitting pdf")
99
+ text_splitter = CharacterTextSplitter(
100
+ separator="\n",
101
+ chunk_size=1000,
102
+ chunk_overlap=150,
103
+ length_function=len
104
+ )
105
+
106
+ return text_splitter.split_documents(pages)
107
+
108
+ def save_in_db(self, splits: List[Document]):
109
+ log.info("Saving chunks in db")
110
+ if len(splits) == 0:
111
+ log.warning(
112
+ "There are no splits to save in db. Please provide them in arguments or call the split_documents(headers_to_split, pages) method")
113
+ return None
114
+
115
+ vectordb = Chroma.from_documents(
116
+ documents=splits,
117
+ embedding=self.embedding,
118
+ persist_directory=self.db_dir
119
+ )
120
+
121
+ log.info(f"{vectordb._collection.count()} rows were saved")
122
+ log.info(f"{_calc_tokens(splits)} tokens were affected")
123
+ return True
124
+
125
+ def stuff_search(self, question: str):
126
+ vectordb = Chroma(persist_directory=self.db_dir,
127
+ embedding_function=self.embedding)
128
+
129
+ qa_chain = RetrievalQA.from_chain_type(
130
+ self.llm,
131
+ retriever=vectordb.as_retriever(),
132
+ return_source_documents=True,
133
+ chain_type_kwargs={"prompt": self.prompt}
134
+ )
135
+
136
+ result = qa_chain({"query": question})
137
+ log.info(f'Questing: {question}')
138
+ log.info(f'Result: {result["result"]}')
139
+ log.info("DOCUMENTS:")
140
+ for doc in result["source_documents"]:
141
+ log.info(doc)
142
+
143
+ return result["result"], result["source_documents"]
144
+
145
+ if __name__ == "__main__":
146
+ assistant = Assistant()
147
+ vectordb = Chroma(persist_directory="docs/chroma/",
148
+ embedding_function=assistant.embedding)
149
+ if(len(vectordb.get().get("documents")) == 0):
150
+ pdf = assistant.load_pdf("Harry_Potter.pdf")
151
+ splits = assistant.split_documents(pdf)
152
+ assistant.save_in_db(splits)
153
+ assistant.run()
files_to_load/Harry_Potter.pdf ADDED
Binary file (608 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.1
3
+ aiosignal==1.3.1
4
+ altair==5.2.0
5
+ annotated-types==0.6.0
6
+ anyio==4.2.0
7
+ asgiref==3.7.2
8
+ asttokens==2.4.1
9
+ attrs==23.2.0
10
+ backoff==2.2.1
11
+ bcrypt==4.1.2
12
+ blinker==1.7.0
13
+ build==1.0.3
14
+ cachetools==5.3.2
15
+ certifi==2023.11.17
16
+ charset-normalizer==3.3.2
17
+ chroma-hnswlib==0.7.3
18
+ chromadb==0.4.22
19
+ click==8.1.7
20
+ colorama==0.4.6
21
+ coloredlogs==15.0.1
22
+ comm==0.2.1
23
+ contourpy==1.2.0
24
+ cycler==0.12.1
25
+ dataclasses-json==0.6.3
26
+ debugpy==1.8.0
27
+ decorator==5.1.1
28
+ Deprecated==1.2.14
29
+ distro==1.9.0
30
+ executing==2.0.1
31
+ fastapi==0.109.0
32
+ ffmpy==0.3.1
33
+ filelock==3.13.1
34
+ flatbuffers==23.5.26
35
+ fonttools==4.47.2
36
+ frozenlist==1.4.1
37
+ fsspec==2023.12.2
38
+ gitdb==4.0.11
39
+ GitPython==3.1.41
40
+ google-auth==2.26.2
41
+ googleapis-common-protos==1.62.0
42
+ gradio==4.15.0
43
+ gradio_client==0.8.1
44
+ greenlet==3.0.3
45
+ grpcio==1.60.0
46
+ h11==0.14.0
47
+ httpcore==1.0.2
48
+ httptools==0.6.1
49
+ httpx==0.26.0
50
+ huggingface-hub==0.20.2
51
+ humanfriendly==10.0
52
+ idna==3.6
53
+ importlib-metadata==6.11.0
54
+ importlib-resources==6.1.1
55
+ ipykernel==6.29.0
56
+ ipython==8.20.0
57
+ jedi==0.19.1
58
+ Jinja2==3.1.3
59
+ jsonpatch==1.33
60
+ jsonpointer==2.4
61
+ jsonschema==4.21.1
62
+ jsonschema-specifications==2023.12.1
63
+ jupyter_client==8.6.0
64
+ jupyter_core==5.7.1
65
+ kiwisolver==1.4.5
66
+ kubernetes==29.0.0
67
+ langchain==0.1.1
68
+ langchain-community==0.0.13
69
+ langchain-core==0.1.13
70
+ langchain-openai==0.0.3
71
+ langsmith==0.0.83
72
+ loguru==0.7.2
73
+ markdown-it-py==3.0.0
74
+ MarkupSafe==2.1.4
75
+ marshmallow==3.20.2
76
+ matplotlib==3.8.2
77
+ matplotlib-inline==0.1.6
78
+ mdurl==0.1.2
79
+ mmh3==4.1.0
80
+ monotonic==1.6
81
+ mpmath==1.3.0
82
+ multidict==6.0.4
83
+ mypy-extensions==1.0.0
84
+ nest-asyncio==1.5.9
85
+ numpy==1.26.3
86
+ oauthlib==3.2.2
87
+ onnxruntime==1.16.3
88
+ openai==1.9.0
89
+ opentelemetry-api==1.22.0
90
+ opentelemetry-exporter-otlp-proto-common==1.22.0
91
+ opentelemetry-exporter-otlp-proto-grpc==1.22.0
92
+ opentelemetry-instrumentation==0.43b0
93
+ opentelemetry-instrumentation-asgi==0.43b0
94
+ opentelemetry-instrumentation-fastapi==0.43b0
95
+ opentelemetry-proto==1.22.0
96
+ opentelemetry-sdk==1.22.0
97
+ opentelemetry-semantic-conventions==0.43b0
98
+ opentelemetry-util-http==0.43b0
99
+ orjson==3.9.12
100
+ overrides==7.6.0
101
+ packaging==23.2
102
+ pandas==2.2.0
103
+ parso==0.8.3
104
+ pillow==10.2.0
105
+ platformdirs==4.1.0
106
+ posthog==3.3.2
107
+ prompt-toolkit==3.0.43
108
+ protobuf==4.25.2
109
+ psutil==5.9.8
110
+ pulsar-client==3.4.0
111
+ pure-eval==0.2.2
112
+ pyarrow==14.0.2
113
+ pyasn1==0.5.1
114
+ pyasn1-modules==0.3.0
115
+ pydantic==2.5.3
116
+ pydantic_core==2.14.6
117
+ pydeck==0.8.1b0
118
+ pydub==0.25.1
119
+ Pygments==2.17.2
120
+ pyparsing==3.1.1
121
+ pypdf==4.0.0
122
+ PyPika==0.48.9
123
+ pysqlite3-binary==0.5.2.post1
124
+ pyproject_hooks==1.0.0
125
+ pyreadline3==3.4.1
126
+ python-dateutil==2.8.2
127
+ python-dotenv==1.0.0
128
+ python-multipart==0.0.6
129
+ pytz==2023.3.post1
130
+ PyYAML==6.0.1
131
+ pyzmq==25.1.2
132
+ referencing==0.32.1
133
+ regex==2023.12.25
134
+ requests==2.31.0
135
+ requests-oauthlib==1.3.1
136
+ rich==13.7.0
137
+ rpds-py==0.17.1
138
+ rsa==4.9
139
+ ruff==0.1.14
140
+ semantic-version==2.10.0
141
+ shellingham==1.5.4
142
+ six==1.16.0
143
+ smmap==5.0.1
144
+ sniffio==1.3.0
145
+ SQLAlchemy==2.0.25
146
+ stack-data==0.6.3
147
+ starlette==0.35.1
148
+ streamlit==1.30.0
149
+ sympy==1.12
150
+ tenacity==8.2.3
151
+ tiktoken==0.5.2
152
+ tokenizers==0.15.0
153
+ toml==0.10.2
154
+ tomlkit==0.12.0
155
+ toolz==0.12.0
156
+ tornado==6.4
157
+ tqdm==4.66.1
158
+ traitlets==5.14.1
159
+ typer==0.9.0
160
+ typing-inspect==0.9.0
161
+ typing_extensions==4.9.0
162
+ tzdata==2023.4
163
+ tzlocal==5.2
164
+ urllib3==2.1.0
165
+ uvicorn==0.26.0
166
+ validators==0.22.0
167
+ watchdog==3.0.0
168
+ watchfiles==0.21.0
169
+ wcwidth==0.2.13
170
+ websocket-client==1.7.0
171
+ websockets==11.0.3
172
+ win32-setctime==1.1.0
173
+ wrapt==1.16.0
174
+ yarl==1.9.4
175
+ zipp==3.17.0