bhuvanmdev commited on
Commit
a9705c8
·
verified ·
1 Parent(s): d0c8037

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -45
app.py CHANGED
@@ -30,25 +30,31 @@ embed_model = HuggingFaceBgeEmbeddings(
30
  encode_kwargs={'normalize_embeddings': True}
31
  )
32
 
33
- model_name = "google/gemma-2-2b-it"#"prithivMLmods/Llama-3.2-3B-GGUF"
34
- tokenizer = AutoTokenizer.from_pretrained(model_name)
35
- model = AutoModelForCausalLM.from_pretrained(
36
- model_name,
37
- trust_remote_code=True,
38
- use_auth_token=True
39
- )
40
 
41
- pipe = pipeline(
42
- "text-generation",
43
- model=model,
44
- tokenizer=tokenizer,
45
- max_new_tokens=2048*2,
46
- temperature=0.3,
47
- top_p=0.95,
48
- generation_config=model.generation_config
49
- # repetition_penalty=1.15
50
- )
51
- llm = HuggingFacePipeline(pipeline=pipe)
 
 
 
 
 
 
 
 
 
 
 
52
  # model.generation_config.pad_token_id = model.generation_config.eos_token_id
53
 
54
 
@@ -68,7 +74,7 @@ class RAGConfig:
68
  chunk_size: int = 500
69
  chunk_overlap: int = 100
70
  retriever_k: int = 3
71
- persist_directory: str = "./chroma_db"
72
 
73
  class AdvancedRAGSystem:
74
  """Advanced RAG System with improved error handling and type safety"""
@@ -96,11 +102,12 @@ Context:
96
  self.config = config or RAGConfig()
97
  self.vector_store: Optional[Chroma] = None
98
  self.last_context: Optional[str] = None
99
-
100
- self.prompt = PromptTemplate(
101
- template=self.DEFAULT_TEMPLATE,
102
- input_variables=["context", "question"]
103
- )
 
104
 
105
  def _validate_file(self, file_path: Path) -> bool:
106
  """Validate if the file is of supported format and exists"""
@@ -184,20 +191,41 @@ Context:
184
  retrieved_docs = retriever.get_relevant_documents(question)
185
  context = self._format_context(retrieved_docs)
186
  self.last_context = context
 
 
 
 
 
 
187
 
188
- # Generate response using LLM
189
- response = self.llm.invoke(
190
- self.prompt.format(
191
- context=context,
192
- question=question
193
- )
194
- )
195
 
196
- return {
197
- "answer": response.split("<|end_header_id|>")[-1],
198
- "context": context,
199
- "source_documents": len(retrieved_docs)
200
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  except Exception as e:
203
  error_msg = f"Error during query processing: {str(e)}"
@@ -221,16 +249,17 @@ def create_gradio_interface(rag_system: AdvancedRAGSystem) -> gr.Blocks:
221
  except Exception as e:
222
  return f"Error: {str(e)}"
223
 
224
- def query_and_update_history(question: str) -> tuple[str, str]:
225
  """Query system and update history with error handling"""
226
  try:
227
- result = rag_system.query(question)
228
- return (
229
- result["answer"],
230
- f"Last context used ({result['source_documents']} documents):\n\n{result['context']}"
231
- )
232
  except Exception as e:
233
- return str(e), "Error occurred while retrieving context"
 
 
 
 
234
  with gr.Blocks(title="Advanced RAG System") as demo:
235
  gr.Markdown("# Advanced RAG System with PDF Processing")
236
 
@@ -286,9 +315,15 @@ def create_gradio_interface(rag_system: AdvancedRAGSystem) -> gr.Blocks:
286
  )
287
 
288
  query_button.click(
289
- fn=query_and_update_history,
290
  inputs=[question_input],
291
- outputs=[answer_output, history_output]
 
 
 
 
 
 
292
  )
293
 
294
  return demo
 
30
  encode_kwargs={'normalize_embeddings': True}
31
  )
32
 
