File size: 3,573 Bytes
fc540fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from base64 import b64decode
import os
from dotenv import load_dotenv
load_dotenv


def parse_docs(docs):
    """Split base64-encoded images and texts"""
    b64 = []
    text = []
    for doc in docs:
        try:
            b64decode(doc)
            b64.append(doc)
        except Exception as e:
            text.append(doc)
    return {"images": b64, "texts": text}


def build_prompt(kwargs):
    docs_by_type = kwargs["context"]
    user_question = kwargs["question"]

    # Extract text context
    context_text = ""
    if docs_by_type.get("texts"):
        for text_element in docs_by_type["texts"]:
            if isinstance(text_element, list):
                # Flatten nested lists before joining
                flat_text = " ".join(
                    " ".join(map(str, sub_element)) if isinstance(sub_element, list) else str(sub_element)
                    for sub_element in text_element
                )
                context_text += flat_text + "\n"
            else:
                context_text += str(text_element) + "\n"

    # Extract table context
    context_tables = ""
    if docs_by_type.get("tables"):
        for table in docs_by_type["tables"]:
            table_str = "\n".join([" | ".join(map(str, row)) for row in table])  # Convert table rows to strings
            context_tables += f"\nTable:\n{table_str}\n"

    # Construct prompt with context (including tables)
    prompt_template = f"""

    Answer the question based only on the following context only Ground Truth Final Answer, which includes text and tables.

    If you don't know the answer, just say that you don't know, don't try to make up an answer.

    Be specific



    Context:

    {context_text}



    {context_tables}



    Question: {user_question}

    """

    prompt_content = [{"type": "text", "text": prompt_template}]

    # If images are provided, include them
    # if docs_by_type.get("images"):
    #     for image in docs_by_type["images"]:
    #         prompt_content.append(
    #             {
    #                 "type": "image_url",
    #                 "image_url": {"url": f"data:image/jpeg;base64,{image}"},
    #             }
    #         )

    return ChatPromptTemplate.from_messages(
        [
            HumanMessage(content=prompt_content),
        ]
    )

def create_rag_chain(retriever):
    chain = (
        {
            "context": retriever | RunnableLambda(parse_docs),
            "question": RunnablePassthrough(),
        }
        | RunnableLambda(build_prompt)
        | ChatGroq(model="llama-3.3-70b-versatile",api_key=os.environ["GROQ_API_KEY"])
        | StrOutputParser()
    )

    chain_with_sources = {
        "context": retriever | RunnableLambda(parse_docs),
        "question": RunnablePassthrough(),
    } | RunnablePassthrough().assign(
        response=(
            RunnableLambda(build_prompt)
            | ChatGroq(model="llama-3.3-70b-versatile",api_key=os.environ["GROQ_API_KEY"])
            | StrOutputParser()
        )
    )
    return chain, chain_with_sources

def invoke_chain(chain):
    response = chain.invoke(
    "What is the policy start and expiry date?"
    

    )

    return response