Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains import RetrievalQA | |
from langchain_core.language_models import BaseChatModel | |
from langchain_core.outputs import ChatResult, ChatGeneration | |
from langchain_core.messages import AIMessage | |
from groq import Groq | |
load_dotenv() | |
groq_api_key = os.getenv("GROQ_API_KEY") | |
# β Custom wrapper | |
class ChatGroq(BaseChatModel): | |
model: str = "llama3-8b-8192" | |
temperature: float = 0.3 | |
groq_api_key: str = None | |
def __init__(self, model="llama3-8b-8192", temperature=0.3, api_key=None): | |
super().__init__() | |
self.model = model | |
self.temperature = temperature | |
self.groq_api_key = api_key | |
self._client = Groq(api_key=api_key) # β use _client instead of client | |
def _generate(self, messages, stop=None): | |
prompt = [{"role": "user", "content": self._get_message_text(messages)}] | |
response = self._client.chat.completions.create( | |
model=self.model, | |
messages=prompt, | |
temperature=self.temperature, | |
max_tokens=1024 | |
) | |
content = response.choices[0].message.content.strip() | |
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=content))]) | |
def _get_message_text(self, messages): | |
if isinstance(messages, list): | |
return " ".join([msg.content for msg in messages]) | |
return messages.content | |
def _llm_type(self) -> str: | |
return "chat-groq" | |
# β Function to return a QA chain | |
def create_qa_chain_from_pdf(pdf_path): | |
loader = PyPDFLoader(pdf_path) | |
documents = loader.load() | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
texts = splitter.split_documents(documents) | |
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-m3") | |
vectorstore = FAISS.from_documents(texts, embeddings) | |
llm = ChatGroq(model="llama3-8b-8192", temperature=0.3, api_key=groq_api_key) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=vectorstore.as_retriever(search_kwargs={"k": 1}), | |
return_source_documents=True | |
) | |
return qa_chain |