Arjun Moorthy commited on
Commit
2720b05
Β·
1 Parent(s): da47961

Optimize for hardware constraints - make RAG optional and lightweight

Browse files
Files changed (2) hide show
  1. Oncolife/app.py +59 -52
  2. requirements.txt +1 -11
Oncolife/app.py CHANGED
@@ -4,7 +4,7 @@ OncoLife Symptom & Triage Assistant
4
  A medical chatbot that performs both symptom assessment and clinical triage for chemotherapy patients.
5
  Updated: Using BioMistral-7B base model for medical conversations.
6
  REBUILD: Simplified to use only base model, no adapters.
7
- RAG: Added document retrieval capabilities for PDFs and other reference materials.
8
  """
9
 
10
  import gradio as gr
@@ -15,14 +15,19 @@ from transformers import AutoTokenizer, MistralForCausalLM
15
  import torch
16
  from spaces import GPU
17
 
18
- # RAG imports
19
- import chromadb
20
- from sentence_transformers import SentenceTransformer
21
- import PyPDF2
22
- import pdfplumber
23
- from langchain.text_splitter import RecursiveCharacterTextSplitter
24
- from langchain.embeddings import HuggingFaceEmbeddings
25
- import fitz # PyMuPDF for better PDF handling
 
 
 
 
 
26
 
27
  # Force GPU detection for HF Spaces
28
  @GPU
@@ -51,8 +56,18 @@ class OncoLifeAssistant:
51
  # Load the OncoLife instructions
52
  self._load_instructions()
53
 
54
- # Initialize RAG system
55
- self._initialize_rag()
 
 
 
 
 
 
 
 
 
 
56
 
57
  def _load_instructions(self):
58
  """Load the OncoLife instructions from the text file"""
@@ -70,15 +85,15 @@ class OncoLifeAssistant:
70
  self.instructions = ""
71
 
72
  def _initialize_rag(self):
73
- """Initialize the RAG system with document embeddings"""
74
  try:
75
- print("πŸ” Initializing RAG system...")
76
 
77
- # Initialize embedding model
78
  self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
79
  print("βœ… Loaded embedding model")
80
 
81
- # Initialize ChromaDB
82
  self.chroma_client = chromadb.Client()
83
  self.collection = self.chroma_client.create_collection(
84
  name="oncolife_documents",
@@ -86,19 +101,20 @@ class OncoLifeAssistant:
86
  )
87
  print("βœ… Initialized ChromaDB collection")
88
 
89
- # Load and process documents
90
- self._load_documents()
91
 
92
  except Exception as e:
93
  print(f"❌ Error initializing RAG: {e}")
94
  self.embedding_model = None
95
  self.collection = None
 
96
 
97
- def _load_documents(self):
98
- """Load and process all reference documents"""
99
  try:
100
  docs_path = Path(__file__).parent / "guideline-docs"
101
- print(f"πŸ“š Loading documents from: {docs_path}")
102
 
103
  if not docs_path.exists():
104
  print("⚠️ guideline-docs directory not found")
@@ -106,27 +122,14 @@ class OncoLifeAssistant:
106
 
107
  # Text splitter for chunking documents
108
  text_splitter = RecursiveCharacterTextSplitter(
109
- chunk_size=1000,
110
- chunk_overlap=200,
111
  separators=["\n\n", "\n", ". ", " ", ""]
112
  )
113
 
114
  documents_loaded = 0
115
 
116
- # Process PDF files
117
- for pdf_file in docs_path.glob("*.pdf"):
118
- try:
119
- print(f"πŸ“„ Processing PDF: {pdf_file.name}")
120
- text = self._extract_pdf_text(pdf_file)
121
- if text:
122
- chunks = text_splitter.split_text(text)
123
- self._add_chunks_to_db(chunks, pdf_file.name)
124
- documents_loaded += 1
125
- print(f"βœ… Added {len(chunks)} chunks from {pdf_file.name}")
126
- except Exception as e:
127
- print(f"❌ Error processing {pdf_file.name}: {e}")
128
-
129
- # Process JSON files
130
  for json_file in docs_path.glob("*.json"):
131
  try:
132
  print(f"πŸ“„ Processing JSON: {json_file.name}")
@@ -141,7 +144,7 @@ class OncoLifeAssistant:
141
  except Exception as e:
142
  print(f"❌ Error processing {json_file.name}: {e}")
143
 
144
- # Process text files
145
  for txt_file in docs_path.glob("*.txt"):
146
  try:
147
  print(f"πŸ“„ Processing TXT: {txt_file.name}")
@@ -222,10 +225,10 @@ class OncoLifeAssistant:
222
  except Exception as e:
223
  print(f"❌ Error adding chunks to database: {e}")
224
 
225
- def _retrieve_relevant_documents(self, query, top_k=5):
226
  """Retrieve relevant document chunks for a query"""
227
  try:
228
- if not self.collection or not self.embedding_model:
229
  return []
230
 
231
  # Generate query embedding
@@ -254,7 +257,7 @@ class OncoLifeAssistant:
254
  return []
255
 
256
  def _load_model(self, model_id, gpu_available):
257
- """Load the BioMistral base model"""
258
  try:
259
  print("πŸ”„ Loading BioMistral base model...")
260
 
@@ -275,14 +278,16 @@ class OncoLifeAssistant:
275
  trust_remote_code=True
276
  )
277
 
278
- # Load the model
279
  print(f"πŸ“¦ Loading model: {model_id}")
280
  self.model = MistralForCausalLM.from_pretrained(
281
  model_id,
282
  trust_remote_code=True,
283
  device_map="auto",
284
  torch_dtype=dtype,
285
- low_cpu_mem_usage=True
 
 
286
  )
287
 
288
  # Add pad token if not present
@@ -297,7 +302,7 @@ class OncoLifeAssistant:
297
  self.tokenizer = None
298
 
299
  def generate_oncolife_response(self, user_input, conversation_history):
300
- """Generate response using OncoLife instructions and RAG"""
301
  try:
302
  if self.model is None or self.tokenizer is None:
303
  return """❌ **Model Loading Error**
