import os from dotenv import load_dotenv from langchain_community.document_loaders import PyPDFLoader from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains import RetrievalQA from langchain_core.language_models import BaseChatModel from langchain_core.outputs import ChatResult, ChatGeneration from langchain_core.messages import AIMessage from groq import Groq load_dotenv() groq_api_key = os.getenv("GROQ_API_KEY") # ✅ Custom wrapper class ChatGroq(BaseChatModel): model: str = "llama3-8b-8192" temperature: float = 0.3 groq_api_key: str = None def __init__(self, model="llama3-8b-8192", temperature=0.3, api_key=None): super().__init__() self.model = model self.temperature = temperature self.groq_api_key = api_key self._client = Groq(api_key=api_key) # ✅ use _client instead of client def _generate(self, messages, stop=None): prompt = [{"role": "user", "content": self._get_message_text(messages)}] response = self._client.chat.completions.create( model=self.model, messages=prompt, temperature=self.temperature, max_tokens=1024 ) content = response.choices[0].message.content.strip() return ChatResult(generations=[ChatGeneration(message=AIMessage(content=content))]) def _get_message_text(self, messages): if isinstance(messages, list): return " ".join([msg.content for msg in messages]) return messages.content @property def _llm_type(self) -> str: return "chat-groq" # ✅ Function to return a QA chain def create_qa_chain_from_pdf(pdf_path): loader = PyPDFLoader(pdf_path) documents = loader.load() splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) texts = splitter.split_documents(documents) embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-m3") vectorstore = FAISS.from_documents(texts, embeddings) llm = ChatGroq(model="llama3-8b-8192", temperature=0.3, api_key=groq_api_key) qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever(search_kwargs={"k": 1}), return_source_documents=True ) return qa_chain