File size: 4,998 Bytes
0c62088
 
 
 
 
 
643f5c3
 
 
0c62088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145cdfd
0c62088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643f5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c62088
 
643f5c3
0c62088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chat_models import ChatOpenAI
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
import os
import arxiv
import chainlit as cl
from chainlit import user_session

@cl.langchain_factory(use_async=True)
async def init():
    arxiv_query = None

    # Wait for the user to ask an Arxiv question
    while arxiv_query == None:
        arxiv_query = await cl.AskUserMessage(
            content="Please enter a topic to begin!", timeout=15
        ).send()

    # Obtain the top 3 results from Arxiv for the query
    search = arxiv.Search(
        query=arxiv_query["content"],
        max_results=3,
        sort_by=arxiv.SortCriterion.Relevance,
    )

    await cl.Message(content="Downloading and chunking articles...").send()
    # download each of the pdfs
    pdf_data = []
    for result in search.results():
        loader = PyMuPDFLoader(result.pdf_url)
        loaded_pdf = loader.load()

        for document in loaded_pdf:
            document.metadata["source"] = result.entry_id
            document.metadata["file_path"] = result.pdf_url
            document.metadata["title"] = result.title
            pdf_data.append(document)

    # Create a Chroma vector store
    embeddings = OpenAIEmbeddings(
        disallowed_special=(),
    )
    
    # If operation takes too long, make_async allows to run in a thread
    # docsearch = await cl.make_async(Chroma.from_documents)(pdf_data, embeddings) 
    docsearch = Chroma.from_documents(pdf_data, embeddings)

    # custom SageMaker Model
    class Llama2SageMaker(LLM):
        max_new_tokens: int = 256
        top_p: float = 0.9
        temperature: float = 0.1
    
        @property
        def _llm_type(self) -> str:
            return "Llama2SageMaker"
    
        def _call(
            self,
            prompt: str,
            stop: Optional[List[str]] = None,
            run_manager: Optional[CallbackManagerForLLMRun] = None,
        ) -> str:
            if stop is not None:
                raise ValueError("stop kwargs are not permitted.")
            
            json_body = {
                "inputs" : [
                  [{"role" : "user", "content" : prompt}]
                ],
                "parameters" : {
                    "max_new_tokens" : self.max_new_tokens,
                    "top_p" : self.top_p,
                    "temperature" : self.temperature
                }
            }
    
            response = requests.post(model_api_gateway, json=json_body)
    
            return response.json()[0]["generation"]["content"]
    
        @property
        def _identifying_params(self) -> Mapping[str, Any]:
            """Get the identifying parameters."""
            return {
                "max_new_tokens" : self.max_new_tokens,
                "top_p" : self.top_p,
                "temperature" : self.temperature
            }

    # set our llm to the custom Llama2SageMaker endpoint model
    llm = Llama2SageMaker()

    # Create a chain that uses the Chroma vector store
    chain = RetrievalQAWithSourcesChain.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=docsearch.as_retriever(),
        return_source_documents=True,
    )

    # Let the user know that the system is ready
    await cl.Message(
        content=f"We found a few papers about `{arxiv_query['content']}` you can now ask questions!"
    ).send()

    return chain


@cl.langchain_postprocess
async def process_response(res):
    answer = res["answer"]
    source_elements_dict = {}
    source_elements = []
    for idx, source in enumerate(res["source_documents"]):
        title = source.metadata["title"]

        if title not in source_elements_dict:
            source_elements_dict[title] = {
                "page_number": [source.metadata["page"]],
                "url": source.metadata["file_path"],
            }

        else:
            source_elements_dict[title]["page_number"].append(source.metadata["page"])

        # sort the page numbers
        source_elements_dict[title]["page_number"].sort()

    for title, source in source_elements_dict.items():
        # create a string for the page numbers
        page_numbers = ", ".join([str(x) for x in source["page_number"]])
        text_for_source = f"Page Number(s): {page_numbers}\nURL: {source['url']}"
        source_elements.append(
            cl.Text(name=title, content=text_for_source, display="inline")
        )

    await cl.Message(content=answer, elements=source_elements).send()