Lrosado commited on
Commit
3a5fc49
·
verified ·
1 Parent(s): 29ce833

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -131
app.py CHANGED
@@ -1,132 +1,131 @@
1
- import os
2
- import uuid
3
- import json
4
-
5
- import gradio as gr
6
-
7
- from openai import OpenAI
8
- pip install --upgrade openai
9
-
10
- from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
11
- from langchain_community.vectorstores import Chroma
12
-
13
- from huggingface_hub import CommitScheduler
14
- from pathlib import Path
15
-
16
-
17
- client = OpenAI(
18
- #base_url="https://aibe.mygreatlearning.com/openai/v1",
19
- api_key="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"
20
- )
21
-
22
- embedding_model = SentenceTransformerEmbeddings(model_name='gpt=thenlper/gte-small')
23
-
24
- tesla_10k_collection = 'tesla-10k-2019-to-2023'
25
-
26
- vectorstore_persisted = Chroma(
27
- collection_name=tesla_10k_collection,
28
- persist_directory='./tesla_db',
29
- embedding_function=embedding_model
30
- )
31
-
32
- retriever = vectorstore_persisted.as_retriever(
33
- search_type='similarity',
34
- search_kwargs={'k': 5}
35
- )
36
-
37
- # Prepare the logging functionality
38
-
39
- log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
40
- log_folder = log_file.parent
41
-
42
- scheduler = CommitScheduler(
43
- repo_id="document-qna-chroma-anyscale-logs",
44
- repo_type="dataset",
45
- folder_path=log_folder,
46
- path_in_repo="data",
47
- every=2
48
- )
49
-
50
- qna_system_message = """
51
- You are an assistant to a financial services firm who answers user queries on annual reports.
52
- Users will ask questions delimited by triple backticks, that is, ```.
53
- User input will have the context required by you to answer user questions.
54
- This context will begin with the token: ###Context.
55
- The context contains references to specific portions of a document relevant to the user query.
56
- Please answer only using the context provided in the input. However, do not mention anything about the context in your answer.
57
- If the answer is not found in the context, respond "I don't know".
58
- """
59
-
60
- qna_user_message_template = """
61
- ###Context
62
- Here are some documents that are relevant to the question.
63
- {context}
64
- ```
65
- {question}
66
- ```
67
- """
68
-
69
- # Define the predict function that runs when 'Submit' is clicked or when a API request is made
70
- def predict(user_input):
71
-
72
- relevant_document_chunks = retriever.invoke(user_input)
73
- context_list = [d.page_content for d in relevant_document_chunks]
74
- context_for_query = ".".join(context_list)
75
-
76
- prompt = [
77
- {'role':'system', 'content': qna_system_message},
78
- {'role': 'user', 'content': qna_user_message_template.format(
79
- context=context_for_query,
80
- question=user_input
81
- )
82
- }
83
- ]
84
-
85
- try:
86
- response = client.chat.completions.create(
87
- model='mlabonne/NeuralHermes-2.5-Mistral-7B',
88
- messages=prompt,
89
- temperature=0
90
- )
91
-
92
- prediction = response.choices[0].message.content
93
-
94
- except Exception as e:
95
- prediction = e
96
-
97
- # While the prediction is made, log both the inputs and outputs to a local log file
98
- # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
99
- # access
100
-
101
- with scheduler.lock:
102
- with log_file.open("a") as f:
103
- f.write(json.dumps(
104
- {
105
- 'user_input': user_input,
106
- 'retrieved_context': context_for_query,
107
- 'model_response': prediction
108
- }
109
- ))
110
- f.write("\n")
111
-
112
- return prediction
113
-
114
-
115
- textbox = gr.Textbox(placeholder="Enter your query here", lines=6)
116
-
117
- # Create the interface
118
- demo = gr.Interface(
119
- inputs=textbox, fn=predict, outputs="text",
120
- title="AMA on Tesla 10-K statements",
121
- description="This web API presents an interface to ask questions on contents of the Tesla 10-K reports for the period 2019 - 2023.",
122
- article="Note that questions that are not relevant to the Tesla 10-K report will not be answered.",
123
- examples=[["What was the total revenue of the company in 2022?", "$ 81.46 Billion"],
124
- ["Summarize the Management Discussion and Analysis section of the 2021 report in 50 words.", ""],
125
- ["What was the company's debt level in 2020?", ""],
126
- ["Identify 5 key risks identified in the 2019 10k report? Respond with bullet point summaries.", ""]
127
- ],
128
- concurrency_limit=16
129
- )
130
-
131
- demo.queue()
132
  demo.launch()
 