33
+ model_name = "meta-llama/Llama-3.2-3B-Instruct"#"google/gemma-2-2b-it"#"prithivMLmods/Llama-3.2-3B-GGUF"
34
+ from huggingface_hub import InferenceClient
 
 
 
 
 
35
 
36
+ client = InferenceClient(model_name)
37
+
38
+
39
+
40
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
41
+ # model = AutoModelForCausalLM.from_pretrained(
42
+ # model_name,
43
+ # trust_remote_code=True,
44
+ # use_auth_token=True
45
+ # )
46
+
47
+ # pipe = pipeline(
48
+ # "text-generation",
49
+ # model=model,
50
+ # tokenizer=tokenizer,
51
+ # max_new_tokens=2048*2,
52
+ # temperature=0.3,
53
+ # top_p=0.95,
54
+ # generation_config=model.generation_config
55
+ # # repetition_penalty=1.15
56
+ # )
57
+ # llm = HuggingFacePipeline(pipeline=pipe)
58
  # model.generation_config.pad_token_id = model.generation_config.eos_token_id
59
 
60
 
 
74
  chunk_size: int = 500
75
  chunk_overlap: int = 100
76
  retriever_k: int = 3
77
+ # persist_directory: str = "./chroma_db"
78
 
79
  class AdvancedRAGSystem:
80
  """Advanced RAG System with improved error handling and type safety"""
 
102
  self.config = config or RAGConfig()
103
  self.vector_store: Optional[Chroma] = None
104
  self.last_context: Optional[str] = None
105
+ self.context = None
106
+ self.source_documents = 0
107
+ # self.prompt = PromptTemplate(
108
+ # template=self.DEFAULT_TEMPLATE,
109
+ # input_variables=["context", "question"]
110
+ # )
111
 
112
  def _validate_file(self, file_path: Path) -> bool:
113
  """Validate if the file is of supported format and exists"""
 
191
  retrieved_docs = retriever.get_relevant_documents(question)
192
  context = self._format_context(retrieved_docs)
193
  self.last_context = context
194
+ messages = [
195
+ {
196
+ "role":"system",
197
+ "content":f"""<|start_header_id|>system<|end_header_id|>
198
+ You are a helpful assistant. Use the following pieces of context to answer the question at the end.
199
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
200
 
201
+ Context:
202
+ {context}
 
 
 
 
 
203
 
204
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
205
+ {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
206
+ """
207
+ },
208
+ {
209
+ "role": "user",
210
+ "content": "What is the capital of France?"
211
+ }
212
+ ]
213
+ self.context = context
214
+ self.source_documents = len(retrieved_docs)
215
+ # Generate response using LLM ###########
216
+ # response = self.llm.invoke(
217
+ # self.prompt.format(
218
+ # context=context,
219
+ # question=question
220
+ # )
221
+ # )
222
+
223
+ return client.chat.completions.create(
224
+ model=model_name,
225
+ messages=messages,
226
+ max_tokens=500,
227
+ stream=True
228
+ )
229
 
230
  except Exception as e:
231
  error_msg = f"Error during query processing: {str(e)}"
 
249
  except Exception as e:
250
  return f"Error: {str(e)}"
251
 
252
+ def query_fin(question):
253
  """Query system and update history with error handling"""
254
  try:
255
+ for x in rag_system.query(question):
256
+ yield x.choices[0].delta.content
 
 
 
257
  except Exception as e:
258
+ pass
259
+
260
+ def update_history(question: str):
261
+ return f"Last context used ({self.source_documents} documents):\n\n{self.context}"
262
+
263
  with gr.Blocks(title="Advanced RAG System") as demo:
264
  gr.Markdown("# Advanced RAG System with PDF Processing")
265
 
 
315
  )
316
 
317
  query_button.click(
318
+ fn=query_fin,
319
  inputs=[question_input],
320
+ outputs=[answer_output]
321
+ )
322
+
323
+ query_button.click(
324
+ fn=update_history,
325
+ inputs=[],
326
+ outputs=[history_output]
327
  )
328
 
329
  return demo