mayankchugh-learning commited on
Commit
2b28caa
·
verified ·
1 Parent(s): f63ff6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -131
app.py CHANGED
@@ -1,131 +1,93 @@
1
- import os
2
- import uuid
3
- import json
4
-
5
- import gradio as gr
6
-
7
- from openai import OpenAI
8
-
9
- from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
10
- from langchain_community.vectorstores import Chroma
11
-
12
- from huggingface_hub import CommitScheduler
13
- from pathlib import Path
14
-
15
-
16
- client = OpenAI(
17
- base_url="https://api.endpoints.anyscale.com/v1",
18
- api_key=os.environ['ANYSCALE_API_KEY']
19
- )
20
-
21
- embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-small')
22
-
23
- tesla_10k_collection = 'tesla-10k-2019-to-2023'
24
-
25
- vectorstore_persisted = Chroma(
26
- collection_name=tesla_10k_collection,
27
- persist_directory='./tesla_db',
28
- embedding_function=embedding_model
29
- )
30
-
31
- retriever = vectorstore_persisted.as_retriever(
32
- search_type='similarity',
33
- search_kwargs={'k': 5}
34
- )
35
-
36
- # Prepare the logging functionality
37
-
38
- log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
39
- log_folder = log_file.parent
40
-
41
- scheduler = CommitScheduler(
42
- repo_id="document-qna-chroma-anyscale-logs",
43
- repo_type="dataset",
44
- folder_path=log_folder,
45
- path_in_repo="data",
46
- every=2
47
- )
48
-
49
- qna_system_message = """
50
- You are an assistant to a financial services firm who answers user queries on annual reports.
51
- Users will ask questions delimited by triple backticks, that is, ```.
52
- User input will have the context required by you to answer user questions.
53
- This context will begin with the token: ###Context.
54
- The context contains references to specific portions of a document relevant to the user query.
55
- Please answer only using the context provided in the input. However, do not mention anything about the context in your answer.
56
- If the answer is not found in the context, respond "I don't know".
57
- """
58
-
59
- qna_user_message_template = """
60
- ###Context
61
- Here are some documents that are relevant to the question.
62
- {context}
63
- ```
64
- {question}
65
- ```
66
- """
67
-
68
- # Define the predict function that runs when 'Submit' is clicked or when a API request is made
69
- def predict(user_input):
70
-
71
- relevant_document_chunks = retriever.invoke(user_input)
72
- context_list = [d.page_content for d in relevant_document_chunks]
73
- context_for_query = ".".join(context_list)
74
-
75
- prompt = [
76
- {'role':'system', 'content': qna_system_message},
77
- {'role': 'user', 'content': qna_user_message_template.format(
78
- context=context_for_query,
79
- question=user_input
80
- )
81
- }
82
- ]
83
-
84
- try:
85
- response = client.chat.completions.create(
86
- model='mlabonne/NeuralHermes-2.5-Mistral-7B',
87
- messages=prompt,
88
- temperature=0
89
- )
90
-
91
- prediction = response.choices[0].message.content
92
-
93
- except Exception as e:
94
- prediction = e
95
-
96
- # While the prediction is made, log both the inputs and outputs to a local log file
97
- # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
98
- # access
99
-
100
- with scheduler.lock:
101
- with log_file.open("a") as f:
102
- f.write(json.dumps(
103
- {
104
- 'user_input': user_input,
105
- 'retrieved_context': context_for_query,
106
- 'model_response': prediction
107
- }
108
- ))
109
- f.write("\n")
110
-
111
- return prediction
112
-
113
-
114
- textbox = gr.Textbox(placeholder="Enter your query here", lines=6)
115
-
116
- # Create the interface
117
- demo = gr.Interface(
118
- inputs=textbox, fn=predict, outputs="text",
119
- title="AMA on Tesla 10-K statements",
120
- description="This web API presents an interface to ask questions on contents of the Tesla 10-K reports for the period 2019 - 2023.",
121
- article="Note that questions that are not relevant to the Tesla 10-K report will not be answered.",
122
- examples=[["What was the total revenue of the company in 2022?", "$ 81.46 Billion"],
123
- ["Summarize the Management Discussion and Analysis section of the 2021 report in 50 words.", ""],
124
- ["What was the company's debt level in 2020?", ""],
125
- ["Identify 5 key risks identified in the 2019 10k report? Respond with bullet point summaries.", ""]
126
- ],
127
- concurrency_limit=16
128
- )
129
-
130
- demo.queue()
131
- demo.launch()
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+ from openai import OpenAI
6
+
7
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
8
+ from langchain_community.vectorstores import Chroma
9
+
10
+ client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
11
+
12
+ embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-small')
13
+
14
+ tesla_10k_collection = 'tesla-10k-2019-to-2023'
15
+
16
+ vectorstore_persisted = Chroma(
17
+ collection_name=tesla_10k_collection,
18
+ persist_directory='./tesla_db',
19
+ embedding_function=embedding_model
20
+ )
21
+
22
+ retriever = vectorstore_persisted.as_retriever(
23
+ search_type='similarity',
24
+ search_kwargs={'k': 5}
25
+ )
26
+
27
+ qna_system_message = """
28
+ You are an assistant to a financial services firm who answers user queries on annual reports.
29
+ Users will ask questions delimited by triple backticks, that is, ```.
30
+ User input will have the context required by you to answer user questions.
31
+ This context will begin with the token: ###Context.
32
+ The context contains references to specific portions of a document relevant to the user query.
33
+ Please answer only using the context provided in the input. However, do not mention anything about the context in your answer.
34
+ If the answer is not found in the context, respond "I don't know".
35
+ """
36
+
37
+ qna_user_message_template = """
38
+ ###Context
39
+ Here are some documents that are relevant to the question.
40
+ {context}
41
+ ```
42
+ {question}
43
+ ```
44
+ """
45
+
46
+ def predict(user_input):
47
+
48
+ relevant_document_chunks = retriever.get_relevant_documents(user_input)
49
+ context_list = [d.page_content for d in relevant_document_chunks]
50
+ context_for_query = ".".join(context_list)
51
+
52
+ prompt = [
53
+ {'role':'system', 'content': qna_system_message},
54
+ {'role': 'user', 'content': qna_user_message_template.format(
55
+ context=context_for_query,
56
+ question=user_input
57
+ )
58
+ }
59
+ ]
60
+
61
+ try:
62
+ response = client.chat.completions.create(
63
+ model="gpt-3.5-turbo",
64
+ messages=prompt,
65
+ temperature=0
66
+ )
67
+
68
+ prediction = response.choices[0].message.content
69
+
70
+ except Exception as e:
71
+ prediction = e
72
+
73
+ return prediction
74
+
75
+
76
+ textbox = gr.Textbox(placeholder="Enter your query here", lines=6)
77
+
78
+ demo = gr.Interface(
79
+ inputs=textbox, fn=predict, outputs="text",
80
+ title="AMA on Tesla 10-K statements",
81
+ description="This web API presents an interface to ask questions on contents of the Tesla 10-K reports for the period 2019 - 2023.",
82
+ article="Note that questions that are not relevant to the Tesla 10-K report will not be answered.",
83
+ examples=[["What was the total revenue of the company in 2022?", "$ 81.46 Billion"],
84
+ ["Summarize the Management Discussion and Analysis section of the 2021 report in 50 words.", ""],
85
+ ["What was the company's debt level in 2020?", ""],
86
+ ["Identify 5 key risks identified in the 2019 10k report? Respond with bullet point summaries.", ""]
87
+ ],
88
+ concurrency_limit=16
89
+ )
90
+
91
+
92
+ demo.queue()
93
+ demo.launch(auth=("demouser", os.getenv('PASSWD')))