@@ -311,15 +316,17 @@ Please check the Space logs for details."""
311
 
312
  print(f"πŸ”„ Generating OncoLife response for: {user_input}")
313
 
314
- # Retrieve relevant documents using RAG
315
- relevant_docs = self._retrieve_relevant_documents(user_input, top_k=3)
316
-
317
- # Format retrieved documents
318
  context_text = ""
319
- if relevant_docs:
320
- context_text = "\n\n**Relevant Reference Information:**\n"
321
- for i, doc in enumerate(relevant_docs):
322
- context_text += f"\n--- Source: {doc['source']} ---\n{doc['content'][:500]}...\n"
 
 
 
 
 
323
 
324
  # Create prompt using the loaded instructions and retrieved context
325
  system_prompt = f"""You are the OncoLife Symptom & Triage Assistant. Follow these instructions exactly:
@@ -426,7 +433,7 @@ Please try a simpler question or check the logs for more details."""
426
  "assistant": assistant_msg
427
  })
428
 
429
- # Generate response using OncoLife instructions and RAG
430
  response = self.generate_oncolife_response(message, conversation_history)
431
 
432
  return response
 
4
  A medical chatbot that performs both symptom assessment and clinical triage for chemotherapy patients.
5
  Updated: Using BioMistral-7B base model for medical conversations.
6
  REBUILD: Simplified to use only base model, no adapters.
7
+ RAG: Added document retrieval capabilities for PDFs and other reference materials (optional).
8
  """
9
 
10
  import gradio as gr
 
15
  import torch
16
  from spaces import GPU
17
 
18
+ # RAG imports (optional)
19
+ try:
20
+ import chromadb
21
+ from sentence_transformers import SentenceTransformer
22
+ import PyPDF2
23
+ import pdfplumber
24
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
25
+ from langchain.embeddings import HuggingFaceEmbeddings
26
+ import fitz # PyMuPDF for better PDF handling
27
+ RAG_AVAILABLE = True
28
+ except ImportError:
29
+ print("⚠️ RAG libraries not available, running in instruction-only mode")
30
+ RAG_AVAILABLE = False
31
 
32
  # Force GPU detection for HF Spaces
33
  @GPU
 
56
  # Load the OncoLife instructions
57
  self._load_instructions()
58
 
59
+ # Initialize RAG system (optional)
60
+ self.rag_enabled = False
61
+ if RAG_AVAILABLE:
62
+ try:
63
+ self._initialize_rag()
64
+ self.rag_enabled = True
65
+ print("βœ… RAG system initialized successfully")
66
+ except Exception as e:
67
+ print(f"⚠️ RAG initialization failed: {e}")
68
+ print("πŸ”„ Continuing with instruction-only mode")
69
+ else:
70
+ print("πŸ”„ Running in instruction-only mode (no RAG)")
71
 
72
  def _load_instructions(self):
73
  """Load the OncoLife instructions from the text file"""
 
85
  self.instructions = ""
86
 
87
  def _initialize_rag(self):
88
+ """Initialize the RAG system with document embeddings (lightweight version)"""
89
  try:
90
+ print("πŸ” Initializing lightweight RAG system...")
91
 
92
+ # Use a smaller embedding model
93
  self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
94
  print("βœ… Loaded embedding model")
95
 
96
+ # Initialize ChromaDB with persistence disabled for memory efficiency
97
  self.chroma_client = chromadb.Client()
98
  self.collection = self.chroma_client.create_collection(
99
  name="oncolife_documents",
 
101
  )
102
  print("βœ… Initialized ChromaDB collection")
103
 
104
+ # Load and process documents (limited to essential files)
105
+ self._load_documents_lightweight()
106
 
107
  except Exception as e:
108
  print(f"❌ Error initializing RAG: {e}")
109
  self.embedding_model = None
110
  self.collection = None
111
+ raise e
112
 
113
+ def _load_documents_lightweight(self):
114
+ """Load only essential documents to save memory"""
115
  try:
