File size: 2,439 Bytes
48322d5
 
f8774a1
 
 
48322d5
 
e3fae22
 
 
f8774a1
48322d5
 
 
 
ae1fcab
e3fae22
d772609
 
 
 
e3fae22
d772609
e3fae22
 
d772609
ae1fcab
e3fae22
 
 
d772609
e3fae22
 
 
 
 
 
 
 
 
ae1fcab
 
 
e3fae22
 
d772609
e3fae22
 
ae1fcab
48322d5
 
 
ae1fcab
48322d5
 
e3fae22
48322d5
 
 
e3fae22
48322d5
 
 
 
 
 
 
ae1fcab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
from dotenv import load_dotenv
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_core.language_models import BaseChatModel
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.messages import AIMessage
from groq import Groq

load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")

# ✅ Custom wrapper
class ChatGroq(BaseChatModel):
    model: str = "llama3-8b-8192"
    temperature: float = 0.3
    groq_api_key: str = None

    def __init__(self, model="llama3-8b-8192", temperature=0.3, api_key=None):
        super().__init__()
        self.model = model
        self.temperature = temperature
        self.groq_api_key = api_key
        self._client = Groq(api_key=api_key)  # ✅ use _client instead of client

    def _generate(self, messages, stop=None):
        prompt = [{"role": "user", "content": self._get_message_text(messages)}]
        response = self._client.chat.completions.create(
            model=self.model,
            messages=prompt,
            temperature=self.temperature,
            max_tokens=1024
        )
        content = response.choices[0].message.content.strip()
        return ChatResult(generations=[ChatGeneration(message=AIMessage(content=content))])

    def _get_message_text(self, messages):
        if isinstance(messages, list):
            return " ".join([msg.content for msg in messages])
        return messages.content

    @property
    def _llm_type(self) -> str:
        return "chat-groq"

# ✅ Function to return a QA chain
def create_qa_chain_from_pdf(pdf_path):
    loader = PyPDFLoader(pdf_path)
    documents = loader.load()
    
    splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    texts = splitter.split_documents(documents)

    embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-m3")
    vectorstore = FAISS.from_documents(texts, embeddings)

    llm = ChatGroq(model="llama3-8b-8192", temperature=0.3, api_key=groq_api_key)

    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=vectorstore.as_retriever(search_kwargs={"k": 1}),
        return_source_documents=True
    )
    return qa_chain