muhammadsalmanalfaridzi commited on
Commit
dba1f58
·
verified ·
1 Parent(s): 8ad304a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -277
app.py CHANGED
@@ -1,46 +1,27 @@
1
  import os
2
- import sys
3
  import gc
4
  import tempfile
5
  import uuid
6
  import logging
7
- import requests
8
- import time
9
- from typing import List, Any
10
 
11
  import streamlit as st
12
  from dotenv import load_dotenv
13
- import openai
14
  from gitingest import ingest
15
- from llama_index.core import Settings, PromptTemplate, VectorStoreIndex, SimpleDirectoryReader
16
- from llama_index.core.node_parser import MarkdownNodeParser
17
- from llama_index.vector_stores.faiss import FaissVectorStore
18
- from llama_index.embeddings.base import BaseEmbedding
19
- import faiss
20
- from llama_index.llms.sambanovasystems import SambaNovaCloud
 
 
 
 
21
 
22
- # ------------------ Configuration ------------------
23
  load_dotenv()
24
 
25
- # Configure SamnaNova OpenAI-compatible client
26
- SAMBA_API_KEY = os.getenv("SAMBANOVA_API_KEY")
27
- SAMBA_BASE_URL = os.getenv("SAMBANOVA_BASE_URL", "https://api.sambanova.ai/v1")
28
-
29
- # Nomic AI API Key
30
- NOMIC_API_KEY = os.getenv("NOMIC_API_KEY")
31
-
32
- if not SAMBA_API_KEY:
33
- raise ValueError("Missing SAMBANOVA_API_KEY in environment")
34
-
35
- if not NOMIC_API_KEY:
36
- raise ValueError("Missing NOMIC_API_KEY in environment")
37
-
38
- # Initialize SambaNova client
39
- sambanova_client = openai.OpenAI(
40
- api_key=SAMBA_API_KEY,
41
- base_url=SAMBA_BASE_URL
42
- )
43
-
44
  # Configure logging
45
  logging.basicConfig(level=logging.INFO)
46
  logger = logging.getLogger(__name__)
@@ -49,97 +30,10 @@ logger = logging.getLogger(__name__)
49
  MAX_REPO_SIZE = 100 * 1024 * 1024 # 100MB
50
  SUPPORTED_REPO_TYPES = ['.py', '.md', '.ipynb', '.js', '.ts', '.json']
51
 
52
- # ------------------ Exceptions ------------------
53
  class GitHubRAGError(Exception):
54
  """Custom exception for GitHub RAG application errors"""
55
  pass
56
 
