shukdevdatta123 commited on
Commit
0eb6419
·
verified ·
1 Parent(s): ed6dab8

Create generate_answer.py

Browse files
Files changed (1) hide show
  1. generate_answer.py +93 -0
generate_answer.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ # import subprocess
4
+
5
+ import openai
6
+ from openai import OpenAI
7
+ from dotenv import load_dotenv
8
+
9
+ from langchain.embeddings import OpenAIEmbeddings
10
+ from langchain.vectorstores import Chroma
11
+ from langchain.document_loaders import PyPDFLoader
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+
14
+ from langchain_community.chat_models import ChatOpenAI
15
+ from langchain.chains import RetrievalQA
16
+ from langchain.memory import ConversationBufferMemory
17
+
18
+
19
+ load_dotenv()
20
+ api_key = os.getenv("OPENAI_API_KEY")
21
+
22
+ client = OpenAI(api_key=api_key)
23
+ openai.api_key = api_key
24
+
25
+
26
+ def base_model_chatbot(messages):
27
+ system_message = [
28
+ {"role": "system", "content": "You are an helpful AI chatbot, that answers questions asked by User."}]
29
+ messages = system_message + messages
30
+ response = client.chat.completions.create(
31
+ model="gpt-3.5-turbo-1106",
32
+ messages=messages
33
+ )
34
+ return response.choices[0].message.content
35
+
36
+
37
+ class VectorDB:
38
+ """Class to manage document loading and vector database creation."""
39
+
40
+ def __init__(self, docs_directory:str):
41
+
42
+ self.docs_directory = docs_directory
43
+
44
+ def create_vector_db(self):
45
+
46
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
47
+
48
+ files = glob(os.path.join(self.docs_directory, "*.pdf"))
49
+
50
+ loadPDFs = [PyPDFLoader(pdf_file) for pdf_file in files]
51
+
52
+ pdf_docs = list()
53
+ for loader in loadPDFs:
54
+ pdf_docs.extend(loader.load())
55
+ chunks = text_splitter.split_documents(pdf_docs)
56
+
57
+ return Chroma.from_documents(chunks, OpenAIEmbeddings())
58
+
59
+ class ConversationalRetrievalChain:
60
+ """Class to manage the QA chain setup."""
61
+
62
+ def __init__(self, model_name="gpt-3.5-turbo", temperature=0):
63
+ self.model_name = model_name
64
+ self.temperature = temperature
65
+
66
+ def create_chain(self):
67
+
68
+ model = ChatOpenAI(model_name=self.model_name,
69
+ temperature=self.temperature,
70
+ )
71
+
72
+ memory = ConversationBufferMemory(
73
+ memory_key="chat_history",
74
+ return_messages=True
75
+ )
76
+ vector_db = VectorDB('docs/')
77
+ retriever = vector_db.create_vector_db().as_retriever(search_type="similarity",
78
+ search_kwargs={"k": 2},
79
+ )
80
+ return RetrievalQA.from_chain_type(
81
+ llm=model,
82
+ retriever=retriever,
83
+ memory=memory,
84
+ )
85
+
86
+ def with_pdf_chatbot(messages):
87
+ """Main function to execute the QA system."""
88
+ query = messages[-1]['content'].strip()
89
+
90
+
91
+ qa_chain = ConversationalRetrievalChain().create_chain()
92
+ result = qa_chain({"query": query})
93
+ return result['result']