Spaces:
Sleeping
Sleeping
File size: 2,439 Bytes
48322d5 f8774a1 48322d5 e3fae22 f8774a1 48322d5 ae1fcab e3fae22 d772609 e3fae22 d772609 e3fae22 d772609 ae1fcab e3fae22 d772609 e3fae22 ae1fcab e3fae22 d772609 e3fae22 ae1fcab 48322d5 ae1fcab 48322d5 e3fae22 48322d5 e3fae22 48322d5 ae1fcab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
from dotenv import load_dotenv
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_core.language_models import BaseChatModel
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.messages import AIMessage
from groq import Groq
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
# ✅ Custom wrapper
class ChatGroq(BaseChatModel):
model: str = "llama3-8b-8192"
temperature: float = 0.3
groq_api_key: str = None
def __init__(self, model="llama3-8b-8192", temperature=0.3, api_key=None):
super().__init__()
self.model = model
self.temperature = temperature
self.groq_api_key = api_key
self._client = Groq(api_key=api_key) # ✅ use _client instead of client
def _generate(self, messages, stop=None):
prompt = [{"role": "user", "content": self._get_message_text(messages)}]
response = self._client.chat.completions.create(
model=self.model,
messages=prompt,
temperature=self.temperature,
max_tokens=1024
)
content = response.choices[0].message.content.strip()
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=content))])
def _get_message_text(self, messages):
if isinstance(messages, list):
return " ".join([msg.content for msg in messages])
return messages.content
@property
def _llm_type(self) -> str:
return "chat-groq"
# ✅ Function to return a QA chain
def create_qa_chain_from_pdf(pdf_path):
loader = PyPDFLoader(pdf_path)
documents = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-m3")
vectorstore = FAISS.from_documents(texts, embeddings)
llm = ChatGroq(model="llama3-8b-8192", temperature=0.3, api_key=groq_api_key)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={"k": 1}),
return_source_documents=True
)
return qa_chain |