57
- # ------------------ Embedding Cache ------------------
58
- embedding_cache = {}
59
-
60
- # ------------------ Nomic AI Embedding Implementation ------------------
61
- class NomicEmbedding(BaseEmbedding):
62
- """Custom embedding class for Nomic AI"""
63
- def __init__(self, model_name="nomic-embed-text-v1.5", task_type="search_document"):
64
- self.model_name = model_name
65
- self.task_type = task_type
66
- self.api_key = NOMIC_API_KEY
67
- super().__init__()
68
-
69
- def _get_query_embedding(self, query: str) -> List[float]:
70
- """Get embedding for a query string"""
71
- return self._get_embedding(query)
72
-
73
- def _get_text_embedding(self, text: str) -> List[float]:
74
- """Get embedding for a text string"""
75
- return self._get_embedding(text)
76
-
77
- def _get_embedding(self, text: str) -> List[float]:
78
- """Get embedding from Nomic AI"""
79
- # Check if text is already in cache
80
- if text in embedding_cache:
81
- return embedding_cache[text]
82
-
83
- try:
84
- url = "https://api-atlas.nomic.ai/v1/embedding/text"
85
- headers = {
86
- "Authorization": f"Bearer {self.api_key}",
87
- "Content-Type": "application/json",
88
- "Accept": "application/json"
89
- }
90
- payload = {
91
- "texts": [text],
92
- "model": self.model_name,
93
- "task_type": self.task_type
94
- }
95
-
96
- # Retry logic with exponential backoff
97
- max_retries = 3
98
- retry_delay = 1 # Start with 1 second delay
99
-
100
- for retry in range(max_retries):
101
- try:
102
- response = requests.post(
103
- url,
104
- headers=headers,
105
- json=payload,
106
- timeout=30 # 30 seconds timeout
107
- )
108
-
109
- if response.status_code == 200:
110
- embedding = response.json()["embeddings"][0]
111
- # Cache the result
112
- embedding_cache[text] = embedding
113
- return embedding
114
- else:
115
- logger.error(f"Error from Nomic API: {response.status_code} - {response.text}")
116
- if retry < max_retries - 1:
117
- # Wait with exponential backoff before retry
118
- time.sleep(retry_delay)
119
- retry_delay *= 2 # Double the delay for next retry
120
- else:
121
- # Last retry failed
122
- raise Exception(f"Failed to get embedding after {max_retries} attempts")
123
- except requests.exceptions.RequestException as e:
124
- logger.error(f"Request error (attempt {retry+1}/{max_retries}): {e}")
125
- if retry < max_retries - 1:
126
- time.sleep(retry_delay)
127
- retry_delay *= 2
128
- else:
129
- raise
130
- except Exception as e:
131
- logger.error(f"Error connecting to Nomic API: {e}")
132
- raise # Propagate the error without fallback
133
-
134
- async def _aget_query_embedding(self, query: str) -> List[float]:
135
- """Async version of get_query_embedding"""
136
- return self._get_query_embedding(query)
137
-
138
- async def _aget_text_embedding(self, text: str) -> List[float]:
139
- """Async version of get_text_embedding"""
140
- return self._get_text_embedding(text)
141
-
142
- # ------------------ Utility Functions ------------------
143
 
144
  def validate_github_url(url: str) -> bool:
145
  return url.startswith(('https://github.com/', 'http://github.com/'))
@@ -147,7 +41,7 @@ def validate_github_url(url: str) -> bool:
147
 
148
  def get_repo_name(url: str) -> str:
149
  try:
150
- return url.rstrip('/').split('/')[-1].replace('.git', '')
151
  except Exception as e:
152
  raise GitHubRAGError(f"Invalid repository URL: {e}")
153
 
@@ -171,179 +65,131 @@ def process_with_gitingets(github_url: str) -> tuple:
171
  return summary, tree, content
172
  except Exception as e:
173
  logger.error(f"Error processing repository: {e}")
174
- raise GitHubRAGError(f"Failed to process repository: {e}")
175
-
176
-
177
- def create_query_engine(content_path: str, repo_name: str) -> Any:
178
- """Create and configure LlamaIndex RAG query engine with FAISS vector store."""
179
- try:
180
- # Load documents from local folder
181
- loader = SimpleDirectoryReader(input_dir=content_path)
182
- docs = loader.load_data()
183
-
184
- # Create a Nomic embedding instance
185
- embed_model = NomicEmbedding()
186
-
187
- # Set up LlamaIndex to use Nomic embeddings
188
- Settings.embed_model = embed_model
189
-
190
- # Create FAISS index - using L2 distance (Euclidean)
191
- dimension = len(embed_model._get_text_embedding("test")) # Get dimensionality from a sample embedding
192
- faiss_index = faiss.IndexFlatL2(dimension)
193
-
194
- # Initialize FAISS vector store
195
- vector_store = FaissVectorStore(faiss_index=faiss_index)
196
-
197
- # Build vector index with markdown parsing and FAISS
198
- node_parser = MarkdownNodeParser()
199
- index = VectorStoreIndex.from_documents(
200
- documents=docs,
201
- transformations=[node_parser],
202
- vector_store=vector_store,
203
- show_progress=True
204
- )
205
-
206
- # Custom QA prompt template
207
- qa_prompt = PromptTemplate(
208
- template_str="""
209
- You are an AI assistant specialized in analyzing GitHub repositories.
210
-
211
- Repository structure:
212
- {tree}
213
-
214
- Context information:
215
- {context_str}
216
-
217
- Answer the following query about the repository. If unknown, say you don't have enough information.
218
-
219
- Query: {query_str}
220
- Answer:"""
221
- )
222
-
223
- # Configure query engine with streaming and template
224
- query_engine = index.as_query_engine(streaming=True)
225
- query_engine.update_prompts({
226
- "response_synthesizer:text_qa_template": qa_prompt
227
- })
228
-
229
- # And then configure it within llama-index
230
- llm = SambaNovaCloud(
231
  model_name="QwQ-32B",
232
- api_key=SAMBA_API_KEY,
233
- base_url=SAMBA_BASE_URL
 
 
234
  )
