File size: 3,630 Bytes
05a8b3a
 
 
 
 
 
 
 
 
16522e2
 
 
05a8b3a
52bc1cc
05a8b3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16522e2
05a8b3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52bc1cc
05a8b3a
16522e2
 
52bc1cc
16522e2
05a8b3a
16522e2
 
 
 
 
 
 
05a8b3a
 
 
 
 
 
 
16522e2
05a8b3a
16522e2
 
 
 
 
05a8b3a
16522e2
 
05a8b3a
16522e2
05a8b3a
 
 
 
 
 
16522e2
05a8b3a
16522e2
05a8b3a
16522e2
 
 
 
05a8b3a
16522e2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document

from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.chains.prompts import papers_prompt_template
import time
from ..utils import rename_chain, pass_values


DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="Source : {source} - Content : {page_content}")

def _combine_documents(
    docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):

    doc_strings =  []

    for i,doc in enumerate(docs):
        # chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
        chunk_type = "Doc"
        if isinstance(doc,str):
            doc_formatted = doc
        else:
            doc_formatted = format_document(doc, document_prompt)
        doc_string = f"{chunk_type} {i+1}: " + doc_formatted
        doc_string = doc_string.replace("\n"," ") 
        doc_strings.append(doc_string)

    return sep.join(doc_strings)


def get_text_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "text"]

def get_image_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "image"]

def make_rag_chain(llm):
    prompt = ChatPromptTemplate.from_template(answer_prompt_template)
    chain = ({
        "context":lambda x : _combine_documents(x["documents"]),
        "context_length":lambda x : print("CONTEXT LENGTH : " , len(_combine_documents(x["documents"]))),
        "query":itemgetter("query"),
        "language":itemgetter("language"),
        "audience":itemgetter("audience"),
    } | prompt | llm | StrOutputParser())
    return chain

def make_rag_chain_without_docs(llm):
    prompt = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
    chain = prompt | llm | StrOutputParser()
    return chain

def make_rag_node(llm,with_docs = True):

    if with_docs:
        rag_chain = make_rag_chain(llm)
    else:
        rag_chain = make_rag_chain_without_docs(llm)
    
    async def answer_rag(state,config):
        print("---- Answer RAG ----")
        start_time = time.time()
        print("Sources used : " +  "\n".join([x.metadata["short_name"] + " - page " + str(x.metadata["page_number"])  for x in state["documents"]]))

        answer = await rag_chain.ainvoke(state,config)
    
        end_time = time.time()
        elapsed_time = end_time - start_time
        print("RAG elapsed time: ", elapsed_time)
        print("Answer size : ", len(answer))
        # print(f"\n\nAnswer:\n{answer}")
        
        return {"answer":answer}

    return answer_rag




def make_rag_papers_chain(llm):

    prompt = ChatPromptTemplate.from_template(papers_prompt_template)
    input_documents = {
        "context":lambda x : _combine_documents(x["docs"]),
        **pass_values(["question","language"])
    }

    chain = input_documents | prompt | llm | StrOutputParser()
    chain = rename_chain(chain,"answer")

    return chain






def make_illustration_chain(llm):

    prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)

    input_description_images = {
        "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
        **pass_values(["question","audience","language","answer"]),
    }

    illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
    return illustration_chain