shamim237 commited on
Commit
8ff45d7
·
verified ·
1 Parent(s): 953293f

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ vectors_data/med_data_vec/index.faiss filter=lfs diff=lfs merge=lfs -text
37
+ vectors_data/msd_data_vec/index.faiss filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ .env
4
+ *.env
5
+ chatbot/
6
+ processed_data/
7
+ src/
8
+ vectors_data/
9
+ logs/
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import asyncio
3
+ import uuid
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from typing import Dict, Optional
7
+
8
+ from fastapi import FastAPI, HTTPException, Request
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import JSONResponse
11
+ from pydantic import BaseModel, constr, Field, validator, constr
12
+ from rag import generate_answer # Importing the function to generate answers
13
+
14
+ # Create logs directory if it doesn't exist
15
+ Path("logs").mkdir(exist_ok=True) # Ensure logs directory is available for logging
16
+
17
+ # Initialize FastAPI app with metadata
18
+ app = FastAPI(
19
+ title="Question Answering API", # Title of the API
20
+ description="API for generating answers using RAG", # Description of the API
21
+ version="1.0.0", # Version of the API
22
+ docs_url="/docs", # URL for API documentation
23
+ redoc_url="/redoc" # URL for ReDoc documentation
24
+ )
25
+
26
+ # Configure CORS
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"], # Allow all origins; modify in production for security
30
+ allow_credentials=True, # Allow credentials to be included in requests
31
+ allow_methods=["*"], # Allow all HTTP methods
32
+ allow_headers=["*"], # Allow all headers
33
+ )
34
+
35
+ # Configure detailed logging
36
+ logging.basicConfig(
37
+ level=logging.INFO, # Set logging level to INFO
38
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # Log format
39
+ handlers=[ # Handlers for logging output
40
+ logging.StreamHandler(), # Log to console
41
+ logging.FileHandler(f'logs/app_{datetime.now().strftime("%Y%m%d")}.log') # Log to file
42
+ ]
43
+ )
44
+ logger = logging.getLogger(__name__) # Create a logger for this module
45
+
46
+ def generate_request_id() -> str:
47
+ """Generate a unique request ID.""" # Docstring for function
48
+ return str(uuid.uuid4()) # Return a new UUID as a string
49
+
50
+ class QuestionRequest(BaseModel):
51
+ """Request model for question answering endpoint.""" # Docstring for model
52
+ question: constr(min_length=1) = Field( # Field for the question with validation
53
+ ...,
54
+ description="The question to be answered", # Description of the field
55
+ example="What is the capital of France?" # Example question
56
+ )
57
+
58
+ @validator('question') # Validator for the question field
59
+ def validate_question(cls, v):
60
+ if v.strip() == "": # Check if the question is empty
61
+ raise HTTPException(
62
+ status_code=500, # Raise HTTP 500 if empty
63
+ detail="Error processing request" # Error message
64
+ )
65
+ return v.strip() # Return the trimmed question
66
+
67
+ class ErrorResponse(BaseModel):
68
+ """Standard error response model.""" # Docstring for error response model
69
+ detail: str # Detail of the error
70
+ request_id: Optional[str] = None # Optional request ID for tracking
71
+
72
+ @app.middleware("http")
73
+ async def add_request_id(request: Request, call_next):
74
+ """Middleware to add request ID to all requests.""" # Docstring for middleware
75
+ request_id = generate_request_id() # Generate a new request ID
76
+ request.state.request_id = request_id # Store request ID in request state
77
+ response = await call_next(request) # Process the request
78
+ response.headers["X-Request-ID"] = request_id # Add request ID to response headers
79
+ return response # Return the response
80
+
81
+ @app.exception_handler(HTTPException)
82
+ async def http_exception_handler(request: Request, exc: HTTPException):
83
+ """Custom exception handler to include request ID.""" # Docstring for exception handler
84
+ return JSONResponse(
85
+ status_code=exc.status_code, # Return the status code from the exception
86
+ content={"detail": exc.detail, "request_id": request.state.request_id} # Include error details and request ID
87
+ )
88
+
89
+ @app.post("/answer",
90
+ response_model=Dict[str, str], # Response model for the endpoint
91
+ summary="Generate an answer for the given question", # Summary of the endpoint
92
+ response_description="Returns the generated answer", # Description of the response
93
+ responses={ # Possible responses
94
+ 400: {"model": ErrorResponse}, # Bad request response
95
+ 500: {"model": ErrorResponse}, # Internal server error response
96
+ 504: {"model": ErrorResponse} # Gateway timeout response
97
+ }
98
+ )
99
+ async def get_answer(request: Request, question_request: QuestionRequest) -> Dict[str, str]:
100
+ """
101
+ Generate an answer for the given question.
102
+
103
+ Args:
104
+ request: FastAPI request object # Description of request argument
105
+ question_request: The question request model # Description of question_request argument
106
+
107
+ Returns:
108
+ Dict containing the generated answer # Description of return value
109
+
110
+ Raises:
111
+ HTTPException: For various error conditions # Description of possible exceptions
112
+ """
113
+ request_id = request.state.request_id # Retrieve request ID from state
114
+ logger.info("Processing request %s - Question: %s", request_id, question_request.question) # Log the processing request
115
+
116
+ try:
117
+ # Additional validation
118
+ if len(question_request.question) > 1000: # Check if question exceeds max length
119
+ logger.warning("Request %s: Question exceeds maximum length", request_id) # Log warning
120
+ raise HTTPException(
121
+ status_code=400, # Raise HTTP 400 for bad request
122
+ detail="Question length exceeds maximum allowed characters (1000)" # Error message
123
+ )
124
+
125
+ # Convert generate_answer to async if it's not already
126
+ async def async_generate_answer(question: str):
127
+ loop = asyncio.get_event_loop() # Get the current event loop
128
+ return await loop.run_in_executor(None, generate_answer, question) # Run the generate_answer function in executor
129
+
130
+ # Generate answer with timeout handling
131
+ answer = await asyncio.wait_for(
132
+ async_generate_answer(question_request.question), # Await the answer generation
133
+ timeout=120.0 # 120 second timeout
134
+ )
135
+
136
+ logger.info("Request %s: Successfully generated answer", request_id) # Log successful answer generation
137
+ return {"answer": answer} # Return the generated answer
138
+
139
+ except asyncio.TimeoutError: # Handle timeout errors
140
+ logger.error("Request %s: Generation timeout", request_id) # Log timeout error
141
+ raise HTTPException(
142
+ status_code=504, # Raise HTTP 504 for timeout
143
+ detail="Answer generation timed out" # Error message
144
+ )
145
+
146
+ except ValueError as ve: # Handle value errors
147
+ logger.error("Request %s: Validation error - %s", request_id, str(ve)) # Log validation error
148
+ raise HTTPException(
149
+ status_code=400, # Raise HTTP 400 for bad request
150
+ detail=str(ve) # Return the validation error message
151
+ )
152
+
153
+ except Exception as e: # Handle all other exceptions
154
+ logger.error(
155
+ "Request %s: Unexpected error - %s",
156
+ request_id,
157
+ str(e),
158
+ exc_info=True # Include stack trace in logs
159
+ )
160
+ raise HTTPException(
161
+ status_code=500, # Raise HTTP 500 for internal server error
162
+ detail="Internal server error occurred. Please try again later." # Error message
163
+ )
164
+
165
+ @app.get("/health",
166
+ summary="Health check endpoint", # Summary of the health check endpoint
167
+ response_description="Returns the API health status" # Description of the response
168
+ )
169
+ async def health_check():
170
+ """Health check endpoint to verify API is running.""" # Docstring for health check
171
+ return {"status": "healthy"} # Return health status
172
+
173
+ @app.on_event("startup")
174
+ async def startup_event():
175
+ """Initialize any resources on startup.""" # Docstring for startup event
176
+ logger.info("Starting up Question Answering API") # Log startup event
177
+ # Add any initialization code here (e.g., loading models)
178
+
179
+ @app.on_event("shutdown")
180
+ async def shutdown_event():
181
+ """Cleanup any resources on shutdown.""" # Docstring for shutdown event
182
+ logger.info("Shutting down Question Answering API") # Log shutdown event
183
+ # Add cleanup code here
184
+
185
+ if __name__ == "__main__": # Entry point for the application
186
+ import uvicorn # Import Uvicorn for running the app
187
+
188
+ # Load configuration from environment variables or use defaults
189
+ host = "0.0.0.0" # Host address
190
+ port = 8000 # Port number
191
+
192
+ logger.info(f"Starting server on {host}:{port}") # Log server start
193
+
194
+ uvicorn.run(
195
+ "app:app", # Application to run
196
+ host=host, # Host address
197
+ port=port, # Port number
198
+ reload=False, # Disable auto-reload in production
199
+ log_level="info", # Set log level
200
+ access_log=True, # Enable access logging
201
+ workers=1 # Number of worker processes
202
+ )
gradio_app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from rag import generate_answer # Import the generate_answer function from rag.py
4
+
5
+ def chat_with_bot(query: str) -> tuple:
6
+ """
7
+ Chat with the bot using the provided query.
8
+
9
+ Args:
10
+ query: User's question
11
+
12
+ Returns:
13
+ tuple: Generated answer and extracted contents
14
+ """
15
+ response = generate_answer(query) # Call the existing generate_answer function
16
+
17
+ return response
18
+
19
+ def main():
20
+ """
21
+ Main function to create and launch the Gradio interface.
22
+ """
23
+ try:
24
+ # Create Gradio interface
25
+ interface = gr.Interface(
26
+ fn=chat_with_bot, # Use the chat function
27
+ inputs="text", # Input type
28
+ outputs="text", # Output type
29
+ title="RAG-LLM based Medical Chatbot", # Title of the app
30
+ description="Ask your medical questions and get answers from the chatbot." # Description
31
+ )
32
+ interface.launch(share=True) # Launch the Gradio app
33
+
34
+ except Exception as e:
35
+ print(f"An error occurred: {str(e)}")
36
+
37
+ if __name__ == "__main__":
38
+ main()
log.txt ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Query: who is john cena?
2
+ Response: I'm sorry, but based on the provided context, there is no information related to "John Cena" or any medical relevance to that name. If you have any other medical questions or concerns, please feel free to ask. Remember, for personalized medical advice, it is always best to consult with a healthcare professional.
3
+
4
+ Query: i have headache, suggest me some medicine
5
+ Response: Based on the context provided, if you have a headache, you can consider taking Ibuprofen EG syrup suspension. However, it is important to consult with a healthcare professional before taking any medication, especially if you have other underlying health conditions or are taking other medications. Ibuprofen may cause side effects such as injection site reaction, tenderness, fatigue, myalgia, arthralgia, fever, and nausea. If your headache persists or worsens, it is advisable to seek medical advice for proper evaluation and treatment.
6
+
7
+ Query: what is the best medicine when i have a headache?
8
+ Response: The best medicine for a headache can vary depending on the individual and the underlying cause of the headache. In general, over-the-counter pain relievers such as acetaminophen (oral or IV), ibuprofen, or naproxen can be effective for mild to moderate headaches. However, it is important to consult with a healthcare professional before taking any medication, especially if you are immunosuppressed or have papilledema. Additionally, it is important to avoid overusing medications, as this can lead to medication overuse headaches. In some cases, cognitive therapy, biofeedback, and education can also be helpful in managing headaches. If you have frequent headaches, it is recommended to seek medical advice for proper diagnosis and treatment.
9
+
10
+ Query: I have Diarrhea, suggest some medicine
11
+ Response: Based on the provided context, it is important to first consult a healthcare professional for a proper diagnosis and treatment plan. However, if you are experiencing diarrhea, you may consider over-the-counter medications such as loperamide (Imodium) to help alleviate symptoms. It is important to follow the recommended dosage and consult a healthcare provider if symptoms persist or worsen. Additionally, it is important to stay hydrated by drinking plenty of fluids.
12
+
13
+ Query: What are the clinical signs and symptoms of mitral stenosis?
14
+ Response: The clinical signs and symptoms of mitral stenosis include those of heart failure, such as fatigue, cold skin, and right upper quadrant abdominal discomfort. Additionally, signs may include an opening snap and a diastolic murmur. It is important to consult a healthcare professional for proper diagnosis and treatment. Medications such as diuretics and beta-blockers may be prescribed, but potential side effects should be discussed with a healthcare provider.
15
+
16
+ Query: How does dehydration affect blood pressure in someone with autonomic dysfunction?
17
+ Response: Dehydration can lead to hypotension (low blood pressure) in someone with autonomic dysfunction. This is because autonomic dysfunction can already cause orthostatic hypotension, and dehydration further exacerbates this by reducing blood volume and affecting the body's ability to regulate blood pressure. It is important for individuals with autonomic dysfunction to stay well-hydrated to help maintain their blood pressure within a healthy range. However, for personalized advice and treatment options, it is recommended to consult a healthcare professional.
18
+
19
+ Query: Can an ear infection cause a heart problem? If yes, how?
20
+ Response: Yes, in heavy infections, the inflammation from an ear infection may cause complications such as cardiac issues like myocarditis, heart failure, arrhythmia, and even death. The inflammation can spread to the heart and lead to serious complications. It is important to seek medical attention if you suspect an ear infection to prevent any potential complications. Consulting a healthcare professional for proper diagnosis and treatment is recommended.
21
+
22
+ Query: When do VPBs require treatment?
23
+ Response: VPBs, or ventricular premature beats, may require treatment if they are causing symptoms such as palpitations, dizziness, or fainting. Treatment options may include medications such as antiarrhythmics. It is important to consult with a healthcare professional for proper evaluation and management of VPBs.
24
+
25
+ Query: When do VPBs require treatment?
26
+ Response: VPBs, or ventricular premature beats, may require treatment if they are causing symptoms such as palpitations, dizziness, or fainting. Treatment options may include medications such as quetiapine or olanzapine, but it is important to consult a healthcare professional for proper evaluation and management. In some cases, preventive treatment may be indicated for patients with certain medical conditions that require bed rest. If you are experiencing symptoms of VPBs, it is important to seek medical advice for appropriate care.
27
+
28
+ Query: How does microvascular angina differ from typical angina?
29
+ Response: Microvascular angina differs from typical angina in that it is caused by impaired coronary blood flow due to issues in the smaller blood vessels, rather than blockages in the larger epicardial coronary arteries. This can lead to symptoms such as angina pectoris, syncope, or arrhythmias even in the absence of coronary artery disease. To diagnose microvascular angina, guidewires with pressure or flow sensors can be used to estimate blood flow across stenoses, with a fractional flow reserve (FFR) of < 0.75 to 0.8 indicating myocardial ischemia. Treatment for microvascular angina may involve medications like Lipiodol Ultra Fluide or Solu-Medrol, but it is important to consult a healthcare professional for proper diagnosis and management, as well as to discuss potential side effects of these medications.
30
+
31
+ Query: What are the common causes of pericarditis?
32
+ Response: Common causes of pericarditis include systemic rheumatic diseases such as rheumatoid arthritis, systemic lupus erythematosus (SLE), and systemic sclerosis. Bacterial infections from gram-negative bacilli, Haemophilus influenzae, Lyme disease, rheumatic fever, staphylococci, and streptococci can also lead to pericarditis. Additionally, tuberculosis, certain cancers like breast cancer, Kaposi sarcoma (in people with HIV infection), leukemia, and lung cancer, as well as certain medications like anticoagulants, can be common causes of pericarditis. It is important to consult a healthcare professional for proper diagnosis and treatment.
33
+
34
+ Query: suggest some medicine for pericarditis
35
+ Response: For pericarditis, medications such as colchicine or interleukin-1 antagonists may be prescribed based on the underlying cause and severity of the condition. It is important to consult a healthcare professional for a proper diagnosis and treatment plan. Additionally, for patients with newly diagnosed constrictive pericarditis, anti-inflammatory medications may be considered before pericardiectomy. It is crucial to follow the guidance of a healthcare provider and discuss any potential side effects of the prescribed medication.
36
+
37
+ Query: What are the common causes of orthostatic hypotension?
38
+ Response: Common causes of orthostatic hypotension include dehydration, electrolyte disturbances, use of certain drugs (such as antihypertensives and nitrates), aging, heart failure, chronic kidney disease, respiratory disorders, liver failure, renal insufficiency, and sedatives or opioids used for intubation and ventilation. It is important to consult a healthcare professional for proper diagnosis and treatment, as the underlying cause may vary from person to person.
39
+
40
+ Query: How is long QT syndrome diagnosed and treated?'
41
+ Response: Long QT syndrome is diagnosed based on clinical criteria. Treatment typically involves supportive care, including medications as needed for associated symptoms. It is important to consult with a healthcare professional for proper diagnosis and treatment. If medications are prescribed, it is essential to be aware of potential side effects.
42
+
43
+ Query: What are the classes of antiarrhythmic drugs and their mechanisms of action?
44
+ Response: The classes of antiarrhythmic drugs are grouped into four main classes based on their dominant cellular electrophysiologic effect. Class I drugs are sodium channel blockers that slow conduction in fast-channel tissues. They are subdivided into subclasses A, B, and C based on their kinetics of the sodium channel effects. Class Ib drugs have fast kinetics. The effects of these drugs may be reflected in the electrocardiogram as widening of the P wave, widening of the QRS complex, prolongation of the PR interval, or a combination. It is important to consult a healthcare professional for specific information on antiarrhythmic drugs and their mechanisms of action.
45
+
46
+ Query: What are the key causes of pitting and non-pitting edema?
47
+ Response: The key causes of pitting and non-pitting edema can include injury (such as burns, chemicals, toxins, or blunt trauma), severe sepsis causing vascular endothelial leakage, soft-tissue infections like cellulitis or necrotizing myofasciitis, and other inflammatory conditions. Pitting edema may present as painless, symmetrical swelling, while non-pitting edema may be focal and erythematous. It is important to seek clinical evaluation, cultures, and imaging studies as needed to determine the underlying cause of edema. For specific treatment options and management, it is advisable to consult a healthcare professional.
48
+
49
+ Query: tell me about this medicine: Rhinathiol Antirhinitis 10/4 compr.
50
+ Response: Rhinathiol Antirhinitis 10/4 compr. is a medication in the form of tablets. It is indicated for the treatment of rhinitis, which is inflammation of the mucous membrane of the nose. This medication may help reduce symptoms such as nasal congestion, runny nose, and sneezing associated with rhinitis.
51
+
52
+ It is important to note that this information is general and for educational purposes only. It is always recommended to consult with a healthcare professional or pharmacist before starting any new medication. Additionally, common side effects of Rhinathiol Antirhinitis 10/4 compr. may include drowsiness, dizziness, dry mouth, and blurred vision. If you experience any severe or persistent side effects, it is important to seek medical attention.
53
+
54
+ Query: What are the first-line treatments for type 2 diabetes, and how do they differ in mechanism?
55
+ Response: The first-line treatments for type 2 diabetes include oral antihyperglycemics, non-insulin injectable medications such as glucagon-like peptide-1 (GLP-1) receptor agonists, insulin, or a combination of these. These medications work by helping to lower blood sugar levels in different ways. Oral antihyperglycemics help improve insulin sensitivity or reduce glucose production in the liver. GLP-1 receptor agonists stimulate insulin secretion and reduce glucagon secretion. Insulin helps regulate blood sugar levels by replacing or supplementing the body's natural insulin production. It is important to consult healthcare professionals for personalized treatment recommendations and to discuss potential side effects of these medications.
56
+
57
+ Query: What are the first-line treatments for type 2 diabetes, and how do they differ in mechanism?
58
+ Response: The first-line treatments for type 2 diabetes include oral antihyperglycemics, non-insulin injectable medications such as glucagon-like peptide-1 (GLP-1) receptor agonists, insulin, or a combination of these. These medications work by helping to lower blood sugar levels in different ways. Oral antihyperglycemics help improve insulin sensitivity or reduce glucose production in the liver. GLP-1 receptor agonists help increase insulin secretion and decrease glucagon secretion. Insulin therapy helps regulate blood sugar levels by replacing or supplementing the body's insulin production. It is important to consult healthcare professionals for personalized treatment recommendations and to discuss potential side effects of these medications.
59
+
60
+ Query: What are the first-line treatments for type 2 diabetes, and how do they differ in mechanism?
61
+ Response: The first-line treatments for type 2 diabetes include oral antihyperglycemics, non-insulin injectable medications such as glucagon-like peptide-1 (GLP-1) receptor agonists, insulin, or a combination of these. These medications work by helping to lower blood sugar levels in different ways. Oral antihyperglycemics help improve insulin sensitivity or reduce glucose production in the liver. GLP-1 receptor agonists help increase insulin secretion and decrease glucagon secretion. Insulin therapy helps regulate blood sugar levels by replacing or supplementing the body's insulin production. It is important to consult healthcare professionals for personalized treatment plans and to discuss potential side effects of these medications.
62
+
63
+ Query: who is John Cena?
64
+ Response: I'm sorry, but based on the provided context, there is no information related to John Cena. If you have any other medical questions or concerns, please feel free to ask. It is always recommended to consult with a healthcare professional for personalized medical advice.
65
+
66
+ Query: who is John Cena?
67
+ Response: I'm sorry, but based on the provided context, there is no information related to a person named John Cena. If you are looking for information on a specific individual, it is recommended to consult reliable sources or conduct a separate search. If you have any medical questions or concerns, it is always best to consult a healthcare professional for personalized advice and guidance.
68
+
preprocess.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Union, Optional
4
+ import pandas as pd
5
+
6
+ # Configure logging
7
+ logging.basicConfig(
8
+ level=logging.INFO,
9
+ format='%(asctime)s - %(levelname)s - %(message)s'
10
+ )
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class DataPreprocessor:
14
+ """A class to handle data preprocessing operations for different file formats."""
15
+
16
+ @staticmethod
17
+ def _preprocess_dataframe(df: pd.DataFrame) -> pd.DataFrame:
18
+ """
19
+ Applies standard preprocessing steps to a DataFrame.
20
+
21
+ Args:
22
+ df (pd.DataFrame): Input DataFrame to preprocess
23
+
24
+ Returns:
25
+ pd.DataFrame: Preprocessed DataFrame
26
+ """
27
+ try:
28
+ # Convert text columns to lowercase for standardization
29
+ df = df.map(lambda x: x.lower() if isinstance(x, str) else x)
30
+
31
+ # Drop columns that are fully null
32
+ df = df.dropna(axis=1, how='all')
33
+
34
+ # Fill remaining NaN values with empty strings
35
+ df = df.fillna('')
36
+
37
+ # Remove duplicate rows
38
+ df = df.drop_duplicates()
39
+
40
+ return df
41
+
42
+ except Exception as e:
43
+ logger.error(f"Error during DataFrame preprocessing: {str(e)}")
44
+ raise
45
+
46
+ @classmethod
47
+ def preprocess_msd(cls,
48
+ file_path: Union[str, Path],
49
+ output_path: Union[str, Path],
50
+ sheet_name: Optional[Union[str, int]] = 0) -> pd.DataFrame:
51
+ """
52
+ Preprocesses an MSD Excel file and saves the result.
53
+
54
+ Args:
55
+ file_path: Path to the Excel file
56
+ output_path: Directory path for the output file
57
+ sheet_name: Sheet name or index to load (default: 0)
58
+
59
+ Returns:
60
+ pd.DataFrame: Preprocessed DataFrame
61
+
62
+ Raises:
63
+ FileNotFoundError: If input file doesn't exist
64
+ PermissionError: If output directory is not writable
65
+ """
66
+ try:
67
+ # Convert to Path objects
68
+ file_path = Path(file_path)
69
+ output_path = Path(output_path)
70
+
71
+ # Validate input file
72
+ if not file_path.exists():
73
+ raise FileNotFoundError(f"Input file not found: {file_path}")
74
+
75
+ # Ensure output directory exists
76
+ output_path.mkdir(parents=True, exist_ok=True)
77
+
78
+ logger.info(f"Processing MSD file: {file_path}")
79
+ df = pd.read_excel(file_path, sheet_name=sheet_name)
80
+
81
+ # Apply preprocessing
82
+ df = cls._preprocess_dataframe(df)
83
+
84
+ # Save processed file
85
+ output_file = output_path / "msd_processed.csv"
86
+ df.to_csv(output_file, index=False)
87
+ logger.info(f"Saved processed file to: {output_file}")
88
+
89
+ return df
90
+
91
+ except Exception as e:
92
+ logger.error(f"Error processing MSD file: {str(e)}")
93
+ raise
94
+
95
+ @classmethod
96
+ def preprocess_cbip(cls,
97
+ input_dir: Union[str, Path],
98
+ output_dir: Union[str, Path]) -> None:
99
+ """
100
+ Preprocesses all CSV files in the CBIP directory.
101
+
102
+ Args:
103
+ input_dir: Directory containing input CSV files
104
+ output_dir: Directory for output files
105
+
106
+ Raises:
107
+ FileNotFoundError: If input directory doesn't exist
108
+ PermissionError: If output directory is not writable
109
+ """
110
+ try:
111
+ # Convert to Path objects
112
+ input_dir = Path(input_dir)
113
+ output_dir = Path(output_dir)
114
+
115
+ # Validate input directory
116
+ if not input_dir.exists():
117
+ raise FileNotFoundError(f"Input directory not found: {input_dir}")
118
+
119
+ # Ensure output directory exists
120
+ output_dir.mkdir(parents=True, exist_ok=True)
121
+
122
+ # Process all CSV files
123
+ csv_files = list(input_dir.rglob("*.csv"))
124
+ if not csv_files:
125
+ logger.warning(f"No CSV files found in: {input_dir}")
126
+ return
127
+
128
+ for file_path in csv_files:
129
+ try:
130
+ logger.info(f"Processing CBIP file: {file_path}")
131
+
132
+ # Read CSV file
133
+ df = pd.read_csv(
134
+ file_path,
135
+ delimiter=';',
136
+ quotechar='"',
137
+ skip_blank_lines=True
138
+ )
139
+
140
+ # Apply preprocessing
141
+ df = cls._preprocess_dataframe(df)
142
+
143
+ # Save processed file
144
+ output_file = output_dir / file_path.name
145
+ df.to_csv(output_file, index=False)
146
+ logger.info(f"Saved processed file to: {output_file}")
147
+
148
+ except Exception as e:
149
+ logger.error(f"Error processing {file_path}: {str(e)}")
150
+ continue
151
+
152
+ except Exception as e:
153
+ logger.error(f"Error processing CBIP directory: {str(e)}")
154
+ raise
155
+
156
+ def main():
157
+ """Main execution function."""
158
+ try:
159
+ import os
160
+ import argparse
161
+ from pathlib import Path
162
+
163
+ # Create processed_data directory in current working directory
164
+ output_base = Path.cwd() / "processed_data"
165
+ msd_output = output_base / "msd"
166
+ cbip_output = output_base / "cbip"
167
+
168
+ parser = argparse.ArgumentParser(description='Process MSD and CBIP data files.')
169
+ parser.add_argument('--msd-input', required=True, help='Path to MSD Excel file')
170
+ parser.add_argument('--cbip-input', required=True, help='Input directory containing CBIP CSV files')
171
+
172
+ args = parser.parse_args()
173
+
174
+ preprocessor = DataPreprocessor()
175
+
176
+ # Process MSD file
177
+ preprocessor.preprocess_msd(
178
+ args.msd_input,
179
+ msd_output
180
+ )
181
+
182
+ # Process CBIP directory
183
+ preprocessor.preprocess_cbip(
184
+ args.cbip_input,
185
+ cbip_output
186
+ )
187
+
188
+ except Exception as e:
189
+ logger.error(f"Main execution failed: {str(e)}")
190
+ raise
191
+
192
+ if __name__ == "__main__":
193
+ main()
rag.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Optional
5
+ from dotenv import load_dotenv
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_huggingface import HuggingFaceEmbeddings
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.retrievers import MergerRetriever
12
+
13
+ # Load environment variables from .env file
14
+ load_dotenv()
15
+
16
+ # Configure logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
20
+ )
21
+ logger = logging.getLogger(__name__)
22
+
23
+ def load_faiss_index(folder_path: str, model_name: str) -> Optional[FAISS]:
24
+ """
25
+ Load a FAISS index with a specific embedding model.
26
+
27
+ Args:
28
+ folder_path: Path to the FAISS index folder
29
+ model_name: Name of the HuggingFace embedding model
30
+
31
+ Returns:
32
+ FAISS: Loaded FAISS index object
33
+
34
+ Raises:
35
+ ValueError: If the folder path doesn't exist
36
+ """
37
+ try:
38
+ if not os.path.exists(folder_path):
39
+ raise ValueError(f"FAISS index folder not found: {folder_path}")
40
+
41
+ logger.info(f"Loading FAISS index from {folder_path}")
42
+ embeddings = HuggingFaceEmbeddings(model_name=model_name)
43
+ return FAISS.load_local(
44
+ folder_path=folder_path,
45
+ embeddings=embeddings,
46
+ allow_dangerous_deserialization=True
47
+ )
48
+ except Exception as e:
49
+ logger.error(f"Error loading FAISS index: {str(e)}")
50
+ raise
51
+
52
+ def generate_answer(query: str) -> str:
53
+ """
54
+ Generate an answer for the given query using RAG.
55
+
56
+ Args:
57
+ query: User's question
58
+
59
+ Returns:
60
+ str: Generated answer
61
+
62
+ Raises:
63
+ ValueError: If query is empty or required files are missing
64
+ """
65
+ try:
66
+ if not query or not query.strip():
67
+ raise ValueError("Query cannot be empty")
68
+
69
+ # Get the current directory and construct paths
70
+ current_dir = Path(__file__).parent
71
+ vectors_dir = current_dir / "vectors_data"
72
+
73
+ # Validate vectors directory exists
74
+ if not vectors_dir.exists():
75
+ raise ValueError(f"Vectors directory not found at {vectors_dir}")
76
+
77
+ # Load FAISS indices
78
+ logger.info("Loading FAISS indices...")
79
+ data_vec = load_faiss_index(
80
+ str(vectors_dir / "msd_data_vec"),
81
+ "sentence-transformers/all-MiniLM-L12-v2"
82
+ )
83
+ med_vec = load_faiss_index(
84
+ str(vectors_dir / "med_data_vec"),
85
+ "sentence-transformers/all-MiniLM-L12-v2"
86
+ )
87
+
88
+ # Create the LLM instance
89
+ llm = ChatOpenAI(
90
+ model="gpt-3.5-turbo",
91
+ temperature=0,
92
+ openai_api_key="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"
93
+ )
94
+
95
+ # Define the prompt template
96
+ template = """You are a helpful medical information assistant. Use the following pieces of context to answer the medical question at the end.
97
+
98
+ Important notes:
99
+ - Base your answer strictly on the provided context and understandable for all readers
100
+ - If you don't know the answer, just say that you don't know
101
+ - Include relevant disclaimers about consulting healthcare professionals
102
+ - If suggesting medications (upon question), mention potential side effects if provided in the context
103
+ - Highlight if the information is general knowledge or requires professional medical advice
104
+
105
+ Context: {context}
106
+
107
+ Question: {question}
108
+
109
+ Medical Information Assistant:"""
110
+
111
+ QA_CHAIN_PROMPT = PromptTemplate(
112
+ input_variables=["context", "question"],
113
+ template=template
114
+ )
115
+
116
+ # Initialize and combine retrievers
117
+ logger.info("Setting up retrieval chain...")
118
+ data_retriever = data_vec.as_retriever()
119
+ med_retriever = med_vec.as_retriever()
120
+ combined_retriever = MergerRetriever(
121
+ retrievers=[data_retriever, med_retriever]
122
+ )
123
+
124
+ # Initialize the RetrievalQA chain
125
+ qa_chain = RetrievalQA.from_chain_type(
126
+ llm=llm,
127
+ retriever=combined_retriever,
128
+ return_source_documents=True,
129
+ chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
130
+ )
131
+
132
+ # Run the chain
133
+ logger.info("Generating answer...")
134
+ result = qa_chain.invoke({"query": query})
135
+ logger.info("Answer generated successfully")
136
+
137
+ # Extracting the relevant documents from the result
138
+ extracted_docs = result.get("source_documents", [])
139
+ logger.info(f"Extracted documents: {extracted_docs}") # Log the extracted documents
140
+
141
+ # New organized printing of extracted documents
142
+ print("\nExtracted Documents:")
143
+ for doc in extracted_docs:
144
+ print(f"Source: {doc.metadata['source']}, Row: {doc.metadata['row']}")
145
+ print(f"Content: {doc.page_content}\n")
146
+
147
+ return result["result"]
148
+
149
+ except Exception as e:
150
+ logger.error(f"Error generating answer: {str(e)}")
151
+ raise
152
+
153
+ def main():
154
+ """
155
+ Main function to demonstrate the usage of the RAG system.
156
+ """
157
+ try:
158
+ # Example usage
159
+ query = "suggest me some medicine for bronchitis"
160
+ logger.info(f"Processing query: {query}")
161
+
162
+ response = generate_answer(query)
163
+ print("\nQuery:", query)
164
+ print("\nResponse:", response)
165
+
166
+ except Exception as e:
167
+ logger.error(f"Error in main function: {str(e)}")
168
+ print(f"An error occurred: {str(e)}")
169
+
170
+ if __name__ == "__main__":
171
+ main()
readme.md ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Building a Retrieval-Augmented Question-Answering System with FastAPI and LangChain
2
+ Develop a simple question-answering web service that leverages Retrieval-Augmented Generation (RAG) to provide answers based on a set of provided documents. The service will be built using Python, FastAPI, and LangChain.
3
+
4
+ ## Installation
5
+
6
+ 1. Clone or download this repository
7
+ ```
8
+ git clone https://github.com/shamim237/artech_med_bot.git
9
+ ```
10
+ 2. Create a virtual environment (optional but recommended):
11
+
12
+ ```
13
+ python -m venv venv
14
+ source venv/bin/activate # On Windows use `venv\Scripts\activate`
15
+ ```
16
+ 3. Install the required packages:
17
+ ```
18
+ pip install -r requirements.txt
19
+ ```
20
+
21
+ # Run the system
22
+
23
+ To run the system on your machine, follow the steps shown below:
24
+
25
+ ## **preprocess.py**:
26
+ This script for preprocessing MSD (Excel) and CBIP (CSV) data files. This script provides standardized data cleaning and transformation operations for both file formats.
27
+
28
+ ### Features
29
+ - Processes MSD Excel files and CBIP CSV files
30
+ - Standardizes text to lowercase
31
+ - Removes empty columns
32
+ - Handles missing values
33
+ - Eliminates duplicate rows
34
+ - Preserves original file structure
35
+ - Comprehensive logging
36
+
37
+ ### Usage
38
+ ```
39
+ python preprocess.py --msd-input "path/to/msd.xlsx" --cbip-input "path/to/cbip/directory"
40
+ ```
41
+ #### Arguments
42
+ - `--msd-input`: Path to the MSD Excel file
43
+ - `--cbip-input`: Path to the directory containing CBIP CSV files
44
+
45
+
46
+ #### Output
47
+
48
+ The script creates a `processed_data` directory in your current working directory with the following structure:
49
+ processed_data/
50
+ ├── msd/
51
+ │ └── msd_processed.csv
52
+ └── cbip/
53
+ └── [processed_csv_files]
54
+
55
+ ### Data Processing Steps
56
+
57
+ 1. **Text Standardization**: Converts all text to lowercase
58
+ 2. **Column Cleaning**: Removes columns that are completely empty
59
+ 3. **Missing Value Handling**: Fills NaN values with empty strings
60
+ 4. **Duplicate Removal**: Removes duplicate rows from the dataset
61
+
62
+ ### Error Handling
63
+
64
+ - The script includes comprehensive error handling and logging
65
+ - Errors are logged with timestamps and detailed messages
66
+ - Processing continues even if individual files fail
67
+
68
+
69
+ ## "vectorize.py":
70
+ This script processes CSV documents and creates FAISS vector stores using LangChain and Hugging Face embeddings. It's designed to handle both MSD (Master Service Data) and medical data sources, converting them into efficient searchable vector representations.
71
+
72
+ ### Features
73
+ - CSV document loading with support for multiple files
74
+ - Text chunking with configurable size and overlap
75
+ - FAISS vector store creation and persistence
76
+ - Comprehensive error handling and logging
77
+ - Support for Hugging Face embedding models
78
+
79
+ ### Configuration
80
+
81
+ The script uses the following default configuration:
82
+
83
+ - MSD Data Path: `./processed_data/msd/msd_processed.csv`
84
+ - Medical CSV Path: `./processed_data/cbip/*.csv`
85
+ - MSD Vector Output: `./vectors_data/msd_data_vec`
86
+ - Medical Vector Output: `./vectors_data/med_data_vec`
87
+ - Embedding Model: `sentence-transformers/all-MiniLM-L12-v2`
88
+
89
+ ### Usage
90
+ - Just run the script to get default output
91
+ ```
92
+ python -m vectorize.py
93
+ ```
94
+ - or change paths of the dataset
95
+
96
+ ## "rag.py":
97
+ This script implements a Retrieval-Augmented Generation (RAG) system using LangChain, FAISS vector store, and OpenAI's GPT-3.5 model. The system combines medical and general data sources to provide informed answers to user queries.
98
+
99
+ ### Features
100
+
101
+ - Dual vector store integration (medical and general data)
102
+ - HuggingFace embeddings using `all-MiniLM-L12-v2` model
103
+ - OpenAI GPT-3.5 for answer generation
104
+ - Comprehensive error handling and logging
105
+ - Environment variable support for API keys
106
+
107
+ ### Prerequisites
108
+
109
+ - OpenAI API key
110
+ - Create a `.env` file in the project root and add your OpenAI API key: OPENAI_API_KEY=your_api_key_here
111
+ - Required vector stores in the `vectors_data` directory:
112
+ - `msd_data_vec/` - General data vector store
113
+ - `med_data_vec/` - Medical data vector store
114
+
115
+ ### Usage
116
+ ```
117
+ python rag.py
118
+ ```
119
+
120
+ ## "app.py":
121
+ This script is a FastAPI-based REST API that generates answers to questions using RAG (Retrieval-Augmented Generation) technology.
122
+
123
+ ### Features
124
+
125
+ - Question answering endpoint with RAG integration
126
+ - Request ID tracking for all API calls
127
+ - Comprehensive error handling and logging
128
+ - Health check endpoint
129
+ - CORS support
130
+ - API documentation (Swagger UI and ReDoc)
131
+
132
+ ### Usage
133
+ ```
134
+ uvicorn app:app --reload
135
+ ```
136
+
137
+ The server will start on `http://localhost:8000`
138
+
139
+ ### API Endpoints
140
+
141
+ #### 1. Question Answering
142
+ - **Endpoint**: `/answer`
143
+ - **Method**: POST
144
+ - **Request Body**:
145
+ ```
146
+ {
147
+ "question": "What is an overactive bladder?"
148
+ }
149
+ ```
150
+ - **Response**:
151
+ ```
152
+ {
153
+ "answer": "The generated answer..."
154
+ }
155
+ ```
156
+
157
+ ## test_rag.py:
158
+
159
+ The test suite validates the functionality of:
160
+ - Individual data retrievers (medicine and general data)
161
+ - Combined retriever functionality
162
+ - Answer generation system
163
+ - Error handling for edge cases
164
+
165
+ ### Test Cases
166
+
167
+ The test suite includes the following test cases:
168
+
169
+ 1. `test_data_retriever`: Tests retrieval from general data store
170
+ 2. `test_med_retriever`: Tests retrieval from medical data store
171
+ 3. `test_combined_retriever`: Tests the merged retriever functionality
172
+ 4. `test_generate_answer`: Validates answer generation
173
+ 5. `test_empty_query`: Tests error handling for invalid inputs
174
+
175
+ ### Usage
176
+ ```
177
+ python -m unittest test_rag.py
178
+ ```
179
+
180
+ ### Vector Store Setup
181
+
182
+ The system expects two FAISS indices in the `vectors/` directory:
183
+ - `msd_data_vec`: General knowledge vector store
184
+ - `med_data_vec`: Medical knowledge vector store
185
+
186
+ Both indices use the `sentence-transformers/all-MiniLM-L12-v2` embedding model.
187
+
188
+ #### Notes
189
+
190
+ - Ensure all vector stores are properly initialized before running tests
191
+ - The system uses the MiniLM-L12-v2 model for embeddings
192
+ - Empty or whitespace-only queries will raise ValueError exceptions
193
+
194
+
195
+ ## test_app.py:
196
+
197
+ The test suite (`test_app.py`) validates the `/answer` endpoint's response to different types of requests, ensuring proper handling of both valid and invalid inputs.
198
+
199
+ ### Test Cases
200
+
201
+ The test suite includes the following test cases:
202
+
203
+ 1. **Valid Question Test**
204
+ - Verifies that the endpoint correctly processes a valid question
205
+ - Expects a 200 status code and an answer in the response
206
+
207
+ 2. **Empty Question Test**
208
+ - Validates handling of empty string inputs
209
+ - Expects a 422 status code (Pydantic validation error)
210
+
211
+ 3. **Whitespace Question Test**
212
+ - Checks handling of whitespace-only inputs
213
+ - Expects a 500 status code with an error message
214
+
215
+ 4. **Missing Question Field Test**
216
+ - Verifies behavior when the question field is omitted
217
+ - Expects a 422 status code (FastAPI validation error)
218
+
219
+ 5. **Invalid JSON Test**
220
+ - Tests handling of malformed JSON requests
221
+ - Expects a 422 status code (FastAPI validation error)
222
+
223
+ ### Usage
224
+ ```
225
+ python -m unittest test_app.py
226
+ ```
227
+
228
+ ## Assumptions and Trade-offs:
229
+ I generated and stored vector embeddings separately for disease/MSD data and medicine/CBIP data, believing that this separation would enhance the LLM's performance.
230
+
231
+ ## Comments:
232
+ The quality of responses from this RAG-based LLM can be further strengthened through the following steps:
233
+ - Organizing the disease-related dataset more systematically.
234
+ - Structuring the medicine-related dataset more effectively.
235
+ - Enhancing disease-treatment and drug recommendations through better-organized mappings.
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ openpyxl
3
+ langchain
4
+ langchain-community
5
+ langchain-huggingface
6
+ faiss-cpu
7
+ sentence-transformers
8
+ python-dotenv
9
+ langchain-openai
10
+ openai
11
+ fastapi
12
+ uvicorn
13
+ pydantic
test_app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from fastapi.testclient import TestClient
3
+ from app import app
4
+
5
+ class TestAnswerEndpoint(unittest.TestCase):
6
+ def setUp(self):
7
+ self.client = TestClient(app)
8
+
9
+ def test_valid_question(self):
10
+ response = self.client.post("/answer", json={"question": "What is an overactive bladder?"})
11
+ self.assertEqual(response.status_code, 200)
12
+ self.assertIn("answer", response.json())
13
+
14
+ def test_empty_question(self):
15
+ # Empty string fails Pydantic validation (constr(min_length=1))
16
+ response = self.client.post("/answer", json={"question": ""})
17
+ self.assertEqual(response.status_code, 422) # Pydantic validation error
18
+ self.assertIn("should have at least 1 character", response.json()["detail"][0]["msg"])
19
+
20
+ def test_whitespace_question(self):
21
+ # Whitespace-only string passes Pydantic but fails our custom validation
22
+ response = self.client.post("/answer", json={"question": " "})
23
+ self.assertEqual(response.status_code, 500) # Internal server error from our validation
24
+ self.assertEqual(response.json()["detail"], "Error processing request")
25
+
26
+ def test_missing_question_field(self):
27
+ response = self.client.post("/answer", json={})
28
+ self.assertEqual(response.status_code, 422) # FastAPI validation error
29
+
30
+ def test_invalid_json(self):
31
+ response = self.client.post("/answer", data="invalid json")
32
+ self.assertEqual(response.status_code, 422) # FastAPI validation error
33
+
34
+ if __name__ == '__main__':
35
+ unittest.main()
test_rag.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from rag import load_faiss_index, generate_answer
3
+ from langchain.retrievers import MergerRetriever
4
+
5
+ class TestRetrieval(unittest.TestCase):
6
+ @classmethod
7
+ def setUpClass(cls):
8
+ # Initialize FAISS indices
9
+ cls.data_vec = load_faiss_index("./vectors_data/msd_data_vec", "sentence-transformers/all-MiniLM-L12-v2")
10
+ cls.med_vec = load_faiss_index("./vectors_data/med_data_vec", "sentence-transformers/all-MiniLM-L12-v2")
11
+
12
+ # Initialize retrievers
13
+ cls.data_retriever = cls.data_vec.as_retriever()
14
+ cls.med_retriever = cls.med_vec.as_retriever()
15
+ # Combine both retrievers into a single retriever
16
+ cls.combined_retriever = MergerRetriever(retrievers=[cls.data_retriever, cls.med_retriever])
17
+
18
+ def test_data_retriever(self):
19
+ # Test the data retriever with a specific query
20
+ query = "what are the symptoms of diabetes?"
21
+ docs = self.data_retriever.get_relevant_documents(query)
22
+
23
+ # Assert that documents are returned and are not empty
24
+ self.assertIsNotNone(docs)
25
+ self.assertTrue(len(docs) > 0)
26
+ # Check if documents have content
27
+ self.assertTrue(all(doc.page_content.strip() != "" for doc in docs))
28
+
29
+ def test_med_retriever(self):
30
+ # Test the medical retriever with a specific query
31
+ query = "what are common antibiotics?"
32
+ docs = self.med_retriever.get_relevant_documents(query)
33
+
34
+ # Assert that documents are returned and are not empty
35
+ self.assertIsNotNone(docs)
36
+ self.assertTrue(len(docs) > 0)
37
+ self.assertTrue(all(doc.page_content.strip() != "" for doc in docs))
38
+
39
+ def test_combined_retriever(self):
40
+ # Test the combined retriever with a specific query
41
+ query = "what is the treatment for high blood pressure?"
42
+ docs = self.combined_retriever.get_relevant_documents(query)
43
+
44
+ # Assert that documents are returned and are not empty
45
+ self.assertIsNotNone(docs)
46
+ self.assertTrue(len(docs) > 0)
47
+ self.assertTrue(all(doc.page_content.strip() != "" for doc in docs))
48
+
49
+ def test_generate_answer(self):
50
+ # Test the answer generation function with a specific query
51
+ query = "what are the side effects of aspirin?"
52
+ response = generate_answer(query)
53
+
54
+ # Assert that a valid response is returned
55
+ self.assertIsNotNone(response)
56
+ self.assertIsInstance(response, str)
57
+ self.assertTrue(len(response) > 0)
58
+
59
+ def test_empty_query(self):
60
+ # Test the answer generation function with an empty query
61
+ with self.assertRaises(ValueError): # More specific exception
62
+ generate_answer("")
63
+
64
+ # Test the answer generation function with a whitespace-only query
65
+ with self.assertRaises(ValueError):
66
+ generate_answer(" ") # Test whitespace-only query
67
+
68
+ if __name__ == '__main__':
69
+ unittest.main()
vectorize.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import List, Optional
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_community.document_loaders import CSVLoader
7
+ from langchain_huggingface import HuggingFaceEmbeddings
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.docstore.document import Document
10
+
11
+ # Configure logging
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(levelname)s - %(message)s',
15
+ handlers=[
16
+ logging.FileHandler('vectorize.log'),
17
+ logging.StreamHandler()
18
+ ]
19
+ )
20
+
21
+ class VectorizationError(Exception):
22
+ """Custom exception for vectorization-related errors"""
23
+ pass
24
+
25
+ def load_csv_documents(csv_file_path: str) -> List[Document]:
26
+ """
27
+ Load CSV documents from the specified path.
28
+
29
+ Args:
30
+ csv_file_path (str): Path pattern to search for CSV files.
31
+
32
+ Returns:
33
+ List[Document]: A list of documents loaded from the CSV files.
34
+
35
+ Raises:
36
+ VectorizationError: If no CSV files are found or if there's an error loading them.
37
+ """
38
+ try:
39
+ documents = []
40
+ csv_files = list(glob.glob(csv_file_path))
41
+
42
+ if not csv_files:
43
+ raise VectorizationError(f"No CSV files found at path: {csv_file_path}")
44
+
45
+ for csv_file in csv_files:
46
+ logging.info(f"Loading CSV file: {csv_file}")
47
+ loader = CSVLoader(csv_file, encoding="utf-8")
48
+ documents.extend(loader.load())
49
+
50
+ logging.info(f"Successfully loaded {len(documents)} documents from {len(csv_files)} CSV files")
51
+ return documents
52
+
53
+ except Exception as e:
54
+ raise VectorizationError(f"Error loading CSV documents: {str(e)}")
55
+
56
+ def create_vector_store(
57
+ documents: List[Document],
58
+ embeddings_model: HuggingFaceEmbeddings,
59
+ output_path: str,
60
+ chunk_size: int = 500,
61
+ chunk_overlap: int = 50
62
+ ) -> Optional[FAISS]:
63
+ """
64
+ Create and save a FAISS vector store from documents.
65
+
66
+ Args:
67
+ documents (List[Document]): List of documents to vectorize
68
+ embeddings_model (HuggingFaceEmbeddings): The embeddings model to use
69
+ output_path (str): Path to save the FAISS index
70
+ chunk_size (int, optional): Size of text chunks. Defaults to 500.
71
+ chunk_overlap (int, optional): Overlap between chunks. Defaults to 50.
72
+
73
+ Returns:
74
+ Optional[FAISS]: The created FAISS index if successful, None otherwise
75
+ """
76
+ try:
77
+ text_splitter = RecursiveCharacterTextSplitter(
78
+ chunk_size=chunk_size,
79
+ chunk_overlap=chunk_overlap
80
+ )
81
+
82
+ chunked_documents = text_splitter.split_documents(documents)
83
+ logging.info(f"Created {len(chunked_documents)} chunks from {len(documents)} documents")
84
+
85
+ faiss_index = FAISS.from_documents(chunked_documents, embeddings_model)
86
+
87
+
88
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
89
+
90
+ faiss_index.save_local(output_path)
91
+ logging.info(f"Successfully saved FAISS index to {output_path}")
92
+
93
+ return faiss_index
94
+
95
+ except Exception as e:
96
+ logging.error(f"Error creating vector store: {str(e)}")
97
+ return None
98
+
99
+ def main():
100
+ try:
101
+ # Configuration with relative paths
102
+ config = {
103
+ 'msd_data_path': "./processed_data/msd/msd_processed.csv",
104
+ 'medical_csv_path': "./processed_data/cbip/*.csv",
105
+ 'msd_vector_path': "./vectors_data/msd_data_vec",
106
+ 'medical_vector_path': "./vectors_data/med_data_vec",
107
+ 'model_name': "sentence-transformers/all-MiniLM-L12-v2"
108
+ }
109
+
110
+ # Create vectors_data directory if it doesn't exist
111
+ Path("./vectors_data").mkdir(exist_ok=True)
112
+
113
+ logging.info("Starting vectorization process")
114
+
115
+ # Load documents
116
+ msd_data_documents = load_csv_documents(config['msd_data_path'])
117
+ medical_documents = load_csv_documents(config['medical_csv_path'])
118
+
119
+ # Initialize embeddings model
120
+ logging.info(f"Initializing embeddings model: {config['model_name']}")
121
+ embeddings_model = HuggingFaceEmbeddings(model_name=config['model_name'])
122
+
123
+ # Create vector stores
124
+ msd_index = create_vector_store(
125
+ msd_data_documents,
126
+ embeddings_model,
127
+ config['msd_vector_path']
128
+ )
129
+
130
+ medical_index = create_vector_store(
131
+ medical_documents,
132
+ embeddings_model,
133
+ config['medical_vector_path']
134
+ )
135
+
136
+ if msd_index and medical_index:
137
+ logging.info("Vectorization process completed successfully")
138
+ else:
139
+ logging.error("Vectorization process completed with errors")
140
+
141
+ except VectorizationError as ve:
142
+ logging.error(f"Vectorization error: {str(ve)}")
143
+ raise
144
+ except Exception as e:
145
+ logging.error(f"Unexpected error: {str(e)}")
146
+ raise
147
+
148
+ if __name__ == "__main__":
149
+ main()
vectors_data/med_data_vec/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:178a3e2e7fea1a9ca64ca39dd32443d3c90f7fcd7602328470f65a896c43557d
3
+ size 13323309
vectors_data/med_data_vec/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f1a9804e83879fc36cc98e68dc88564bd16ce3d552eb0b0b54da145f8d92e14
3
+ size 2648619
vectors_data/msd_data_vec/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:504d3ef09a0197a1f77de7865f23a801596ece8722e024cdcf07859c9ae53a52
3
+ size 370983981
vectors_data/msd_data_vec/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57f6ad56b1c1462bf030ab66e7f37bf2616c6c8b0cfcd9841df39606a8157eba
3
+ size 135114811