midrees2806 commited on
Commit
e3fae22
·
verified ·
1 Parent(s): d7f7c81

Update pdf_bot.py

Browse files
Files changed (1) hide show
  1. pdf_bot.py +33 -8
pdf_bot.py CHANGED
@@ -5,28 +5,53 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
5
  from langchain_community.vectorstores import FAISS
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.chains import RetrievalQA
 
 
 
8
  from groq import Groq
9
 
10
  load_dotenv()
11
  groq_api_key = os.getenv("GROQ_API_KEY")
12
- hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
 
14
- # Load PDF and prepare QA chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def create_qa_chain_from_pdf(pdf_path):
16
  loader = PyPDFLoader(pdf_path)
17
  documents = loader.load()
18
 
19
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
20
  texts = splitter.split_documents(documents)
21
-
22
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-m3")
23
  vectorstore = FAISS.from_documents(texts, embeddings)
24
 
25
- llm = ChatGroq(
26
- model="llama3-8b-8192",
27
- temperature=0.3,
28
- api_key=groq_api_key,
29
- )
30
 
31
  qa_chain = RetrievalQA.from_chain_type(
32
  llm=llm,
 
5
  from langchain_community.vectorstores import FAISS
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.chains import RetrievalQA
8
+ from langchain_core.language_models import BaseChatModel
9
+ from langchain_core.outputs import ChatResult, ChatGeneration
10
+ from langchain_core.messages import AIMessage
11
  from groq import Groq
12
 
13
  load_dotenv()
14
  groq_api_key = os.getenv("GROQ_API_KEY")
 
15
 
16
+ # Custom wrapper
17
+ class ChatGroq(BaseChatModel):
18
+ def __init__(self, model="llama3-8b-8192", temperature=0.3, api_key=None):
19
+ self.client = Groq(api_key=api_key)
20
+ self.model = model
21
+ self.temperature = temperature
22
+
23
+ def _generate(self, messages, stop=None):
24
+ prompt = [{"role": "user", "content": self._get_message_text(messages)}]
25
+ response = self.client.chat.completions.create(
26
+ model=self.model,
27
+ messages=prompt,
28
+ temperature=self.temperature,
29
+ max_tokens=1024
30
+ )
31
+ content = response.choices[0].message.content.strip()
32
+ return ChatResult(generations=[ChatGeneration(message=AIMessage(content=content))])
33
+
34
+ def _get_message_text(self, messages):
35
+ if isinstance(messages, list):
36
+ return " ".join([msg.content for msg in messages])
37
+ return messages.content
38
+
39
+ @property
40
+ def _llm_type(self):
41
+ return "chat-groq"
42
+
43
+ # ✅ Function to return a QA chain
44
  def create_qa_chain_from_pdf(pdf_path):
45
  loader = PyPDFLoader(pdf_path)
46
  documents = loader.load()
47
 
48
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
49
  texts = splitter.split_documents(documents)
50
+
51
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-m3")
52
  vectorstore = FAISS.from_documents(texts, embeddings)
53
 
54
+ llm = ChatGroq(model="llama3-8b-8192", temperature=0.3, api_key=groq_api_key)
 
 
 
 
55
 
56
  qa_chain = RetrievalQA.from_chain_type(
57
  llm=llm,