116
  docs_path = Path(__file__).parent / "guideline-docs"
117
+ print(f"πŸ“š Loading essential documents from: {docs_path}")
118
 
119
  if not docs_path.exists():
120
  print("⚠️ guideline-docs directory not found")
 
122
 
123
  # Text splitter for chunking documents
124
  text_splitter = RecursiveCharacterTextSplitter(
125
+ chunk_size=500, # Smaller chunks to save memory
126
+ chunk_overlap=100,
127
  separators=["\n\n", "\n", ". ", " ", ""]
128
  )
129
 
130
  documents_loaded = 0
131
 
132
+ # Only process JSON files (lightweight)
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  for json_file in docs_path.glob("*.json"):
134
  try:
135
  print(f"πŸ“„ Processing JSON: {json_file.name}")
 
144
  except Exception as e:
145
  print(f"❌ Error processing {json_file.name}: {e}")
146
 
147
+ # Process text files (lightweight)
148
  for txt_file in docs_path.glob("*.txt"):
149
  try:
150
  print(f"πŸ“„ Processing TXT: {txt_file.name}")
 
225
  except Exception as e:
226
  print(f"❌ Error adding chunks to database: {e}")
227
 
228
+ def _retrieve_relevant_documents(self, query, top_k=3):
229
  """Retrieve relevant document chunks for a query"""
230
  try:
231
+ if not self.collection or not self.embedding_model or not self.rag_enabled:
232
  return []
233
 
234
  # Generate query embedding
 
257
  return []
258
 
259
  def _load_model(self, model_id, gpu_available):
260
+ """Load the BioMistral base model with memory optimization"""
261
  try:
262
  print("πŸ”„ Loading BioMistral base model...")
263
 
 
278
  trust_remote_code=True
279
  )
280
 
281
+ # Load the model with memory optimization
282
  print(f"πŸ“¦ Loading model: {model_id}")
283
  self.model = MistralForCausalLM.from_pretrained(
284
  model_id,
285
  trust_remote_code=True,
286
  device_map="auto",
287
  torch_dtype=dtype,
288
+ low_cpu_mem_usage=True,
289
+ # Add memory optimization
290
+ max_memory={0: "8GB", "cpu": "16GB"} if gpu_available else {"cpu": "8GB"}
291
  )
292
 
293
  # Add pad token if not present
 
302
  self.tokenizer = None
303
 
304
  def generate_oncolife_response(self, user_input, conversation_history):
305
+ """Generate response using OncoLife instructions and optional RAG"""
306
  try:
307
  if self.model is None or self.tokenizer is None:
308
  return """❌ **Model Loading Error**
 
316
 
317
  print(f"πŸ”„ Generating OncoLife response for: {user_input}")
318
 
319
+ # Retrieve relevant documents using RAG (if available)
 
 
 
320
  context_text = ""
321
+ if self.rag_enabled:
322
+ try:
323
+ relevant_docs = self._retrieve_relevant_documents(user_input, top_k=2)
324
+ if relevant_docs:
325
+ context_text = "\n\n**Relevant Reference Information:**\n"
326
+ for i, doc in enumerate(relevant_docs):
327
+ context_text += f"\n--- Source: {doc['source']} ---\n{doc['content'][:300]}...\n"
328
+ except Exception as e:
329
+ print(f"⚠️ RAG retrieval failed: {e}")
330
 
331
  # Create prompt using the loaded instructions and retrieved context
332
  system_prompt = f"""You are the OncoLife Symptom & Triage Assistant. Follow these instructions exactly:
 
433
  "assistant": assistant_msg
434
  })
435
 
436
+ # Generate response using OncoLife instructions and optional RAG
437
  response = self.generate_oncolife_response(message, conversation_history)
438
 
439
  return response
requirements.txt CHANGED
@@ -1,5 +1,3 @@
1
- # Medical Chatbot HF Space Requirements
2
-
3
  # Web framework
4
  gradio==4.44.0
5
 
@@ -9,18 +7,10 @@ transformers==4.36.2
9
  accelerate==0.25.0
10
 
11
  # HF Spaces GPU support
12
- spaces>=0.1.0
13
-
14
- # Basic utilities
15
- numpy>=1.21.0,<2.0.0
16
- requests>=2.28.0
17
-
18
- # Additional dependencies for better device handling
19
  safetensors==0.4.1
20
  tokenizers>=0.15.0
21
 
22
- # RAG implementation
23
- bitsandbytes==0.41.3
24
  sentence-transformers==2.2.2
25
  chromadb==0.4.22
26
  pypdf2==3.0.1
 
 
 
1
  # Web framework
2
  gradio==4.44.0
3
 
 
7
  accelerate==0.25.0
8
 
9
  # HF Spaces GPU support
 
 
 
 
 
 
 
10
  safetensors==0.4.1
11
  tokenizers>=0.15.0
12
 
13
+ # RAG implementation (optional - will fallback gracefully if not available)
 
14
  sentence-transformers==2.2.2
15
  chromadb==0.4.22
16
  pypdf2==3.0.1