bhuvanmdev commited on
Commit
fe9b76b
·
verified ·
1 Parent(s): 21a0347

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -97
app.py CHANGED
@@ -8,22 +8,17 @@ from enum import Enum
8
  import gradio as gr
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain_community.vectorstores import Chroma
11
- from langchain.prompts import PromptTemplate
12
  from langchain.schema import BaseRetriever
13
  from langchain.embeddings.base import Embeddings
14
  from langchain.llms.base import BaseLanguageModel
15
  import PyPDF2
 
16
  # Install required packages
17
 
18
 
19
  # Initialize models
20
  import torch
21
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
22
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
23
- from langchain_community.llms import HuggingFacePipeline
24
- from transformers import pipeline
25
- from sentence_transformers import SentenceTransformer
26
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
27
  embed_model = HuggingFaceBgeEmbeddings(
28
  model_name="all-MiniLM-L6-v2",#"dunzhang/stella_en_1.5B_v5",
29
  model_kwargs={'device': 'cpu'},
@@ -31,34 +26,10 @@ embed_model = HuggingFaceBgeEmbeddings(
31
  )
32
 
33
  model_name = "meta-llama/Llama-3.2-3B-Instruct"#"google/gemma-2-2b-it"#"prithivMLmods/Llama-3.2-3B-GGUF"
34
- from huggingface_hub import InferenceClient
35
-
36
- client = InferenceClient(model_name)
37
 
38
 
 
39
 
40
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
41
- # model = AutoModelForCausalLM.from_pretrained(
42
- # model_name,
43
- # trust_remote_code=True,
44
- # use_auth_token=True
45
- # )
46
-
47
- # pipe = pipeline(
48
- # "text-generation",
49
- # model=model,
50
- # tokenizer=tokenizer,
51
- # max_new_tokens=2048*2,
52
- # temperature=0.3,
53
- # top_p=0.95,
54
- # generation_config=model.generation_config
55
- # # repetition_penalty=1.15
56
- # )
57
- # llm = HuggingFacePipeline(pipeline=pipe)
58
- # model.generation_config.pad_token_id = model.generation_config.eos_token_id
59
-
60
-
61
- # embed_model = embedding_model
62
 
63
  # Set up logging
64
  logging.basicConfig(level=logging.INFO)
@@ -71,24 +42,14 @@ class DocumentFormat(Enum):
71
  @dataclass
72
  class RAGConfig:
73
  """Configuration for RAG system parameters"""
74
- chunk_size: int = 500
75
- chunk_overlap: int = 100
76
  retriever_k: int = 3
77
  persist_directory: str = "./chroma_db"
78
 
79
  class AdvancedRAGSystem:
80
  """Advanced RAG System with improved error handling and type safety"""
81
 
82
- DEFAULT_TEMPLATE = """<|start_header_id|>system<|end_header_id|>
83
- You are a helpful assistant. Use the following pieces of context to answer the question at the end.
84
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
85
-
86
- Context:
87
- {context}
88
-
89
- <|eot_id|><|start_header_id|>user<|end_header_id|>
90
- {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
91
- """
92
 
93
  def __init__(
94
  self,
@@ -104,10 +65,6 @@ Context:
104
  self.last_context: Optional[str] = None
105
  self.context = None
106
  self.source_documents = 0
107
- # self.prompt = PromptTemplate(
108
- # template=self.DEFAULT_TEMPLATE,
109
- # input_variables=["context", "question"]
110
- # )
111
 
112
  def _validate_file(self, file_path: Path) -> bool:
113
  """Validate if the file is of supported format and exists"""
@@ -191,48 +148,44 @@ Context:
191
  retrieved_docs = retriever.get_relevant_documents(question)
192
  context = self._format_context(retrieved_docs)
193
  self.last_context = context
 
 
194
  messages = [
195
  {
196
  "role":"system",
197
- "content":f"""<|start_header_id|>system<|end_header_id|>
198
- You are a helpful assistant. Use the following pieces of context to answer the question at the end.
199
  If you don't know the answer, just say that you don't know, don't try to make up an answer.
200
 
201
  Context:
202
  {context}
203
-
204
- <|eot_id|><|start_header_id|>user<|end_header_id|>
205
- {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
206
  """
207
  },
208
  {
209
  "role": "user",
210
- "content": "What is the capital of France?"
211
  }
212
  ]
213
- self.context = context
214
- self.source_documents = len(retrieved_docs)
215
- # Generate response using LLM ###########
216
- # response = self.llm.invoke(
217
- # self.prompt.format(
218
- # context=context,
219
- # question=question
220
- # )
221
- # )
222
 
223
- for x in self.llm.chat.completions.create(
224
- model=model_name,
225
- messages=messages,
226
- max_tokens=500,
227
- stream=True
228
- ):
229
- yield x
 
 
 
 
 
230
 
231
  except Exception as e:
232
  error_msg = f"Error during query processing: {str(e)}"
233
  logger.error(error_msg)
234
- raise RuntimeError(error_msg)
 
235
 
 
236
  def create_gradio_interface(rag_system: AdvancedRAGSystem) -> gr.Blocks:
237
  """Create an improved Gradio interface for the RAG system"""
238
 
@@ -274,14 +227,14 @@ def create_gradio_interface(rag_system: AdvancedRAGSystem) -> gr.Blocks:
274
  chunk_size = gr.Slider(
275
  minimum=100,
276
  maximum=10000,
277
- value=500,
278
  step=100,
279
  label="Chunk Size"
280
  )
281
  overlap = gr.Slider(
282
  minimum=10,
283
  maximum=5000,
284
- value=100,
285
  step=10,
286
  label="Chunk Overlap"
287
  )
@@ -315,40 +268,20 @@ def create_gradio_interface(rag_system: AdvancedRAGSystem) -> gr.Blocks:
315
  )
316
 
317
  query_button.click(
318
- fn=query_fin,
319
  inputs=[question_input],
320
  outputs=[answer_output],
321
  api_name="stream_response",
322
- show_progress=False
323
- )
324
-
325
- query_button.click(
326
- fn=update_history,
327
  inputs=[question_input],
328
- outputs=[history_output]
329
  )
330
 
331
  return demo
332
 
333
 
334
- """
335
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
336
- """
337
- # demo = gr.ChatInterface(
338
- # respond,
339
- # additional_inputs=[
340
- # gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
341
- # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
342
- # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
343
- # gr.Slider(
344
- # minimum=0.1,
345
- # maximum=1.0,
346
- # value=0.95,
347
- # step=0.05,
348
- # label="Top-p (nucleus sampling)",
349
- # ),
350
- # ],
351
- # )
352
  rag_system = AdvancedRAGSystem(embed_model, client)
353
  demo = create_gradio_interface(rag_system)
354
 
 
8
  import gradio as gr
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain_community.vectorstores import Chroma
 
11
  from langchain.schema import BaseRetriever
12
  from langchain.embeddings.base import Embeddings
13
  from langchain.llms.base import BaseLanguageModel
14
  import PyPDF2
15
+ from huggingface_hub import InferenceClient
16
  # Install required packages
17
 
18
 
19
  # Initialize models
20
  import torch
 
21
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
 
 
 
 
22
  embed_model = HuggingFaceBgeEmbeddings(
23
  model_name="all-MiniLM-L6-v2",#"dunzhang/stella_en_1.5B_v5",
24
  model_kwargs={'device': 'cpu'},
 
26
  )
27
 
28
  model_name = "meta-llama/Llama-3.2-3B-Instruct"#"google/gemma-2-2b-it"#"prithivMLmods/Llama-3.2-3B-GGUF"
 
 
 
29
 
30
 
31
+ client = InferenceClient(model_name)
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Set up logging
35
  logging.basicConfig(level=logging.INFO)
 
42
  @dataclass
43
  class RAGConfig:
44
  """Configuration for RAG system parameters"""
45
+ chunk_size: int = 100
46
+ chunk_overlap: int = 10
47
  retriever_k: int = 3
48
  persist_directory: str = "./chroma_db"
49
 
50
  class AdvancedRAGSystem:
51
  """Advanced RAG System with improved error handling and type safety"""
52
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def __init__(
55
  self,
 
65
  self.last_context: Optional[str] = None
66
  self.context = None
67
  self.source_documents = 0
 
 
 
 
68
 
69
  def _validate_file(self, file_path: Path) -> bool:
70
  """Validate if the file is of supported format and exists"""
 
148
  retrieved_docs = retriever.get_relevant_documents(question)
149
  context = self._format_context(retrieved_docs)
150
  self.last_context = context
151
+ self.context = context
152
+ self.source_documents = len(retrieved_docs)
153
  messages = [
154
  {
155
  "role":"system",
156
+ "content":f"""You are a helpful assistant. Use the following pieces of context to answer the question at the end.
 
157
  If you don't know the answer, just say that you don't know, don't try to make up an answer.
158
 
159
  Context:
160
  {context}
 
 
 
161
  """
162
  },
163
  {
164
  "role": "user",
165
+ "content": question
166
  }
167
  ]
 
 
 
 
 
 
 
 
 
168
 
169
+ response_text = ""
170
+ for chunk in self.llm.chat.completions.create(
171
+ model=model_name,
172
+ messages=messages,
173
+ max_tokens=500,
174
+ stream=True
175
+ ):
176
+ if hasattr(chunk.choices[0].delta, 'content'):
177
+ content = chunk.choices[0].delta.content
178
+ if content is not None:
179
+ response_text += content
180
+ yield response_text
181
 
182
  except Exception as e:
183
  error_msg = f"Error during query processing: {str(e)}"
184
  logger.error(error_msg)
185
+ yield error_msg
186
+
187
 
188
+
189
  def create_gradio_interface(rag_system: AdvancedRAGSystem) -> gr.Blocks:
190
  """Create an improved Gradio interface for the RAG system"""
191
 
 
227
  chunk_size = gr.Slider(
228
  minimum=100,
229
  maximum=10000,
230
+ value=100,
231
  step=100,
232
  label="Chunk Size"
233
  )
234
  overlap = gr.Slider(
235
  minimum=10,
236
  maximum=5000,
237
+ value=10,
238
  step=10,
239
  label="Chunk Overlap"
240
  )
 
268
  )
269
 
270
  query_button.click(
271
+ fn=query_streaming,
272
  inputs=[question_input],
273
  outputs=[answer_output],
274
  api_name="stream_response",
275
+ queue=False
276
+ ).then(
277
+ fn=update_context,
 
 
278
  inputs=[question_input],
279
+ outputs=[context_output]
280
  )
281
 
282
  return demo
283
 
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  rag_system = AdvancedRAGSystem(embed_model, client)
286
  demo = create_gradio_interface(rag_system)
287