bhuvanmdev commited on
Commit
42ae41b
·
verified ·
1 Parent(s): 630ca55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -50
app.py CHANGED
@@ -1,63 +1,308 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
  ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  """
44
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
  """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
 
61
 
62
 
63
  if __name__ == "__main__":
 
1
+ from dataclasses import dataclass
2
+ from operator import itemgetter
3
+ from pathlib import Path
4
+ from typing import List, Optional, Dict, Any
5
+ import logging
6
+ from enum import Enum
7
+
8
  import gradio as gr
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain_community.vectorstores import Chroma
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain.schema import BaseRetriever
13
+ from langchain.embeddings.base import Embeddings
14
+ from langchain.llms.base import BaseLanguageModel
15
+ import PyPDF2
16
+ # Install required packages
17
 
18
+
19
+ # Initialize models
20
+ import torch
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
22
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
23
+ from langchain_community.llms import HuggingFacePipeline
24
+ from transformers import pipeline
25
+ from sentence_transformers import SentenceTransformer
26
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
27
+ embed_model = HuggingFaceBgeEmbeddings(
28
+ model_name="all-MiniLM-L6-v2",#"dunzhang/stella_en_1.5B_v5",
29
+ model_kwargs={'device': 'cpu'},
30
+ encode_kwargs={'normalize_embeddings': True}
31
+ )
32
+
33
+ model_name = "meta-llama/Llama-3.2-3B-Instruct" #"google/gemma-2-2b-it"#"prithivMLmods/Llama-3.2-3B-GGUF"
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ trust_remote_code=True,
38
+ use_auth_token=True
39
+ )
40
+
41
+ # model.generation_config.pad_token_id = model.generation_config.eos_token_id
42
+
43
+
44
+ # embed_model = embedding_model
45
+
46
+ # Set up logging
47
+ logging.basicConfig(level=logging.INFO)
48
+ logger = logging.getLogger(__name__)
49
+
50
+ class DocumentFormat(Enum):
51
+ PDF = ".pdf"
52
+ # Can be extended for other document types
53
+
54
+ @dataclass
55
+ class RAGConfig:
56
+ """Configuration for RAG system parameters"""
57
+ chunk_size: int = 500
58
+ chunk_overlap: int = 100
59
+ retriever_k: int = 3
60
+ persist_directory: str = "./chroma_db"
61
+
62
+ class AdvancedRAGSystem:
63
+ """Advanced RAG System with improved error handling and type safety"""
64
+
65
+ DEFAULT_TEMPLATE = """<|start_header_id|>system<|end_header_id|>
66
+ You are a helpful assistant. Use the following pieces of context to answer the question at the end.
67
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
68
+
69
+ Context:
70
+ {context}
71
+
72
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
73
+ {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
74
  """
75
+
76
+ def __init__(
77
+ self,
78
+ embed_model: Embeddings,
79
+ llm: BaseLanguageModel,
80
+ config: Optional[RAGConfig] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  ):
82
+ """Initialize the RAG system with required models and optional configuration"""
83
+ self.embed_model = embed_model
84
+ self.llm = llm
85
+ self.config = config or RAGConfig()
86
+ self.vector_store: Optional[Chroma] = None
87
+ self.last_context: Optional[str] = None
88
+
89
+ self.prompt = PromptTemplate(
90
+ template=self.DEFAULT_TEMPLATE,
91
+ input_variables=["context", "question"]
92
+ )
93
+
94
+ def _validate_file(self, file_path: Path) -> bool:
95
+ """Validate if the file is of supported format and exists"""
96
+ return file_path.suffix.lower() == DocumentFormat.PDF.value and file_path.exists()
97
+
98
+ def _extract_text_from_pdf(self, pdf_path: Path) -> str:
99
+ """Extract text from a PDF file with proper error handling"""
100
+ try:
101
+ with open(pdf_path, 'rb') as file:
102
+ pdf_reader = PyPDF2.PdfReader(file)
103
+ return "\n".join(
104
+ page.extract_text()
105
+ for page in pdf_reader.pages
106
+ )
107
+ except Exception as e:
108
+ logger.error(f"Error processing PDF {pdf_path}: {str(e)}")
109
+ raise ValueError(f"Failed to process PDF {pdf_path}: {str(e)}")
110
+
111
+ def _create_document_chunks(self, texts: List[str]) -> List[Any]:
112
+ """Split documents into chunks using the configured parameters"""
113
+ text_splitter = RecursiveCharacterTextSplitter(
114
+ chunk_size=self.config.chunk_size,
115
+ chunk_overlap=self.config.chunk_overlap,
116
+ length_function=len,
117
+ add_start_index=True,
118
+ )
119
+ return text_splitter.create_documents(texts)
120
 
