pgurazada1 commited on
Commit
e7c2848
·
verified ·
1 Parent(s): 82e7f41

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+
4
+ from dotenv import load_dotenv
5
+ from openai import AzureOpenAI
6
+
7
+ from langchain_openai import AzureOpenAIEmbeddings
8
+
9
+ from langchain_community.vectorstores import Chroma
10
+ from langchain.retrievers import ContextualCompressionRetriever
11
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
12
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
13
+
14
+ load_dotenv()
15
+
16
+ client = AzureOpenAI(
17
+ api_key=os.environ['AZURE_OPENAI_KEY'],
18
+ azure_endpoint=os.environ['AZURE_OPENAI_ENDPOINT'],
19
+ api_version='2024-02-01'
20
+ )
21
+
22
+ model_name = 'gpt-4o-mini'
23
+
24
+ embedding_model = AzureOpenAIEmbeddings(
25
+ api_key=os.environ['AZURE_OPENAI_KEY'],
26
+ azure_endpoint=os.environ['AZURE_OPENAI_ENDPOINT'],
27
+ api_version='2024-02-01',
28
+ azure_deployment="text-embedding-ada-002"
29
+ )
30
+
31
+ tesla_10k_collection = 'tesla-10k-2021-2023'
32
+
33
+ vectorstore_persisted = Chroma(
34
+ collection_name=tesla_10k_collection,
35
+ persist_directory='./tesla_db',
36
+ embedding_function=embedding_model
37
+ )
38
+
39
+ retriever = vectorstore_persisted.as_retriever(
40
+ search_type='similarity',
41
+ search_kwargs={'k': 20}
42
+ )
43
+
44
+ cross_encoder_model = HuggingFaceCrossEncoder(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2")
45
+ compressor = CrossEncoderReranker(model=cross_encoder_model, top_n=5)
46
+ compression_retriever = ContextualCompressionRetriever(
47
+ base_compressor=compressor, base_retriever=retriever
48
+ )
49
+
50
+ # RAG Q&A
51
+
52
+ qna_system_message = """
53
+ You are an expert analyst at a financial services firm who answers user queries on annual reports.
54
+ User input will have the context required by you to answer user questions.
55
+ This context will begin with the word: ###Context.
56
+ The context contains documents relevant to the user query.
57
+ It also contains references to the metadata associated with the relevant documents.
58
+ In sum, the context provided to you will be a combination of information and the metadata for the source of information.
59
+
60
+ User questions will begin with the word: ###Question.
61
+
62
+ Please answer user questions only using the context provided in the input and provide citations.
63
+ Remember, you must return both an answer and citations. A citation consists of a VERBATIM quote that
64
+ justifies the answer and the metadata of the quote article.
65
+ Return a citation for every quote across all articles that justify the answer.
66
+ Use the following format for your final output:
67
+
68
+ <cited_answer>
69
+ <answer></answer>
70
+ <citations>
71
+ <citation><source_doc_year></source_doc_year><source_page></source_page><quote></quote></citation>
72
+ <citation><source_doc_year></source_doc_year><source_page></source_page><quote></quote></citation>
73
+ ...
74
+ </citations>
75
+ </cited_answer>
76
+
77
+ If the answer is not found in the context, respond "I don't know".
78
+ """
79
+
80
+ qna_user_message_template = """
81
+ ###Context
82
+ Here are some documents that are relevant to the question mentioned below.
83
+ {context}
84
+
85
+ ###Question
86
+ {question}
87
+ """
88
+
89
+
90
+ def predict(user_input: str):
91
+
92
+ relevant_document_chunks = retriever.invoke(user_input)
93
+
94
+ relevant_document_chunks = compression_retriever.invoke(user_input)
95
+
96
+ context_citation_list = [
97
+ f'Information: {d.page_content}\nMetadata: {d.metadata}'
98
+ for d in relevant_document_chunks
99
+ ]
100
+
101
+ context_for_query = "\n---\n".join(context_citation_list)
102
+
103
+ prompt = [
104
+ {'role':'system', 'content': qna_system_message},
105
+ {'role': 'user', 'content': qna_user_message_template.format(
106
+ context=context_for_query,
107
+ question=user_input
108
+ )
109
+ }
110
+ ]
111
+
112
+ try:
113
+ response = client.chat.completions.create(
114
+ model=model_name,
115
+ messages=prompt,
116
+ temperature=0
117
+ )
118
+
119
+ prediction = response.choices[0].message.content.strip()
120
+ except Exception as e:
121
+ prediction = f'Sorry, I encountered the following error: \n {e}'
122
+
123
+ return prediction
124
+
125
+
126
+ def parse_prediction(user_input: str):
127
+
128
+ answer = predict(user_input)
129
+
130
+ final_answer = answer[answer.find('<answer>')+len('<answer>'): answer.find('</answer>')]
131
+ citations = answer[answer.find('<citations>')+len('<citations>'): answer.find('</citations>')].strip().split('</citations>')
132
+ references = ''
133
+
134
+ for i, citation in enumerate(citations):
135
+ quote = citation[citation.find('<quote>')+len("<quote>"): citation.find('</quote>')]
136
+ year = citation[citation.find('<source_doc_year>')+len("<source_doc_year>"): citation.find('</source_doc_year>')]
137
+ page = citation[citation.find('<source_page>')+len("<source_page>"): citation.find('</source_page>')]
138
+ references += f'{i+1}. Quote: {quote}, Annual Report: {year}, Page: {page}'
139
+
140
+ return f'Answer: {final_answer}\n' + f'\nReferences:\n {references}'
141
+
142
+ # UI
143
+
144
+ textbox = gr.Textbox(placeholder="Enter your query here", lines=6)
145
+
146
+ demo = gr.Interface(
147
+ inputs=textbox, fn=parse_prediction, outputs="text",
148
+ title="AMA on Tesla 10-K statements",
149
+ description="This web API presents an interface to ask questions on contents of the Tesla 10-K reports for the period 2021 - 2023.",
150
+ article="Note that questions that are not relevant to the Tesla 10-K report will not be answered.",
151
+ examples=[["What was the total revenue of the company in 2022?", ""],
152
+ ["Summarize the Management Discussion and Analysis section of the 2021 report in 50 words.", ""],
153
+ ["What was the company's debt level in 2023?", ""],
154
+ ["Summarize 5 key risks identified in the 2023 10k report? Respond with bullet point summaries.", ""],
155
+ ["What is the view of the management on the future of electric vehicle batteries?",""]
156
+ ],
157
+ cache_examples=False,
158
+ theme=gr.themes.Base(),
159
+ concurrency_limit=16
160
+ )
161
+
162
+ demo.queue()
163
+ demo.launch()