muhammadsalmanalfaridzi commited on
Commit
40fd220
·
verified ·
1 Parent(s): 9404b5f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import tempfile
4
+ import uuid
5
+ import pandas as pd
6
+ import openai # Import openai for Sambanova API
7
+
8
+ from gitingest import ingest
9
+ from llama_index.core import Settings
10
+ from llama_index.core import PromptTemplate
11
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
12
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
13
+ from llama_index.core.node_parser import MarkdownNodeParser
14
+
15
+ import streamlit as st
16
+
17
+ if "id" not in st.session_state:
18
+ st.session_state.id = uuid.uuid4()
19
+ st.session_state.file_cache = {}
20
+
21
+ session_id = st.session_state.id
22
+ client = None
23
+
24
+ # Update the load_llm function to use Sambanova's API
25
+ @st.cache_resource
26
+ def load_llm():
27
+ # Initialize the Sambanova OpenAI client
28
+ client = openai.OpenAI(
29
+ api_key=os.environ.get("SAMBANOVA_API_KEY"),
30
+ base_url="https://api.sambanova.ai/v1",
31
+ )
32
+ return client
33
+
34
+ def reset_chat():
35
+ st.session_state.messages = []
36
+ st.session_state.context = None
37
+ gc.collect()
38
+
39
+ def process_with_gitingets(github_url):
40
+ # or from URL
41
+ summary, tree, content = ingest(github_url)
42
+ return summary, tree, content
43
+
44
+
45
+ with st.sidebar:
46
+ st.header(f"Add your GitHub repository!")
47
+
48
+ github_url = st.text_input("Enter GitHub repository URL", placeholder="GitHub URL")
49
+ load_repo = st.button("Load Repository")
50
+
51
+ if github_url and load_repo:
52
+ try:
53
+ with tempfile.TemporaryDirectory() as temp_dir:
54
+ st.write("Processing your repository...")
55
+ repo_name = github_url.split('/')[-1]
56
+ file_key = f"{session_id}-{repo_name}"
57
+
58
+ if file_key not in st.session_state.get('file_cache', {}):
59
+ if os.path.exists(temp_dir):
60
+ summary, tree, content = process_with_gitingets(github_url)
61
+
62
+ # Write summary to a markdown file
63
+ with open("content.md", "w", encoding="utf-8") as f:
64
+ f.write(content)
65
+
66
+ # Write summary to a markdown file in temp directory
67
+ content_path = os.path.join(temp_dir, f"{repo_name}_content.md")
68
+ with open(content_path, "w", encoding="utf-8") as f:
69
+ f.write(content)
70
+ loader = SimpleDirectoryReader(
71
+ input_dir=temp_dir,
72
+ )
73
+ else:
74
+ st.error('Could not find the file you uploaded, please check again...')
75
+ st.stop()
76
+
77
+ docs = loader.load_data()
78
+
79
+ # setup llm & embedding model
80
+ llm = load_llm() # Load the Sambanova LLM client
81
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5", trust_remote_code=True)
82
+ # Creating an index over loaded data
83
+ Settings.embed_model = embed_model
84
+ node_parser = MarkdownNodeParser()
85
+ index = VectorStoreIndex.from_documents(documents=docs, transformations=[node_parser], show_progress=True)
86
+
87
+ # Create the query engine, where we use a cohere reranker on the fetched nodes
88
+ Settings.llm = llm
89
+ query_engine = index.as_query_engine(streaming=True)
90
+
91
+ # ====== Customise prompt template ======
92
+ qa_prompt_tmpl_str = (
93
+ "Context information is below.\n"
94
+ "---------------------\n"
95
+ "{context_str}\n"
96
+ "---------------------\n"
97
+ "Given the context information above I want you to think step by step to answer the query in a highly precise and crisp manner focused on the final answer, incase case you don't know the answer say 'I don't know!'.\n"
98
+ "Query: {query_str}\n"
99
+ "Answer: "
100
+ )
101
+ qa_prompt_tmpl = PromptTemplate(qa_prompt_tmpl_str)
102
+
103
+ query_engine.update_prompts(
104
+ {"response_synthesizer:text_qa_template": qa_prompt_tmpl}
105
+ )
106
+
107
+ st.session_state.file_cache[file_key] = query_engine
108
+ else:
109
+ query_engine = st.session_state.file_cache[file_key]
110
+
111
+ # Inform the user that the file is processed and Display the PDF uploaded
112
+ st.success("Ready to Chat!")
113
+ except Exception as e:
114
+ st.error(f"An error occurred: {e}")
115
+ st.stop()
116
+
117
+ col1, col2 = st.columns([6, 1])
118
+
119
+ with col1:
120
+ st.header(f"Chat with GitHub using RAG </>")
121
+
122
+ with col2:
123
+ st.button("Clear ↺", on_click=reset_chat)
124
+
125
+ # Initialize chat history
126
+ if "messages" not in st.session_state:
127
+ reset_chat()
128
+
129
+ # Display chat messages from history on app rerun
130
+ for message in st.session_state.messages:
131
+ with st.chat_message(message["role"]):
132
+ st.markdown(message["content"])
133
+
134
+ # Accept user input
135
+ if prompt := st.chat_input("What's up?"):
136
+ # Add user message to chat history
137
+ st.session_state.messages.append({"role": "user", "content": prompt})
138
+ # Display user message in chat message container
139
+ with st.chat_message("user"):
140
+ st.markdown(prompt)
141
+
142
+ # Display assistant response in chat message container
143
+ with st.chat_message("assistant"):
144
+ message_placeholder = st.empty()
145
+ full_response = ""
146
+
147
+ try:
148
+ # Get the repo name from the GitHub URL
149
+ repo_name = github_url.split('/')[-1]
150
+ file_key = f"{session_id}-{repo_name}"
151
+
152
+ # Get query engine from session state
153
+ query_engine = st.session_state.file_cache.get(file_key)
154
+
155
+ if query_engine is None:
156
+ st.error("Please load a repository first!")
157
+ st.stop()
158
+
159
+ # Use the query engine to get the context for the query
160
+ response = query_engine.query(prompt)
161
+
162
+ # Handle streaming response
163
+ if hasattr(response, 'response_gen'):
164
+ for chunk in response.response_gen:
165
+ if isinstance(chunk, str): # Only process string chunks
166
+ full_response += chunk
167
+ message_placeholder.markdown(full_response + "▌")
168
+ else:
169
+ # Handle non-streaming response
170
+ full_response = str(response)
171
+ message_placeholder.markdown(full_response)
172
+
173
+ message_placeholder.markdown(full_response)
174
+
175
+ except Exception as e:
176
+ st.error(f"An error occurred while processing your query: {str(e)}")
177
+ full_response = "Sorry, I encountered an error while processing your request."
178
+ message_placeholder.markdown(full_response)
179
+
180
+ # Add assistant response to chat history
181
+ st.session_state.messages.append({"role": "assistant", "content": full_response})