235
- Settings.llm = llm
236
-
237
- return query_engine
238
- except Exception as e:
239
- logger.error(f"Error creating query engine: {e}")
240
- raise GitHubRAGError(f"Failed to create query engine: {e}")
241
-
242
- # ------------------ Streamlit App ------------------
243
- # Initialize session state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  if "id" not in st.session_state:
245
  st.session_state.id = uuid.uuid4()
246
- st.session_state.file_cache = {}
247
  st.session_state.messages = []
248
 
249
  session_id = st.session_state.id
250
 
251
- # Sidebar inputs
252
  with st.sidebar:
253
- st.header("Add your GitHub repository!")
254
- github_url = st.text_input("Enter GitHub repository URL",
255
- placeholder="https://github.com/username/repo")
256
-
257
- load_repo = st.button("Load Repository", type="primary")
258
-
259
- if github_url and load_repo:
260
- try:
261
- if not validate_github_url(github_url):
262
- st.error("Please enter a valid GitHub repository URL")
263
- st.stop()
264
-
265
- repo_name = get_repo_name(github_url)
266
- file_key = f"{session_id}-{repo_name}"
267
-
268
- if file_key not in st.session_state.file_cache:
269
- with st.spinner("Processing your repository..."):
270
- with tempfile.TemporaryDirectory() as temp_dir:
271
- summary, tree, content = process_with_gitingets(github_url)
272
- # Write content for RAG
273
- content_path = temp_dir
274
- # Save full content as a doc
275
- md_path = os.path.join(temp_dir, f"{repo_name}.md")
276
- with open(md_path, "w", encoding="utf-8") as f:
277
- f.write(content)
278
-
279
- # Create query engine and cache
280
- query_engine = create_query_engine(content_path, repo_name)
281
- st.session_state.file_cache[file_key] = dict(
282
- engine=query_engine,
283
- tree=tree
284
- )
285
- st.success("Repository loaded successfully! Ready to chat.")
286
- else:
287
- st.info("Repository already loaded. Ready to chat!")
288
- except GitHubRAGError as e:
289
- st.error(str(e))
290
  st.stop()
