muhammadsalmanalfaridzi commited on
Commit
a0cd1a3
·
verified ·
1 Parent(s): f21f5b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -21
app.py CHANGED
@@ -3,10 +3,11 @@ import gc
3
  import tempfile
4
  import uuid
5
  import pandas as pd
6
- import openai # Import openai for Sambanova API
7
 
8
  from gitingest import ingest
9
  from llama_index.core import Settings
 
10
  from llama_index.core import PromptTemplate
11
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
12
  from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
@@ -14,6 +15,9 @@ from llama_index.core.node_parser import MarkdownNodeParser
14
 
15
  import streamlit as st
16
 
 
 
 
17
  if "id" not in st.session_state:
18
  st.session_state.id = uuid.uuid4()
19
  st.session_state.file_cache = {}
@@ -21,15 +25,18 @@ if "id" not in st.session_state:
21
  session_id = st.session_state.id
22
  client = None
23
 
24
- # Update the load_llm function to use Sambanova's API
25
  @st.cache_resource
26
  def load_llm():
27
- # Initialize the Sambanova OpenAI client
28
- client = openai.OpenAI(
29
- api_key=os.environ.get("SAMBANOVA_API_KEY"),
30
- base_url="https://api.sambanova.ai/v1",
 
 
 
 
31
  )
32
- return client
33
 
34
  def reset_chat():
35
  st.session_state.messages = []
@@ -77,26 +84,26 @@ with st.sidebar:
77
  docs = loader.load_data()
78
 
79
  # setup llm & embedding model
80
- llm = load_llm() # Load the Sambanova LLM client
81
  embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5", trust_remote_code=True)
82
  # Creating an index over loaded data
83
  Settings.embed_model = embed_model
84
  node_parser = MarkdownNodeParser()
85
  index = VectorStoreIndex.from_documents(documents=docs, transformations=[node_parser], show_progress=True)
86
 
87
- # Create the query engine, where we use a cohere reranker on the fetched nodes
88
  Settings.llm = llm
89
  query_engine = index.as_query_engine(streaming=True)
90
 
91
  # ====== Customise prompt template ======
92
  qa_prompt_tmpl_str = (
93
- "Context information is below.\n"
94
- "---------------------\n"
95
- "{context_str}\n"
96
- "---------------------\n"
97
- "Given the context information above I want you to think step by step to answer the query in a highly precise and crisp manner focused on the final answer, incase case you don't know the answer say 'I don't know!'.\n"
98
- "Query: {query_str}\n"
99
- "Answer: "
100
  )
101
  qa_prompt_tmpl = PromptTemplate(qa_prompt_tmpl_str)
102
 
@@ -108,12 +115,13 @@ with st.sidebar:
108
  else:
109
  query_engine = st.session_state.file_cache[file_key]
110
 
111
- # Inform the user that the file is processed and Display the PDF uploaded
112
  st.success("Ready to Chat!")
113
  except Exception as e:
114
  st.error(f"An error occurred: {e}")
115
  st.stop()
116
 
 
117
  col1, col2 = st.columns([6, 1])
118
 
119
  with col1:
@@ -126,11 +134,13 @@ with col2:
126
  if "messages" not in st.session_state:
127
  reset_chat()
128
 
 
129
  # Display chat messages from history on app rerun
130
  for message in st.session_state.messages:
131
  with st.chat_message(message["role"]):
132
  st.markdown(message["content"])
133
 
 
134
  # Accept user input
135
  if prompt := st.chat_input("What's up?"):
136
  # Add user message to chat history
@@ -156,9 +166,9 @@ if prompt := st.chat_input("What's up?"):
156
  st.error("Please load a repository first!")
157
  st.stop()
158
 
159
- # Use the query engine to get the context for the query
160
  response = query_engine.query(prompt)
161
-
162
  # Handle streaming response
163
  if hasattr(response, 'response_gen'):
164
  for chunk in response.response_gen:
@@ -171,11 +181,10 @@ if prompt := st.chat_input("What's up?"):
171
  message_placeholder.markdown(full_response)
172
 
173
  message_placeholder.markdown(full_response)
174
-
175
  except Exception as e:
176
  st.error(f"An error occurred while processing your query: {str(e)}")
