Dull / pdf_bot.py
midrees2806's picture
Update pdf_bot.py
ae1fcab verified
import os
from dotenv import load_dotenv
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_core.language_models import BaseChatModel
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.messages import AIMessage
from groq import Groq
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
# βœ… Custom wrapper
class ChatGroq(BaseChatModel):
model: str = "llama3-8b-8192"
temperature: float = 0.3
groq_api_key: str = None
def __init__(self, model="llama3-8b-8192", temperature=0.3, api_key=None):
super().__init__()
self.model = model
self.temperature = temperature
self.groq_api_key = api_key
self._client = Groq(api_key=api_key) # βœ… use _client instead of client
def _generate(self, messages, stop=None):
prompt = [{"role": "user", "content": self._get_message_text(messages)}]
response = self._client.chat.completions.create(
model=self.model,
messages=prompt,
temperature=self.temperature,
max_tokens=1024
)
content = response.choices[0].message.content.strip()
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=content))])
def _get_message_text(self, messages):
if isinstance(messages, list):
return " ".join([msg.content for msg in messages])
return messages.content
@property
def _llm_type(self) -> str:
return "chat-groq"
# βœ… Function to return a QA chain
def create_qa_chain_from_pdf(pdf_path):
loader = PyPDFLoader(pdf_path)
documents = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-m3")
vectorstore = FAISS.from_documents(texts, embeddings)
llm = ChatGroq(model="llama3-8b-8192", temperature=0.3, api_key=groq_api_key)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={"k": 1}),
return_source_documents=True
)
return qa_chain