291
-
292
- # Main chat UI
293
- col1, col2 = st.columns([6, 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  with col1:
295
- st.header("Chat with GitHub using RAG + Sambanova")
296
  with col2:
297
- st.button("Clear Chat", on_click=reset_chat)
298
 
299
- # Display chat history
300
  for msg in st.session_state.messages:
301
- with st.chat_message(msg["role"]):
302
- st.markdown(msg["content"])
303
 
304
- # Chat input
305
- if prompt := st.chat_input("Ask your question..."):
306
- st.session_state.messages.append({"role": "user", "content": prompt})
307
  with st.chat_message("user"):
308
  st.markdown(prompt)
309
-
 
 
 
 
310
  with st.chat_message("assistant"):
311
- file_key = f"{session_id}-{get_repo_name(github_url)}"
312
- cache = st.session_state.file_cache.get(file_key)
313
- if not cache:
314
- st.error("Please load a repository first!")
315
- st.stop()
316
-
317
- query_engine = cache['engine']
318
- tree_str = cache['tree']
319
- # Generate RAG response (streamed chunks)
320
- rag_response = query_engine.query(prompt)
321
- context_str = rag_response.context_str if hasattr(rag_response, 'context_str') else ''
322
-
323
- # Membuat pesan untuk model Sambanova
324
- messages = [
325
- {"role": "system", "content": "You are a knowledgeable assistant combining GitHub repository context with user queries."},
326
- {"role": "user", "content": f"Struktur Repositori:\n{tree_str}\nKonteks:\n{context_str}\nPertanyaan: {prompt}"}
327
- ]
328
-
329
- # Memanggil API Sambanova
330
- try:
331
- stream = sambanova_client.chat.completions.create(
332
- model="QwQ-32B", # Ganti dengan model yang sesuai
333
- messages=messages,
334
- temperature=0.1,
335
- top_p=0.1
336
- )
337
-
338
- full_resp = ""
339
- for chunk in stream:
340
- if chunk.choices[0].delta.content:
341
- full_resp += chunk.choices[0].delta.content
342
- st.write(full_resp + "▌")
343
- st.write(full_resp)
344
- st.session_state.messages.append({"role": "assistant", "content": full_resp})
345
-
346
- except Exception as e:
347
- logger.error(f"API Error: {str(e)}")
348
- st.error(f"Error generating response: {str(e)}")
349
- st.stop()
 
1
  import os
 
2
  import gc
3
  import tempfile
4
  import uuid
5
  import logging
 
 
 
6
 
7
  import streamlit as st
8
  from dotenv import load_dotenv
9
+
10
  from gitingest import ingest
11
+ from llama_index import (
12
+ SimpleDirectoryReader,
13
+ VectorStoreIndex,
14
+ PromptTemplate,
15
+ ServiceContext,
16
+ LLMPredictor,
17
+ )
18
+ from llama_index.node_parser import MarkdownNodeParser
19
+ from llama_index.embeddings import HuggingFaceEmbedding
20
+ from llama_index.llms import OpenAI
21
 
22
+ # Load environment
23
  load_dotenv()
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Configure logging
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
 
30
  MAX_REPO_SIZE = 100 * 1024 * 1024 # 100MB
31
  SUPPORTED_REPO_TYPES = ['.py', '.md', '.ipynb', '.js', '.ts', '.json']
32
 
 
33
  class GitHubRAGError(Exception):
34
  """Custom exception for GitHub RAG application errors"""
35
  pass
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def validate_github_url(url: str) -> bool:
39
  return url.startswith(('https://github.com/', 'http://github.com/'))
 
41
 
42
  def get_repo_name(url: str) -> str:
43
  try:
44
+ return url.split('/')[-1].replace('.git', '')
45
  except Exception as e:
46
  raise GitHubRAGError(f"Invalid repository URL: {e}")
47
 
 
65
  return summary, tree, content
66
  except Exception as e:
67
  logger.error(f"Error processing repository: {e}")
68
+ raise GitHubRAGError(str(e))
69
+
70
+
71
+ def create_query_engine(content_dir: str) -> Any:
72
+ """
73
+ Build index with nomic embeddings and query via Sambanova LLM
74
+ """
75
+ # Reader & parser
76
+ loader = SimpleDirectoryReader(input_dir=content_dir)
77
+ docs = loader.load_data()
78
+ node_parser = MarkdownNodeParser()
79
+
80
+ # Embedding model using Nomic Embed v2 MoE
81
+ embed_model = HuggingFaceEmbedding(
82
+ model_name="nomic-ai/nomic-embed-text-v2-moe",
83
+ embedding_device="cpu", # or 'cuda'
84
+ normalize_embeddings=True,
85
+ trust_remote_code=True,
86
+ )
87
+
88
+ # LLM predictor using Sambarova Cloud via OpenAI compatible API
89
+ llm_predictor = LLMPredictor(
90
+ llm=OpenAI(
91
+ api_key=os.environ.get("SAMBANOVA_API_KEY"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  model_name="QwQ-32B",
93
+ temperature=0.1,
94
+ top_p=0.1,
95
+ streaming=True,
96
+ api_base="https://api.sambanova.ai/v1",
97
  )
98
+ )
99
+
100
+ # Service context
101
+ service_context = ServiceContext.from_defaults(
102
+ embed_model=embed_model,
103
+ llm_predictor=llm_predictor,
104
+ prompt_helper=None,
105
+ )
106
+
107
+ # Build index
108
+ index = VectorStoreIndex.from_documents(
109
+ documents=docs,
110
+ service_context=service_context,
111
+ transformations=[node_parser],
112
+ show_progress=True,
113
+ )
114
+
115
+ # Custom QA prompt
116
+ qa_template = PromptTemplate(
117
+ "You are an AI assistant specialized in analyzing GitHub repositories.\n\n"
118
+ "Repository files and structure:\n{tree}\n---\n"
119
+ "Context:\n{context_str}\n---\n"
120
+ "Question: {query_str}\nAnswer:"
121
+ )
122
+ service_context.prompt_helper.set_default_template(
123
+ qa_template,
124
+ key="response_synthesizer:text_qa_template"
125
+ )
126
+
127
+ # Create query engine
128
+ return index.as_query_engine(streaming=True, service_context=service_context)
129
+
130
+ # Streamlit App
131
  if "id" not in st.session_state:
132
  st.session_state.id = uuid.uuid4()
133
+ st.session_state.cache = {}
134
  st.session_state.messages = []
135
 
136
  session_id = st.session_state.id
137
 
 
138
  with st.sidebar:
139
+ st.header("GitHub RAG with Sambanova & Nomic Embed")
140
+ github_url = st.text_input("GitHub Repo URL", help="e.g. https://github.com/user/repo")
141
+ load_btn = st.button("Load Repository")
142
+ if github_url and load_btn:
143
+ if not validate_github_url(github_url):
144
+ st.error("Invalid GitHub URL")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  st.stop()
146
+ repo_name = get_repo_name(github_url)
147
+ key = f"{session_id}-{repo_name}"
148
+ if key not in st.session_state.cache:
149
+ with st.spinner("Processing repository..."):
150
+ try:
151
+ summary, tree, content = process_with_gitingets(github_url)
152
+ with tempfile.TemporaryDirectory() as td:
153
+ # Save all files to directory
154
+ content_path = os.path.join(td, repo_name)
155
+ os.makedirs(content_path, exist_ok=True)
156
+ with open(os.path.join(content_path, f"{repo_name}.md"), "w") as f:
157
+ f.write(content)
158
+ # Build query engine
159
+ qe = create_query_engine(content_path)
160
+ st.session_state.cache[key] = (qe, tree)
161
+ st.success("Repository loaded!")
162
+ except GitHubRAGError as e:
163
+ st.error(str(e))
164
+ st.stop()
165
+ else:
166
+ st.info("Repository already loaded.")
167
+
168
+ col1, col2 = st.columns([6,1])
169
  with col1:
170
+ st.header("Chat with your Repo")
171
  with col2:
172
+ st.button("Clear Chat", on_click=reset_chat)
173
 
174
+ # Display chat
175
  for msg in st.session_state.messages:
176
+ with st.chat_message(msg['role']):
177
+ st.markdown(msg['content'])
178
 
179
+ if prompt := st.chat_input("Ask a question about the repository..."):
180
+ st.session_state.messages.append({"role":"user","content":prompt})
 
181
  with st.chat_message("user"):
182
  st.markdown(prompt)
183
+ key = f"{session_id}-{get_repo_name(github_url)}"
184
+ if key not in st.session_state.cache:
185
+ st.error("Load a repository first.")
186
+ st.stop()
187
+ qe, tree = st.session_state.cache[key]
188
  with st.chat_message("assistant"):
189
+ placeholder = st.empty()
190
+ answer = ""
191
+ for chunk in qe.query(prompt).response_gen:
192
+ answer += chunk
193
+ placeholder.markdown(answer + "▌")
194
+ placeholder.markdown(answer)
195
+ st.session_state.messages.append({"role":"assistant","content":answer})