177
  full_response = "Sorry, I encountered an error while processing your request."
178
  message_placeholder.markdown(full_response)
179
 
180
  # Add assistant response to chat history
181
- st.session_state.messages.append({"role": "assistant", "content": full_response})
 
3
  import tempfile
4
  import uuid
5
  import pandas as pd
6
+ import openai
7
 
8
  from gitingest import ingest
9
  from llama_index.core import Settings
10
+ from llama_index.llms.sambanova import SambaNovaCloud
11
  from llama_index.core import PromptTemplate
12
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
13
  from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
 
15
 
16
  import streamlit as st
17
 
18
+ # Set up SambaNova API key
19
+ os.environ["SAMBANOVA_API_KEY"] = "your_sambanova_api_key" # Replace with your actual SambaNova API key
20
+
21
  if "id" not in st.session_state:
22
  st.session_state.id = uuid.uuid4()
23
  st.session_state.file_cache = {}
 
25
  session_id = st.session_state.id
26
  client = None
27
 
 
28
  @st.cache_resource
29
  def load_llm():
30
+ # Instantiate the SambaNova model
31
+ llm = SambaNovaCloud(
32
+ model="Meta-Llama-3.1-405B-Instruct", # Use the correct model name
33
+ context_window=100000,
34
+ max_tokens=1024,
35
+ temperature=0.7,
36
+ top_k=1,
37
+ top_p=0.01,
38
  )
39
+ return llm
40
 
41
  def reset_chat():
42
  st.session_state.messages = []
 
84
  docs = loader.load_data()
85
 
86
  # setup llm & embedding model
87
+ llm = load_llm()
88
  embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5", trust_remote_code=True)
89
  # Creating an index over loaded data
90
  Settings.embed_model = embed_model
91
  node_parser = MarkdownNodeParser()
92
  index = VectorStoreIndex.from_documents(documents=docs, transformations=[node_parser], show_progress=True)
93
 
94
+ # Create the query engine
95
  Settings.llm = llm
96
  query_engine = index.as_query_engine(streaming=True)
97
 
98
  # ====== Customise prompt template ======
99
  qa_prompt_tmpl_str = (
100
+ "Context information is below.\n"
101
+ "---------------------\n"
102
+ "{context_str}\n"
103
+ "---------------------\n"
104
+ "Given the context information above I want you to think step by step to answer the query in a highly precise and crisp manner focused on the final answer, in case you don't know the answer say 'I don't know!'.\n"
105
+ "Query: {query_str}\n"
106
+ "Answer: "
107
  )
108
  qa_prompt_tmpl = PromptTemplate(qa_prompt_tmpl_str)
109
 
 
115
  else:
116
  query_engine = st.session_state.file_cache[file_key]
117
 
118
+ # Inform the user that the file is processed and display the PDF uploaded
119
  st.success("Ready to Chat!")
120
  except Exception as e:
121
  st.error(f"An error occurred: {e}")
122
  st.stop()
123
 
124
+
125
  col1, col2 = st.columns([6, 1])
126
 
127
  with col1:
 
134
  if "messages" not in st.session_state:
135
  reset_chat()
136
 
137
+
138
  # Display chat messages from history on app rerun
139
  for message in st.session_state.messages:
140
  with st.chat_message(message["role"]):
141
  st.markdown(message["content"])
142
 
143
+
144
  # Accept user input
145
  if prompt := st.chat_input("What's up?"):
146
  # Add user message to chat history
 
166
  st.error("Please load a repository first!")
167
  st.stop()
168
 
169
+ # Use the query engine
170
  response = query_engine.query(prompt)
171
+
172
  # Handle streaming response
173
  if hasattr(response, 'response_gen'):
174
  for chunk in response.response_gen:
 
181
  message_placeholder.markdown(full_response)
182
 
183
  message_placeholder.markdown(full_response)
 
184
  except Exception as e:
185
  st.error(f"An error occurred while processing your query: {str(e)}")
186
  full_response = "Sorry, I encountered an error while processing your request."
187
  message_placeholder.markdown(full_response)
188
 
189
  # Add assistant response to chat history
190
+ st.session_state.messages.append({"role": "assistant", "content": full_response})