121
+ def process_pdfs(self, pdf_files: List[str]) -> str:
122
+ """Process and index PDF documents with improved error handling"""
123
+ try:
124
+ # Convert to Path objects and validate
125
+ pdf_paths = [Path(pdf.name) for pdf in pdf_files]
126
+ invalid_files = [f for f in pdf_paths if not self._validate_file(f)]
127
+
128
+ if invalid_files:
129
+ raise ValueError(f"Invalid or missing files: {invalid_files}")
130
+
131
+ # Extract text from valid PDFs
132
+ documents = [
133
+ self._extract_text_from_pdf(pdf_path)
134
+ for pdf_path in pdf_paths
135
+ ]
136
+
137
+ # Create document chunks
138
+ doc_chunks = self._create_document_chunks(documents)
139
+
140
+ # Initialize or update vector store
141
+ self.vector_store = Chroma.from_documents(
142
+ documents=doc_chunks,
143
+ embedding=self.embed_model,
144
+ persist_directory=self.config.persist_directory
145
+ )
146
+
147
+ logger.info(f"Successfully processed {len(doc_chunks)} chunks from {len(pdf_files)} PDF files")
148
+ return f"Successfully processed {len(doc_chunks)} chunks from {len(pdf_files)} PDF files"
149
+
150
+ except Exception as e:
151
+ error_msg = f"Error during PDF processing: {str(e)}"
152
+ logger.error(error_msg)
153
+ raise RuntimeError(error_msg)
154
+
155
+ def get_retriever(self) -> BaseRetriever:
156
+ """Get the document retriever with current configuration"""
157
+ if not self.vector_store:
158
+ raise RuntimeError("Vector store not initialized. Please process documents first.")
159
+ return self.vector_store.as_retriever(search_kwargs={"k": self.config.retriever_k})
160
+
161
+ def _format_context(self, documents: List[Any]) -> str:
162
+ """Format retrieved documents into a single context string"""
163
+ return "\n\n".join(doc.page_content for doc in documents)
164
+
165
+ def query(self, question: str) -> Dict[str, str]:
166
+ """Query the RAG system with improved error handling and response formatting"""
167
+ try:
168
+ if not self.vector_store:
169
+ raise RuntimeError("Please process PDF documents first before querying")
170
+
171
+ # Retrieve relevant documents
172
+ retriever = self.get_retriever()
173
+ retrieved_docs = retriever.get_relevant_documents(question)
174
+ context = self._format_context(retrieved_docs)
175
+ self.last_context = context
176
+
177
+ # Generate response using LLM
178
+ response = self.llm.invoke(
179
+ self.prompt.format(
180
+ context=context,
181
+ question=question
182
+ )
183
+ )
184
+
185
+ return {
186
+ "answer": response.split("<|end_header_id|>")[-1],
187
+ "context": context,
188
+ "source_documents": len(retrieved_docs)
189
+ }
190
+
191
+ except Exception as e:
192
+ error_msg = f"Error during query processing: {str(e)}"
193
+ logger.error(error_msg)
194
+ raise RuntimeError(error_msg)
195
+
196
+ def create_gradio_interface(rag_system: AdvancedRAGSystem) -> gr.Blocks:
197
+ """Create an improved Gradio interface for the RAG system"""
198
+
199
+ def process_files(files: List[Any], chunk_size: int, overlap: int) -> str:
200
+ """Process uploaded files with updated configuration"""
201
+ if not files:
202
+ return "Please upload PDF files"
203
+
204
+ # Update configuration with new parameters
205
+ rag_system.config.chunk_size = chunk_size
206
+ rag_system.config.chunk_overlap = overlap
207
+
208
+ try:
209
+ return rag_system.process_pdfs(files)
210
+ except Exception as e:
211
+ return f"Error: {str(e)}"
212
+
213
+ def query_and_update_history(question: str) -> tuple[str, str]:
214
+ """Query system and update history with error handling"""
215
+ try:
216
+ result = rag_system.query(question)
217
+ return (
218
+ result["answer"],
219
+ f"Last context used ({result['source_documents']} documents):\n\n{result['context']}"
220
+ )
221
+ except Exception as e:
222
+ return str(e), "Error occurred while retrieving context"
223
+ with gr.Blocks(title="Advanced RAG System") as demo:
224
+ gr.Markdown("# Advanced RAG System with PDF Processing")
225
+
226
+ with gr.Tab("Upload & Process PDFs"):
227
+ with gr.Row():
228
+ with gr.Column():
229
+ file_input = gr.File(
230
+ file_count="multiple",
231
+ label="Upload PDF Documents",
232
+ file_types=[".pdf"]
233
+ )
234
+ chunk_size = gr.Slider(
235
+ minimum=100,
236
+ maximum=10000,
237
+ value=500,
238
+ step=100,
239
+ label="Chunk Size"
240
+ )
241
+ overlap = gr.Slider(
242
+ minimum=10,
243
+ maximum=5000,
244
+ value=100,
245
+ step=10,
246
+ label="Chunk Overlap"
247
+ )
248
+ process_button = gr.Button("Process PDFs", variant="primary")
249
+ process_output = gr.Textbox(label="Processing Status")
250
+
251
+ with gr.Tab("Query System"):
252
+ with gr.Row():
253
+ with gr.Column(scale=2):
254
+ question_input = gr.Textbox(
255
+ label="Your Question",
256
+ placeholder="Enter your question here...",
257
+ lines=3
258
+ )
259
+ query_button = gr.Button("Get Answer", variant="primary")
260
+ answer_output = gr.Textbox(
261
+ label="Answer",
262
+ lines=10
263
+ )
264
+ with gr.Column(scale=1):
265
+ history_output = gr.Textbox(
266
+ label="Retrieved Context",
267
+ lines=15
268
+ )
269
+
270
+ # Set up event handlers
271
+ process_button.click(
272
+ fn=process_files,
273
+ inputs=[file_input, chunk_size, overlap],
274
+ outputs=[process_output]
275
+ )
276
+
277
+ query_button.click(
278
+ fn=query_and_update_history,
279
+ inputs=[question_input],
280
+ outputs=[answer_output, history_output]
281
+ )
282
+
283
+ return demo
284
 
285
 
286
  """
287
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
288
  """
289
+ # demo = gr.ChatInterface(
290
+ # respond,
291
+ # additional_inputs=[
292
+ # gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
293
+ # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
294
+ # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
295
+ # gr.Slider(
296
+ # minimum=0.1,
297
+ # maximum=1.0,
298
+ # value=0.95,
299
+ # step=0.05,
300
+ # label="Top-p (nucleus sampling)",
301
+ # ),
302
+ # ],
303
+ # )
304
+ rag_system = AdvancedRAGSystem(embed_model, llm)
305
+ demo = create_gradio_interface(rag_system)
306
 
307
 
308
  if __name__ == "__main__":