1
+ import os
2
+ import uuid
3
+ import json
4
+
5
+ import gradio as gr
6
+
7
+ from openai import OpenAI
8
+
9
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
10
+ from langchain_community.vectorstores import Chroma
11
+
12
+ from huggingface_hub import CommitScheduler
13
+ from pathlib import Path
14
+
15
+
16
+ client = OpenAI(
17
+ #base_url="https://aibe.mygreatlearning.com/openai/v1",
18
+ api_key="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"
19
+ )
20
+
21
+ embedding_model = SentenceTransformerEmbeddings(model_name='gpt=thenlper/gte-small')
22
+
23
+ tesla_10k_collection = 'tesla-10k-2019-to-2023'
24
+
25
+ vectorstore_persisted = Chroma(
26
+ collection_name=tesla_10k_collection,
27
+ persist_directory='./tesla_db',
28
+ embedding_function=embedding_model
29
+ )
30
+
31
+ retriever = vectorstore_persisted.as_retriever(
32
+ search_type='similarity',
33
+ search_kwargs={'k': 5}
34
+ )
35
+
36
+ # Prepare the logging functionality
37
+
38
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
39
+ log_folder = log_file.parent
40
+
41
+ scheduler = CommitScheduler(
42
+ repo_id="document-qna-chroma-anyscale-logs",
43
+ repo_type="dataset",
44
+ folder_path=log_folder,
45
+ path_in_repo="data",
46
+ every=2
47
+ )
48
+
49
+ qna_system_message = """
50
+ You are an assistant to a financial services firm who answers user queries on annual reports.
51
+ Users will ask questions delimited by triple backticks, that is, ```.
52
+ User input will have the context required by you to answer user questions.
53
+ This context will begin with the token: ###Context.
54
+ The context contains references to specific portions of a document relevant to the user query.
55
+ Please answer only using the context provided in the input. However, do not mention anything about the context in your answer.
56
+ If the answer is not found in the context, respond "I don't know".
57
+ """
58
+
59
+ qna_user_message_template = """
60
+ ###Context
61
+ Here are some documents that are relevant to the question.
62
+ {context}
63
+ ```
64
+ {question}
65
+ ```
66
+ """
67
+
68
+ # Define the predict function that runs when 'Submit' is clicked or when a API request is made
69
+ def predict(user_input):
70
+
71
+ relevant_document_chunks = retriever.invoke(user_input)
72
+ context_list = [d.page_content for d in relevant_document_chunks]
73
+ context_for_query = ".".join(context_list)
74
+
75
+ prompt = [
76
+ {'role':'system', 'content': qna_system_message},
77
+ {'role': 'user', 'content': qna_user_message_template.format(
78
+ context=context_for_query,
79
+ question=user_input
80
+ )
81
+ }
82
+ ]
83
+
84
+ try:
85
+ response = client.chat.completions.create(
86
+ model='mlabonne/NeuralHermes-2.5-Mistral-7B',
87
+ messages=prompt,
88
+ temperature=0
89
+ )
90
+
91
+ prediction = response.choices[0].message.content
92
+
93
+ except Exception as e:
94
+ prediction = e
95
+
96
+ # While the prediction is made, log both the inputs and outputs to a local log file
97
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
98
+ # access
99
+
100
+ with scheduler.lock:
101
+ with log_file.open("a") as f:
102
+ f.write(json.dumps(
103
+ {
104
+ 'user_input': user_input,
105
+ 'retrieved_context': context_for_query,
106
+ 'model_response': prediction
107
+ }
108
+ ))
109
+ f.write("\n")
110
+
111
+ return prediction
112
+
113
+
114
+ textbox = gr.Textbox(placeholder="Enter your query here", lines=6)
115
+
116
+ # Create the interface
117
+ demo = gr.Interface(
118
+ inputs=textbox, fn=predict, outputs="text",
119
+ title="AMA on Tesla 10-K statements",
120
+ description="This web API presents an interface to ask questions on contents of the Tesla 10-K reports for the period 2019 - 2023.",
121
+ article="Note that questions that are not relevant to the Tesla 10-K report will not be answered.",
122
+ examples=[["What was the total revenue of the company in 2022?", "$ 81.46 Billion"],
123
+ ["Summarize the Management Discussion and Analysis section of the 2021 report in 50 words.", ""],
124
+ ["What was the company's debt level in 2020?", ""],
125
+ ["Identify 5 key risks identified in the 2019 10k report? Respond with bullet point summaries.", ""]
126
+ ],
127
+ concurrency_limit=16
128
+ )
129
+
130
+ demo.queue()
 
131
  demo.launch()