midrees2806 commited on
Commit
ae1fcab
·
verified ·
1 Parent(s): 5b57f0c

Update pdf_bot.py

Browse files
Files changed (1) hide show
  1. pdf_bot.py +8 -4
pdf_bot.py CHANGED
@@ -13,6 +13,7 @@ from groq import Groq
13
  load_dotenv()
14
  groq_api_key = os.getenv("GROQ_API_KEY")
15
 
 
16
  class ChatGroq(BaseChatModel):
17
  model: str = "llama3-8b-8192"
18
  temperature: float = 0.3
@@ -23,7 +24,7 @@ class ChatGroq(BaseChatModel):
23
  self.model = model
24
  self.temperature = temperature
25
  self.groq_api_key = api_key
26
- self._client = Groq(api_key=api_key)
27
 
28
  def _generate(self, messages, stop=None):
29
  prompt = [{"role": "user", "content": self._get_message_text(messages)}]
@@ -37,16 +38,19 @@ class ChatGroq(BaseChatModel):
37
  return ChatResult(generations=[ChatGeneration(message=AIMessage(content=content))])
38
 
39
  def _get_message_text(self, messages):
40
- return " ".join([msg.content for msg in messages])
 
 
41
 
42
  @property
43
  def _llm_type(self) -> str:
44
  return "chat-groq"
45
 
 
46
  def create_qa_chain_from_pdf(pdf_path):
47
  loader = PyPDFLoader(pdf_path)
48
  documents = loader.load()
49
-
50
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
51
  texts = splitter.split_documents(documents)
52
 
@@ -61,4 +65,4 @@ def create_qa_chain_from_pdf(pdf_path):
61
  retriever=vectorstore.as_retriever(search_kwargs={"k": 1}),
62
  return_source_documents=True
63
  )
64
- return qa_chain
 
13
  load_dotenv()
14
  groq_api_key = os.getenv("GROQ_API_KEY")
15
 
16
+ # ✅ Custom wrapper
17
  class ChatGroq(BaseChatModel):
18
  model: str = "llama3-8b-8192"
19
  temperature: float = 0.3
 
24
  self.model = model
25
  self.temperature = temperature
26
  self.groq_api_key = api_key
27
+ self._client = Groq(api_key=api_key) # ✅ use _client instead of client
28
 
29
  def _generate(self, messages, stop=None):
30
  prompt = [{"role": "user", "content": self._get_message_text(messages)}]
 
38
  return ChatResult(generations=[ChatGeneration(message=AIMessage(content=content))])
39
 
40
  def _get_message_text(self, messages):
41
+ if isinstance(messages, list):
42
+ return " ".join([msg.content for msg in messages])
43
+ return messages.content
44
 
45
  @property
46
  def _llm_type(self) -> str:
47
  return "chat-groq"
48
 
49
+ # ✅ Function to return a QA chain
50
  def create_qa_chain_from_pdf(pdf_path):
51
  loader = PyPDFLoader(pdf_path)
52
  documents = loader.load()
53
+
54
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
55
  texts = splitter.split_documents(documents)
56
 
 
65
  retriever=vectorstore.as_retriever(search_kwargs={"k": 1}),
66
  return_source_documents=True
67
  )
68
+ return qa_chain