Spaces:
Sleeping
Sleeping
Commit
·
99a3f34
1
Parent(s):
10330bc
creating embedding from docs
Browse files- app.py +36 -2
- requirements.txt +2 -1
- src/config.py +9 -1
- src/model.py +2 -2
- src/utils.py +49 -8
app.py
CHANGED
@@ -4,7 +4,7 @@ import logging
|
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
from langchain.embeddings.openai import OpenAIEmbeddings
|
6 |
import chainlit as cl
|
7 |
-
from src.utils import get_docSearch
|
8 |
from src.model import load_chain
|
9 |
|
10 |
|
@@ -13,11 +13,13 @@ from src.model import load_chain
|
|
13 |
|
14 |
|
15 |
|
|
|
16 |
welcome_message = """ Upload your file here"""
|
17 |
|
18 |
@cl.on_chat_start
|
19 |
async def start():
|
20 |
await cl.Message("you are in ").send()
|
|
|
21 |
files = None
|
22 |
while files is None:
|
23 |
files = await cl.AskFileMessage(
|
@@ -26,12 +28,16 @@ async def start():
|
|
26 |
max_size_mb=10,
|
27 |
timeout=90
|
28 |
).send()
|
|
|
29 |
file = files[0]
|
30 |
msg = cl.Message(content=f"Processing `{type(files)}` {file.name}....")
|
31 |
await msg.send()
|
32 |
|
33 |
-
|
|
|
|
|
34 |
|
|
|
35 |
|
36 |
chain = load_chain(docsearch)
|
37 |
|
@@ -44,6 +50,34 @@ async def start():
|
|
44 |
|
45 |
await msg.update()
|
46 |
|
|
|
|
|
47 |
cl.user_session.set("chain", chain)
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
|
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
from langchain.embeddings.openai import OpenAIEmbeddings
|
6 |
import chainlit as cl
|
7 |
+
from src.utils import get_docSearch, get_source
|
8 |
from src.model import load_chain
|
9 |
|
10 |
|
|
|
13 |
|
14 |
|
15 |
|
16 |
+
|
17 |
welcome_message = """ Upload your file here"""
|
18 |
|
19 |
@cl.on_chat_start
|
20 |
async def start():
|
21 |
await cl.Message("you are in ").send()
|
22 |
+
logging.info(f"app started")
|
23 |
files = None
|
24 |
while files is None:
|
25 |
files = await cl.AskFileMessage(
|
|
|
28 |
max_size_mb=10,
|
29 |
timeout=90
|
30 |
).send()
|
31 |
+
logging.info("uploader excecuted")
|
32 |
file = files[0]
|
33 |
msg = cl.Message(content=f"Processing `{type(files)}` {file.name}....")
|
34 |
await msg.send()
|
35 |
|
36 |
+
logging.info("processing started")
|
37 |
+
|
38 |
+
docsearch = get_docSearch(file,cl)
|
39 |
|
40 |
+
logging.info("document uploaded success")
|
41 |
|
42 |
chain = load_chain(docsearch)
|
43 |
|
|
|
50 |
|
51 |
await msg.update()
|
52 |
|
53 |
+
logging.info("processing completed")
|
54 |
+
|
55 |
cl.user_session.set("chain", chain)
|
56 |
|
57 |
+
@cl.on_message
|
58 |
+
async def main(message):
|
59 |
+
chain = cl.user_session.get("chain")
|
60 |
+
cb = cl.AsyncLangchainCallbackHandler(
|
61 |
+
stream_final_answer=True, answer_prefix_tokens=["FINAL","ANSWER"]
|
62 |
+
)
|
63 |
+
|
64 |
+
cb.answer_reached = True
|
65 |
+
res = await chain.acall(message, callbacks=[cb])
|
66 |
+
|
67 |
+
answer = res["answer"]
|
68 |
+
sources = res["sources"].strip()
|
69 |
+
|
70 |
+
|
71 |
+
## get doc from user session
|
72 |
+
docs = cl.user_session.get("docs")
|
73 |
+
metadatas = [doc.metadata for doc in docs]
|
74 |
+
all_sources = [m["source"]for m in metadatas]
|
75 |
+
|
76 |
+
source_elements,answer = get_source(sources,all_sources,docs,cl)
|
77 |
+
|
78 |
+
if cb.has_streamed_final_answer:
|
79 |
+
cb.final_stream.elements = source_elements
|
80 |
+
await cb.final_stream.update()
|
81 |
+
else:
|
82 |
+
await cl.Message(content=answer, elements=source_elements).send()
|
83 |
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ openai
|
|
3 |
python-dotenv
|
4 |
chainlit
|
5 |
chromadb
|
6 |
-
tiktoken
|
|
|
|
3 |
python-dotenv
|
4 |
chainlit
|
5 |
chromadb
|
6 |
+
tiktoken
|
7 |
+
tokenizers
|
src/config.py
CHANGED
@@ -1,5 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
class Config:
|
2 |
temperature = 0
|
3 |
streaming = True
|
4 |
chain_type = "stuff"
|
5 |
-
max_token_limit = 4098
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
class Config:
|
8 |
temperature = 0
|
9 |
streaming = True
|
10 |
chain_type = "stuff"
|
11 |
+
max_token_limit = 4098
|
12 |
+
embeddings = OpenAIEmbeddings(api_key=os.getenv('OPENAI_API_KEY'))
|
13 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
src/model.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from langchain.chains import RetrievalQAWithSourcesChain
|
2 |
from langchain.chat_models import ChatOpenAI
|
3 |
import logging
|
4 |
-
|
5 |
|
6 |
|
7 |
from src.config import Config
|
@@ -12,7 +12,7 @@ from src.config import Config
|
|
12 |
|
13 |
def load_model():
|
14 |
model = ChatOpenAI(temperature=Config.temperature,
|
15 |
-
streaming=Config.streaming)
|
16 |
return model
|
17 |
|
18 |
|
|
|
1 |
from langchain.chains import RetrievalQAWithSourcesChain
|
2 |
from langchain.chat_models import ChatOpenAI
|
3 |
import logging
|
4 |
+
import os
|
5 |
|
6 |
|
7 |
from src.config import Config
|
|
|
12 |
|
13 |
def load_model():
|
14 |
model = ChatOpenAI(temperature=Config.temperature,
|
15 |
+
streaming=Config.streaming,api_key=os.getenv('OPENAI_API_KEY'))
|
16 |
return model
|
17 |
|
18 |
|
src/utils.py
CHANGED
@@ -1,12 +1,21 @@
|
|
1 |
from chainlit.types import AskFileResponse
|
|
|
2 |
from langchain.document_loaders import TextLoader
|
3 |
from langchain.document_loaders import PyPDFDirectoryLoader
|
4 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
from langchain.vectorstores import Chroma
|
6 |
-
from langchain.embeddings import OpenAIEmbeddings
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def process_file(file: AskFileResponse):
|
12 |
import tempfile
|
@@ -21,17 +30,49 @@ def process_file(file: AskFileResponse):
|
|
21 |
loader = Loader(tempfile.name)
|
22 |
documents = loader.load()
|
23 |
# text_splitter = text_splitter()
|
24 |
-
docs = text_splitter.split_documents(documents)
|
25 |
|
26 |
for i, doc in enumerate(docs):
|
27 |
doc.metadata["source"] = f"source_{i}"
|
28 |
return docs
|
29 |
|
30 |
-
def get_docSearch(file
|
31 |
docs = process_file(file)
|
32 |
|
|
|
|
|
33 |
## save data in user session
|
|
|
|
|
|
|
34 |
|
35 |
-
docsearch = Chroma.from_documents(docs, embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
|
|
|
|
|
|
|
|
|
|
1 |
from chainlit.types import AskFileResponse
|
2 |
+
import click
|
3 |
from langchain.document_loaders import TextLoader
|
4 |
from langchain.document_loaders import PyPDFDirectoryLoader
|
|
|
5 |
from langchain.vectorstores import Chroma
|
|
|
6 |
|
7 |
+
|
8 |
+
from src.config import Config
|
9 |
+
# import chainlit as cl
|
10 |
+
import logging
|
11 |
+
import openai
|
12 |
+
import os
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
|
15 |
+
load_dotenv()
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
|
20 |
def process_file(file: AskFileResponse):
|
21 |
import tempfile
|
|
|
30 |
loader = Loader(tempfile.name)
|
31 |
documents = loader.load()
|
32 |
# text_splitter = text_splitter()
|
33 |
+
docs = Config.text_splitter.split_documents(documents)
|
34 |
|
35 |
for i, doc in enumerate(docs):
|
36 |
doc.metadata["source"] = f"source_{i}"
|
37 |
return docs
|
38 |
|
39 |
+
def get_docSearch(file,cl):
|
40 |
docs = process_file(file)
|
41 |
|
42 |
+
logging.info("files loaded ")
|
43 |
+
|
44 |
## save data in user session
|
45 |
+
cl.user_session.set("docs",docs)
|
46 |
+
|
47 |
+
logging.info("docs saved in active session")
|
48 |
|
49 |
+
docsearch = Chroma.from_documents(docs, Config.embeddings)
|
50 |
+
|
51 |
+
logging.info("embedding completed")
|
52 |
+
|
53 |
+
return docsearch
|
54 |
+
|
55 |
+
def get_source(sources,all_sources,docs,cl):
|
56 |
+
answer = []
|
57 |
+
source_elements = []
|
58 |
+
if sources:
|
59 |
+
found_sources = []
|
60 |
+
|
61 |
+
# Add the sources to the message
|
62 |
+
for source in sources.split(","):
|
63 |
+
source_name = source.strip().replace(".", "")
|
64 |
+
# Get the index of the source
|
65 |
+
try:
|
66 |
+
index = all_sources.index(source_name)
|
67 |
+
except ValueError:
|
68 |
+
continue
|
69 |
+
text = docs[index].page_content
|
70 |
+
found_sources.append(source_name)
|
71 |
+
# Create the text element referenced in the message
|
72 |
+
source_elements.append(cl.Text(content=text, name=source_name))
|
73 |
|
74 |
+
if found_sources:
|
75 |
+
answer += f"\nSources: {', '.join(found_sources)}"
|
76 |
+
else:
|
77 |
+
answer += "\nNo sources found"
|
78 |
+
return source_elements,answer
|