shukdevdatta123 commited on
Commit
315655d
·
verified ·
1 Parent(s): b6f8341

Update generate_answer.py

Browse files
Files changed (1) hide show
  1. generate_answer.py +16 -28
generate_answer.py CHANGED
@@ -1,8 +1,9 @@
 
 
1
  import os
2
  from glob import glob
3
- # import subprocess
4
-
5
  import openai
 
6
  from dotenv import load_dotenv
7
 
8
  from langchain.embeddings import OpenAIEmbeddings
@@ -14,34 +15,31 @@ from langchain_community.chat_models import ChatOpenAI
14
  from langchain.chains import RetrievalQA
15
  from langchain.memory import ConversationBufferMemory
16
 
17
-
18
  load_dotenv()
19
  api_key = os.getenv("OPENAI_API_KEY")
20
 
21
- client = openai(api_key=api_key)
22
  openai.api_key = api_key
23
 
24
-
25
  def base_model_chatbot(messages):
26
  system_message = [
27
- {"role": "system", "content": "You are an helpful AI chatbot, that answers questions asked by User."}]
 
28
  messages = system_message + messages
29
- response = client.chat.completions.create(
30
- model="gpt-3.5-turbo-1106",
31
  messages=messages
32
  )
33
- return response.choices[0].message.content
34
 
35
 
36
  class VectorDB:
37
  """Class to manage document loading and vector database creation."""
38
 
39
- def __init__(self, docs_directory:str):
40
-
41
  self.docs_directory = docs_directory
42
 
43
  def create_vector_db(self):
44
-
45
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
46
 
47
  files = glob(os.path.join(self.docs_directory, "*.pdf"))
@@ -61,32 +59,22 @@ class ConversationalRetrievalChain:
61
  def __init__(self, model_name="gpt-3.5-turbo", temperature=0):
62
  self.model_name = model_name
63
  self.temperature = temperature
64
-
65
- def create_chain(self):
66
-
67
- model = ChatOpenAI(model_name=self.model_name,
68
- temperature=self.temperature,
69
- )
70
 
71
- memory = ConversationBufferMemory(
72
- memory_key="chat_history",
73
- return_messages=True
74
- )
75
  vector_db = VectorDB('docs/')
76
- retriever = vector_db.create_vector_db().as_retriever(search_type="similarity",
77
- search_kwargs={"k": 2},
78
- )
79
  return RetrievalQA.from_chain_type(
80
  llm=model,
81
  retriever=retriever,
82
  memory=memory,
83
- )
84
 
85
  def with_pdf_chatbot(messages):
86
  """Main function to execute the QA system."""
87
  query = messages[-1]['content'].strip()
88
 
89
-
90
  qa_chain = ConversationalRetrievalChain().create_chain()
91
  result = qa_chain({"query": query})
92
- return result['result']
 
1
+ # generate_answer.py
2
+
3
  import os
4
  from glob import glob
 
 
5
  import openai
6
+ from openai import OpenAI
7
  from dotenv import load_dotenv
8
 
9
  from langchain.embeddings import OpenAIEmbeddings
 
15
  from langchain.chains import RetrievalQA
16
  from langchain.memory import ConversationBufferMemory
17
 
 
18
  load_dotenv()
19
  api_key = os.getenv("OPENAI_API_KEY")
20
 
21
+ # Corrected line: Set the OpenAI API key correctly
22
  openai.api_key = api_key
23
 
 
24
  def base_model_chatbot(messages):
25
  system_message = [
26
+ {"role": "system", "content": "You are an helpful AI chatbot, that answers questions asked by User."}
27
+ ]
28
  messages = system_message + messages
29
+ response = openai.ChatCompletion.create(
30
+ model="gpt-3.5-turbo-1106", # Ensure the model is specified correctly
31
  messages=messages
32
  )
33
+ return response.choices[0].message['content']
34
 
35
 
36
  class VectorDB:
37
  """Class to manage document loading and vector database creation."""
38
 
39
+ def __init__(self, docs_directory: str):
 
40
  self.docs_directory = docs_directory
41
 
42
  def create_vector_db(self):
 
43
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
44
 
45
  files = glob(os.path.join(self.docs_directory, "*.pdf"))
 
59
  def __init__(self, model_name="gpt-3.5-turbo", temperature=0):
60
  self.model_name = model_name
61
  self.temperature = temperature
 
 
 
 
 
 
62
 
63
+ def create_chain(self):
64
+ model = ChatOpenAI(model_name=self.model_name, temperature=self.temperature)
65
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
 
66
  vector_db = VectorDB('docs/')
67
+ retriever = vector_db.create_vector_db().as_retriever(search_type="similarity", search_kwargs={"k": 2})
 
 
68
  return RetrievalQA.from_chain_type(
69
  llm=model,
70
  retriever=retriever,
71
  memory=memory,
72
+ )
73
 
74
  def with_pdf_chatbot(messages):
75
  """Main function to execute the QA system."""
76
  query = messages[-1]['content'].strip()
77
 
 
78
  qa_chain = ConversationalRetrievalChain().create_chain()
79
  result = qa_chain({"query": query})
80
+ return result['result']