gizemsarsinlar commited on
Commit
776ae2c
1 Parent(s): 17cc0b5

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +169 -63
  2. requirements.txt +0 -0
app.py CHANGED
@@ -1,63 +1,169 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.document_transformers import LongContextReorder
2
+ from langchain_core.runnables import RunnableLambda
3
+ from langchain_core.runnables.passthrough import RunnableAssign
4
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
5
+
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from langchain_core.output_parsers import StrOutputParser
8
+
9
+ import gradio as gr
10
+ from functools import partial
11
+ from operator import itemgetter
12
+
13
+ from faiss import IndexFlatL2
14
+ from langchain_community.docstore.in_memory import InMemoryDocstore
15
+ import json
16
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
17
+
18
+ from langchain_community.vectorstores import FAISS
19
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
20
+ from langchain.document_loaders import ArxivLoader
21
+
22
+ # NVIDIAEmbeddings.get_available_models()
23
+ embedder = NVIDIAEmbeddings(model="nvidia/embed-qa-4", truncate="END")
24
+ # ChatNVIDIA.get_available_models()
25
+ instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x7b-instruct-v0.1")
26
+
27
+ embed_dims = len(embedder.embed_query("test"))
28
+ def default_FAISS():
29
+ '''Useful utility for making an empty FAISS vectorstore'''
30
+ return FAISS(
31
+ embedding_function=embedder,
32
+ index=IndexFlatL2(embed_dims),
33
+ docstore=InMemoryDocstore(),
34
+ index_to_docstore_id={},
35
+ normalize_L2=False
36
+ )
37
+
38
+ def aggregate_vstores(vectorstores):
39
+ ## Initialize an empty FAISS Index and merge others into it
40
+ ## We'll use default_faiss for simplicity, though it's tied to your embedder by reference
41
+ agg_vstore = default_FAISS()
42
+ for vstore in vectorstores:
43
+ agg_vstore.merge_from(vstore)
44
+ return agg_vstore
45
+
46
+ text_splitter = RecursiveCharacterTextSplitter(
47
+ chunk_size=1000, chunk_overlap=100,
48
+ separators=["\n\n", "\n", ".", ";", ",", " "],
49
+ )
50
+
51
+ docs = [
52
+ ArxivLoader(query="1706.03762").load(), ## Attention Is All You Need Paper
53
+ ArxivLoader(query="1810.04805").load(), ## BERT Paper
54
+ ArxivLoader(query="2005.11401").load(), ## RAG Paper
55
+ ArxivLoader(query="2205.00445").load(), ## MRKL Paper
56
+ ArxivLoader(query="2310.06825").load(), ## Mistral Paper
57
+ ArxivLoader(query="2306.05685").load(), ## LLM-as-a-Judge
58
+ ## Some longer papers
59
+ ArxivLoader(query="2210.03629").load(), ## ReAct Paper
60
+ ArxivLoader(query="2112.10752").load(), ## Latent Stable Diffusion Paper
61
+ ArxivLoader(query="2103.00020").load(), ## CLIP Paper
62
+ ## TODO: Feel free to add more
63
+ ]
64
+
65
+ ## Cut the paper short if references is included.
66
+ ## This is a standard string in papers.
67
+ for doc in docs:
68
+ content = json.dumps(doc[0].page_content)
69
+ if "References" in content:
70
+ doc[0].page_content = content[:content.index("References")]
71
+
72
+ ## Split the documents and also filter out stubs (overly short chunks)
73
+ print("Chunking Documents")
74
+ docs_chunks = [text_splitter.split_documents(doc) for doc in docs]
75
+ docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks]
76
+
77
+ ## Make some custom Chunks to give big-picture details
78
+ doc_string = "Available Documents:"
79
+ doc_metadata = []
80
+ for chunks in docs_chunks:
81
+ metadata = getattr(chunks[0], 'metadata', {})
82
+ doc_string += "\n - " + metadata.get('Title')
83
+ doc_metadata += [str(metadata)]
84
+
85
+ extra_chunks = [doc_string] + doc_metadata
86
+
87
+ vecstores = [FAISS.from_texts(extra_chunks, embedder)]
88
+ vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks]
89
+
90
+ ## Unintuitive optimization; merge_from seems to optimize constituent vector stores away
91
+ docstore = aggregate_vstores(vecstores)
92
+
93
+ print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks")
94
+
95
+ convstore = default_FAISS()
96
+
97
+ def save_memory_and_get_output(d, vstore):
98
+ """Accepts 'input'/'output' dictionary and saves to convstore"""
99
+ vstore.add_texts([
100
+ f"User previously responded with {d.get('input')}",
101
+ f"Agent previously responded with {d.get('output')}"
102
+ ])
103
+ return d.get('output')
104
+
105
+ initial_msg = (
106
+ "Hello! I am a document chat agent here to help the user!"
107
+ f" I have access to the following documents: {doc_string}\n\nHow can I help you?"
108
+ )
109
+
110
+ chat_prompt = ChatPromptTemplate.from_messages([("system",
111
+ "You are a document chatbot. Help the user as they ask questions about documents."
112
+ " User messaged just asked: {input}\n\n"
113
+ " From this, we have retrieved the following potentially-useful info: "
114
+ " Conversation History Retrieval:\n{history}\n\n"
115
+ " Document Retrieval:\n{context}\n\n"
116
+ " (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)"
117
+ ), ('user', '{input}')])
118
+
119
+ stream_chain = chat_prompt| RPrint() | instruct_llm | StrOutputParser()
120
+
121
+ def RPrint(preface=""):
122
+ """Simple passthrough "prints, then returns" chain"""
123
+ def print_and_return(x, preface):
124
+ if preface: print(preface, end="")
125
+ return x
126
+ return RunnableLambda(partial(print_and_return, preface=preface))
127
+
128
+ retrieval_chain = (
129
+ {'input' : (lambda x: x)}
130
+ ## TODO: Make sure to retrieve history & context from convstore & docstore, respectively.
131
+ ## HINT: Our solution uses RunnableAssign, itemgetter, long_reorder, and docs2str
132
+ | RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str})
133
+ | RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever() | long_reorder | docs2str})
134
+ | RPrint()
135
+ )
136
+
137
+ def chat_gen(message, history=[], return_buffer=True):
138
+ buffer = ""
139
+ ## First perform the retrieval based on the input message
140
+ retrieval = retrieval_chain.invoke(message)
141
+ line_buffer = ""
142
+
143
+ ## Then, stream the results of the stream_chain
144
+ for token in stream_chain.stream(retrieval):
145
+ buffer += token
146
+ ## If you're using standard print, keep line from getting too long
147
+ yield buffer if return_buffer else token
148
+
149
+ ## Lastly, save the chat exchange to the conversation memory buffer
150
+ save_memory_and_get_output({'input': message, 'output': buffer}, convstore)
151
+
152
+
153
+ # ## Start of Agent Event Loop
154
+ # test_question = "Tell me about RAG!" ## <- modify as desired
155
+
156
+ # ## Before you launch your gradio interface, make sure your thing works
157
+ # for response in chat_gen(test_question, return_buffer=False):
158
+ # print(response, end='')
159
+
160
+ chatbot = gr.Chatbot(value = [[None, initial_msg]])
161
+ demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
162
+
163
+ try:
164
+ demo.launch(debug=True, share=True, show_api=False)
165
+ demo.close()
166
+ except Exception as e:
167
+ demo.close()
168
+ print(e)
169
+ raise e
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