Fred Premji commited on
Commit
b3935e8
·
verified ·
1 Parent(s): e3fbdd4

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tesla_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,63 +1,131 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import json
4
+
5
+ import gradio as gr
6
+
7
+ from openai import OpenAI
8
+
9
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
10
+ from langchain_community.vectorstores import Chroma
11
+
12
+ from huggingface_hub import CommitScheduler
13
+ from pathlib import Path
14
+
15
+
16
+ client = OpenAI(
17
+ base_url="https://api.endpoints.anyscale.com/v1",
18
+ api_key=os.environ['ANYSCALE_API_KEY']
19
+ )
20
+
21
+ embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-small')
22
+
23
+ tesla_10k_collection = 'tesla-10k-2019-to-2023'
24
+
25
+ vectorstore_persisted = Chroma(
26
+ collection_name=tesla_10k_collection,
27
+ persist_directory='./tesla_db',
28
+ embedding_function=embedding_model
29
+ )
30
+
31
+ retriever = vectorstore_persisted.as_retriever(
32
+ search_type='similarity',
33
+ search_kwargs={'k': 5}
34
+ )
35
+
36
+ # Prepare the logging functionality
37
+
38
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
39
+ log_folder = log_file.parent
40
+
41
+ scheduler = CommitScheduler(
42
+ repo_id="document-qna-chroma-anyscale-logs",
43
+ repo_type="dataset",
44
+ folder_path=log_folder,
45
+ path_in_repo="data",
46
+ every=2
47
+ )
48
+
49
+ qna_system_message = """
50
+ You are an assistant to a financial services firm who answers user queries on annual reports.
51
+ Users will ask questions delimited by triple backticks, that is, ```.
52
+ User input will have the context required by you to answer user questions.
53
+ This context will begin with the token: ###Context.
54
+ The context contains references to specific portions of a document relevant to the user query.
55
+ Please answer only using the context provided in the input. However, do not mention anything about the context in your answer.
56
+ If the answer is not found in the context, respond "I don't know".
57
+ """
58
+
59
+ qna_user_message_template = """
60
+ ###Context
61
+ Here are some documents that are relevant to the question.
62
+ {context}
63
+ ```
64
+ {question}
65
+ ```
66
+ """
67
+
68
+ # Define the predict function that runs when 'Submit' is clicked or when a API request is made
69
+ def predict(user_input):
70
+
71
+ relevant_document_chunks = retriever.invoke(user_input)
72
+ context_list = [d.page_content for d in relevant_document_chunks]
73
+ context_for_query = ".".join(context_list)
74
+
75
+ prompt = [
76
+ {'role':'system', 'content': qna_system_message},
77
+ {'role': 'user', 'content': qna_user_message_template.format(
78
+ context=context_for_query,
79
+ question=user_input
80
+ )
81
+ }
82
+ ]
83
+
84
+ try:
85
+ response = client.chat.completions.create(
86
+ model='mlabonne/NeuralHermes-2.5-Mistral-7B',
87
+ messages=prompt,
88
+ temperature=0
89
+ )
90
+
91
+ prediction = response.choices[0].message.content
92
+
93
+ except Exception as e:
94
+ prediction = e
95
+
96
+ # While the prediction is made, log both the inputs and outputs to a local log file
97
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
98
+ # access
99
+
100
+ with scheduler.lock:
101
+ with log_file.open("a") as f:
102
+ f.write(json.dumps(
103
+ {
104
+ 'user_input': user_input,
105
+ 'retrieved_context': context_for_query,
106
+ 'model_response': prediction
107
+ }
108
+ ))
109
+ f.write("\n")
110
+
111
+ return prediction
112
+
113
+
114
+ textbox = gr.Textbox(placeholder="Enter your query here", lines=6)
115
+
116
+ # Create the interface
117
+ demo = gr.Interface(
118
+ inputs=textbox, fn=predict, outputs="text",
119
+ title="AMA on Tesla 10-K statements",
120
+ description="This web API presents an interface to ask questions on contents of the Tesla 10-K reports for the period 2019 - 2023.",
121
+ article="Note that questions that are not relevant to the Tesla 10-K report will not be answered.",
122
+ examples=[["What was the total revenue of the company in 2022?", "$ 81.46 Billion"],
123
+ ["Summarize the Management Discussion and Analysis section of the 2021 report in 50 words.", ""],
124
+ ["What was the company's debt level in 2020?", ""],
125
+ ["Identify 5 key risks identified in the 2019 10k report? Respond with bullet point summaries.", ""]
126
+ ],
127
+ concurrency_limit=16
128
+ )
129
+
130
+ demo.queue()
131
+ demo.launch()
requirements.txt CHANGED
@@ -1 +1,5 @@
1
- huggingface_hub==0.22.2
 
 
 
 
 
1
+ openai==1.23.2
2
+ chromadb==0.4.22
3
+ langchain==0.1.9
4
+ langchain-community==0.0.32
5
+ sentence-transformers==2.3.1
tesla_db/.DS_Store ADDED
Binary file (6.15 kB). View file
 
tesla_db/908b9485-d351-4c65-93e1-a9a76f864b14/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27662eb029bbc3d09e4f17de86b0ec8d081222e7f0c2bf8b38c0f76588eb2878
3
+ size 5028000
tesla_db/908b9485-d351-4c65-93e1-a9a76f864b14/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed67034369667b449e6a14ff788b4f792483e794a2ef335429a48cf9bec3a897
3
+ size 100
tesla_db/908b9485-d351-4c65-93e1-a9a76f864b14/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:708978bce9accecfb0aca5a97e8ab92a9e3237aacac297c5e164da3a17f8802e
3
+ size 172004
tesla_db/908b9485-d351-4c65-93e1-a9a76f864b14/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d247935870a99da0b73b5ff4fd66239d718bf01b76581c25351a875c7740c79e
3
+ size 12000
tesla_db/908b9485-d351-4c65-93e1-a9a76f864b14/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08c8d1f08527c0f21b5814da5dad0aa7f1ba98cab0738289e9d92a564c00cc2f
3
+ size 25736
tesla_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c00be5a240a3470d8494233a24d6881d28bc2f2fbe900c5a82b5ef2c20bd26fb
3
+ size 36327424