Spaces:
Sleeping
Sleeping
initial commit
Browse files- .gitattributes +2 -0
- .gitignore +9 -0
- app.py +202 -0
- gradio_app.py +38 -0
- log.txt +68 -0
- preprocess.py +193 -0
- rag.py +171 -0
- readme.md +235 -0
- requirements.txt +13 -0
- test_app.py +35 -0
- test_rag.py +69 -0
- vectorize.py +149 -0
- vectors_data/med_data_vec/index.faiss +3 -0
- vectors_data/med_data_vec/index.pkl +3 -0
- vectors_data/msd_data_vec/index.faiss +3 -0
- vectors_data/msd_data_vec/index.pkl +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
vectors_data/med_data_vec/index.faiss filter=lfs diff=lfs merge=lfs -text
|
37 |
+
vectors_data/msd_data_vec/index.faiss filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
.env
|
4 |
+
*.env
|
5 |
+
chatbot/
|
6 |
+
processed_data/
|
7 |
+
src/
|
8 |
+
vectors_data/
|
9 |
+
logs/
|
app.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import asyncio
|
3 |
+
import uuid
|
4 |
+
from datetime import datetime
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Dict, Optional
|
7 |
+
|
8 |
+
from fastapi import FastAPI, HTTPException, Request
|
9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
10 |
+
from fastapi.responses import JSONResponse
|
11 |
+
from pydantic import BaseModel, constr, Field, validator, constr
|
12 |
+
from rag import generate_answer # Importing the function to generate answers
|
13 |
+
|
14 |
+
# Create logs directory if it doesn't exist
|
15 |
+
Path("logs").mkdir(exist_ok=True) # Ensure logs directory is available for logging
|
16 |
+
|
17 |
+
# Initialize FastAPI app with metadata
|
18 |
+
app = FastAPI(
|
19 |
+
title="Question Answering API", # Title of the API
|
20 |
+
description="API for generating answers using RAG", # Description of the API
|
21 |
+
version="1.0.0", # Version of the API
|
22 |
+
docs_url="/docs", # URL for API documentation
|
23 |
+
redoc_url="/redoc" # URL for ReDoc documentation
|
24 |
+
)
|
25 |
+
|
26 |
+
# Configure CORS
|
27 |
+
app.add_middleware(
|
28 |
+
CORSMiddleware,
|
29 |
+
allow_origins=["*"], # Allow all origins; modify in production for security
|
30 |
+
allow_credentials=True, # Allow credentials to be included in requests
|
31 |
+
allow_methods=["*"], # Allow all HTTP methods
|
32 |
+
allow_headers=["*"], # Allow all headers
|
33 |
+
)
|
34 |
+
|
35 |
+
# Configure detailed logging
|
36 |
+
logging.basicConfig(
|
37 |
+
level=logging.INFO, # Set logging level to INFO
|
38 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # Log format
|
39 |
+
handlers=[ # Handlers for logging output
|
40 |
+
logging.StreamHandler(), # Log to console
|
41 |
+
logging.FileHandler(f'logs/app_{datetime.now().strftime("%Y%m%d")}.log') # Log to file
|
42 |
+
]
|
43 |
+
)
|
44 |
+
logger = logging.getLogger(__name__) # Create a logger for this module
|
45 |
+
|
46 |
+
def generate_request_id() -> str:
|
47 |
+
"""Generate a unique request ID.""" # Docstring for function
|
48 |
+
return str(uuid.uuid4()) # Return a new UUID as a string
|
49 |
+
|
50 |
+
class QuestionRequest(BaseModel):
|
51 |
+
"""Request model for question answering endpoint.""" # Docstring for model
|
52 |
+
question: constr(min_length=1) = Field( # Field for the question with validation
|
53 |
+
...,
|
54 |
+
description="The question to be answered", # Description of the field
|
55 |
+
example="What is the capital of France?" # Example question
|
56 |
+
)
|
57 |
+
|
58 |
+
@validator('question') # Validator for the question field
|
59 |
+
def validate_question(cls, v):
|
60 |
+
if v.strip() == "": # Check if the question is empty
|
61 |
+
raise HTTPException(
|
62 |
+
status_code=500, # Raise HTTP 500 if empty
|
63 |
+
detail="Error processing request" # Error message
|
64 |
+
)
|
65 |
+
return v.strip() # Return the trimmed question
|
66 |
+
|
67 |
+
class ErrorResponse(BaseModel):
|
68 |
+
"""Standard error response model.""" # Docstring for error response model
|
69 |
+
detail: str # Detail of the error
|
70 |
+
request_id: Optional[str] = None # Optional request ID for tracking
|
71 |
+
|
72 |
+
@app.middleware("http")
|
73 |
+
async def add_request_id(request: Request, call_next):
|
74 |
+
"""Middleware to add request ID to all requests.""" # Docstring for middleware
|
75 |
+
request_id = generate_request_id() # Generate a new request ID
|
76 |
+
request.state.request_id = request_id # Store request ID in request state
|
77 |
+
response = await call_next(request) # Process the request
|
78 |
+
response.headers["X-Request-ID"] = request_id # Add request ID to response headers
|
79 |
+
return response # Return the response
|
80 |
+
|
81 |
+
@app.exception_handler(HTTPException)
|
82 |
+
async def http_exception_handler(request: Request, exc: HTTPException):
|
83 |
+
"""Custom exception handler to include request ID.""" # Docstring for exception handler
|
84 |
+
return JSONResponse(
|
85 |
+
status_code=exc.status_code, # Return the status code from the exception
|
86 |
+
content={"detail": exc.detail, "request_id": request.state.request_id} # Include error details and request ID
|
87 |
+
)
|
88 |
+
|
89 |
+
@app.post("/answer",
|
90 |
+
response_model=Dict[str, str], # Response model for the endpoint
|
91 |
+
summary="Generate an answer for the given question", # Summary of the endpoint
|
92 |
+
response_description="Returns the generated answer", # Description of the response
|
93 |
+
responses={ # Possible responses
|
94 |
+
400: {"model": ErrorResponse}, # Bad request response
|
95 |
+
500: {"model": ErrorResponse}, # Internal server error response
|
96 |
+
504: {"model": ErrorResponse} # Gateway timeout response
|
97 |
+
}
|
98 |
+
)
|
99 |
+
async def get_answer(request: Request, question_request: QuestionRequest) -> Dict[str, str]:
|
100 |
+
"""
|
101 |
+
Generate an answer for the given question.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
request: FastAPI request object # Description of request argument
|
105 |
+
question_request: The question request model # Description of question_request argument
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
Dict containing the generated answer # Description of return value
|
109 |
+
|
110 |
+
Raises:
|
111 |
+
HTTPException: For various error conditions # Description of possible exceptions
|
112 |
+
"""
|
113 |
+
request_id = request.state.request_id # Retrieve request ID from state
|
114 |
+
logger.info("Processing request %s - Question: %s", request_id, question_request.question) # Log the processing request
|
115 |
+
|
116 |
+
try:
|
117 |
+
# Additional validation
|
118 |
+
if len(question_request.question) > 1000: # Check if question exceeds max length
|
119 |
+
logger.warning("Request %s: Question exceeds maximum length", request_id) # Log warning
|
120 |
+
raise HTTPException(
|
121 |
+
status_code=400, # Raise HTTP 400 for bad request
|
122 |
+
detail="Question length exceeds maximum allowed characters (1000)" # Error message
|
123 |
+
)
|
124 |
+
|
125 |
+
# Convert generate_answer to async if it's not already
|
126 |
+
async def async_generate_answer(question: str):
|
127 |
+
loop = asyncio.get_event_loop() # Get the current event loop
|
128 |
+
return await loop.run_in_executor(None, generate_answer, question) # Run the generate_answer function in executor
|
129 |
+
|
130 |
+
# Generate answer with timeout handling
|
131 |
+
answer = await asyncio.wait_for(
|
132 |
+
async_generate_answer(question_request.question), # Await the answer generation
|
133 |
+
timeout=120.0 # 120 second timeout
|
134 |
+
)
|
135 |
+
|
136 |
+
logger.info("Request %s: Successfully generated answer", request_id) # Log successful answer generation
|
137 |
+
return {"answer": answer} # Return the generated answer
|
138 |
+
|
139 |
+
except asyncio.TimeoutError: # Handle timeout errors
|
140 |
+
logger.error("Request %s: Generation timeout", request_id) # Log timeout error
|
141 |
+
raise HTTPException(
|
142 |
+
status_code=504, # Raise HTTP 504 for timeout
|
143 |
+
detail="Answer generation timed out" # Error message
|
144 |
+
)
|
145 |
+
|
146 |
+
except ValueError as ve: # Handle value errors
|
147 |
+
logger.error("Request %s: Validation error - %s", request_id, str(ve)) # Log validation error
|
148 |
+
raise HTTPException(
|
149 |
+
status_code=400, # Raise HTTP 400 for bad request
|
150 |
+
detail=str(ve) # Return the validation error message
|
151 |
+
)
|
152 |
+
|
153 |
+
except Exception as e: # Handle all other exceptions
|
154 |
+
logger.error(
|
155 |
+
"Request %s: Unexpected error - %s",
|
156 |
+
request_id,
|
157 |
+
str(e),
|
158 |
+
exc_info=True # Include stack trace in logs
|
159 |
+
)
|
160 |
+
raise HTTPException(
|
161 |
+
status_code=500, # Raise HTTP 500 for internal server error
|
162 |
+
detail="Internal server error occurred. Please try again later." # Error message
|
163 |
+
)
|
164 |
+
|
165 |
+
@app.get("/health",
|
166 |
+
summary="Health check endpoint", # Summary of the health check endpoint
|
167 |
+
response_description="Returns the API health status" # Description of the response
|
168 |
+
)
|
169 |
+
async def health_check():
|
170 |
+
"""Health check endpoint to verify API is running.""" # Docstring for health check
|
171 |
+
return {"status": "healthy"} # Return health status
|
172 |
+
|
173 |
+
@app.on_event("startup")
|
174 |
+
async def startup_event():
|
175 |
+
"""Initialize any resources on startup.""" # Docstring for startup event
|
176 |
+
logger.info("Starting up Question Answering API") # Log startup event
|
177 |
+
# Add any initialization code here (e.g., loading models)
|
178 |
+
|
179 |
+
@app.on_event("shutdown")
|
180 |
+
async def shutdown_event():
|
181 |
+
"""Cleanup any resources on shutdown.""" # Docstring for shutdown event
|
182 |
+
logger.info("Shutting down Question Answering API") # Log shutdown event
|
183 |
+
# Add cleanup code here
|
184 |
+
|
185 |
+
if __name__ == "__main__": # Entry point for the application
|
186 |
+
import uvicorn # Import Uvicorn for running the app
|
187 |
+
|
188 |
+
# Load configuration from environment variables or use defaults
|
189 |
+
host = "0.0.0.0" # Host address
|
190 |
+
port = 8000 # Port number
|
191 |
+
|
192 |
+
logger.info(f"Starting server on {host}:{port}") # Log server start
|
193 |
+
|
194 |
+
uvicorn.run(
|
195 |
+
"app:app", # Application to run
|
196 |
+
host=host, # Host address
|
197 |
+
port=port, # Port number
|
198 |
+
reload=False, # Disable auto-reload in production
|
199 |
+
log_level="info", # Set log level
|
200 |
+
access_log=True, # Enable access logging
|
201 |
+
workers=1 # Number of worker processes
|
202 |
+
)
|
gradio_app.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from rag import generate_answer # Import the generate_answer function from rag.py
|
4 |
+
|
5 |
+
def chat_with_bot(query: str) -> tuple:
|
6 |
+
"""
|
7 |
+
Chat with the bot using the provided query.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
query: User's question
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
tuple: Generated answer and extracted contents
|
14 |
+
"""
|
15 |
+
response = generate_answer(query) # Call the existing generate_answer function
|
16 |
+
|
17 |
+
return response
|
18 |
+
|
19 |
+
def main():
|
20 |
+
"""
|
21 |
+
Main function to create and launch the Gradio interface.
|
22 |
+
"""
|
23 |
+
try:
|
24 |
+
# Create Gradio interface
|
25 |
+
interface = gr.Interface(
|
26 |
+
fn=chat_with_bot, # Use the chat function
|
27 |
+
inputs="text", # Input type
|
28 |
+
outputs="text", # Output type
|
29 |
+
title="RAG-LLM based Medical Chatbot", # Title of the app
|
30 |
+
description="Ask your medical questions and get answers from the chatbot." # Description
|
31 |
+
)
|
32 |
+
interface.launch(share=True) # Launch the Gradio app
|
33 |
+
|
34 |
+
except Exception as e:
|
35 |
+
print(f"An error occurred: {str(e)}")
|
36 |
+
|
37 |
+
if __name__ == "__main__":
|
38 |
+
main()
|
log.txt
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Query: who is john cena?
|
2 |
+
Response: I'm sorry, but based on the provided context, there is no information related to "John Cena" or any medical relevance to that name. If you have any other medical questions or concerns, please feel free to ask. Remember, for personalized medical advice, it is always best to consult with a healthcare professional.
|
3 |
+
|
4 |
+
Query: i have headache, suggest me some medicine
|
5 |
+
Response: Based on the context provided, if you have a headache, you can consider taking Ibuprofen EG syrup suspension. However, it is important to consult with a healthcare professional before taking any medication, especially if you have other underlying health conditions or are taking other medications. Ibuprofen may cause side effects such as injection site reaction, tenderness, fatigue, myalgia, arthralgia, fever, and nausea. If your headache persists or worsens, it is advisable to seek medical advice for proper evaluation and treatment.
|
6 |
+
|
7 |
+
Query: what is the best medicine when i have a headache?
|
8 |
+
Response: The best medicine for a headache can vary depending on the individual and the underlying cause of the headache. In general, over-the-counter pain relievers such as acetaminophen (oral or IV), ibuprofen, or naproxen can be effective for mild to moderate headaches. However, it is important to consult with a healthcare professional before taking any medication, especially if you are immunosuppressed or have papilledema. Additionally, it is important to avoid overusing medications, as this can lead to medication overuse headaches. In some cases, cognitive therapy, biofeedback, and education can also be helpful in managing headaches. If you have frequent headaches, it is recommended to seek medical advice for proper diagnosis and treatment.
|
9 |
+
|
10 |
+
Query: I have Diarrhea, suggest some medicine
|
11 |
+
Response: Based on the provided context, it is important to first consult a healthcare professional for a proper diagnosis and treatment plan. However, if you are experiencing diarrhea, you may consider over-the-counter medications such as loperamide (Imodium) to help alleviate symptoms. It is important to follow the recommended dosage and consult a healthcare provider if symptoms persist or worsen. Additionally, it is important to stay hydrated by drinking plenty of fluids.
|
12 |
+
|
13 |
+
Query: What are the clinical signs and symptoms of mitral stenosis?
|
14 |
+
Response: The clinical signs and symptoms of mitral stenosis include those of heart failure, such as fatigue, cold skin, and right upper quadrant abdominal discomfort. Additionally, signs may include an opening snap and a diastolic murmur. It is important to consult a healthcare professional for proper diagnosis and treatment. Medications such as diuretics and beta-blockers may be prescribed, but potential side effects should be discussed with a healthcare provider.
|
15 |
+
|
16 |
+
Query: How does dehydration affect blood pressure in someone with autonomic dysfunction?
|
17 |
+
Response: Dehydration can lead to hypotension (low blood pressure) in someone with autonomic dysfunction. This is because autonomic dysfunction can already cause orthostatic hypotension, and dehydration further exacerbates this by reducing blood volume and affecting the body's ability to regulate blood pressure. It is important for individuals with autonomic dysfunction to stay well-hydrated to help maintain their blood pressure within a healthy range. However, for personalized advice and treatment options, it is recommended to consult a healthcare professional.
|
18 |
+
|
19 |
+
Query: Can an ear infection cause a heart problem? If yes, how?
|
20 |
+
Response: Yes, in heavy infections, the inflammation from an ear infection may cause complications such as cardiac issues like myocarditis, heart failure, arrhythmia, and even death. The inflammation can spread to the heart and lead to serious complications. It is important to seek medical attention if you suspect an ear infection to prevent any potential complications. Consulting a healthcare professional for proper diagnosis and treatment is recommended.
|
21 |
+
|
22 |
+
Query: When do VPBs require treatment?
|
23 |
+
Response: VPBs, or ventricular premature beats, may require treatment if they are causing symptoms such as palpitations, dizziness, or fainting. Treatment options may include medications such as antiarrhythmics. It is important to consult with a healthcare professional for proper evaluation and management of VPBs.
|
24 |
+
|
25 |
+
Query: When do VPBs require treatment?
|
26 |
+
Response: VPBs, or ventricular premature beats, may require treatment if they are causing symptoms such as palpitations, dizziness, or fainting. Treatment options may include medications such as quetiapine or olanzapine, but it is important to consult a healthcare professional for proper evaluation and management. In some cases, preventive treatment may be indicated for patients with certain medical conditions that require bed rest. If you are experiencing symptoms of VPBs, it is important to seek medical advice for appropriate care.
|
27 |
+
|
28 |
+
Query: How does microvascular angina differ from typical angina?
|
29 |
+
Response: Microvascular angina differs from typical angina in that it is caused by impaired coronary blood flow due to issues in the smaller blood vessels, rather than blockages in the larger epicardial coronary arteries. This can lead to symptoms such as angina pectoris, syncope, or arrhythmias even in the absence of coronary artery disease. To diagnose microvascular angina, guidewires with pressure or flow sensors can be used to estimate blood flow across stenoses, with a fractional flow reserve (FFR) of < 0.75 to 0.8 indicating myocardial ischemia. Treatment for microvascular angina may involve medications like Lipiodol Ultra Fluide or Solu-Medrol, but it is important to consult a healthcare professional for proper diagnosis and management, as well as to discuss potential side effects of these medications.
|
30 |
+
|
31 |
+
Query: What are the common causes of pericarditis?
|
32 |
+
Response: Common causes of pericarditis include systemic rheumatic diseases such as rheumatoid arthritis, systemic lupus erythematosus (SLE), and systemic sclerosis. Bacterial infections from gram-negative bacilli, Haemophilus influenzae, Lyme disease, rheumatic fever, staphylococci, and streptococci can also lead to pericarditis. Additionally, tuberculosis, certain cancers like breast cancer, Kaposi sarcoma (in people with HIV infection), leukemia, and lung cancer, as well as certain medications like anticoagulants, can be common causes of pericarditis. It is important to consult a healthcare professional for proper diagnosis and treatment.
|
33 |
+
|
34 |
+
Query: suggest some medicine for pericarditis
|
35 |
+
Response: For pericarditis, medications such as colchicine or interleukin-1 antagonists may be prescribed based on the underlying cause and severity of the condition. It is important to consult a healthcare professional for a proper diagnosis and treatment plan. Additionally, for patients with newly diagnosed constrictive pericarditis, anti-inflammatory medications may be considered before pericardiectomy. It is crucial to follow the guidance of a healthcare provider and discuss any potential side effects of the prescribed medication.
|
36 |
+
|
37 |
+
Query: What are the common causes of orthostatic hypotension?
|
38 |
+
Response: Common causes of orthostatic hypotension include dehydration, electrolyte disturbances, use of certain drugs (such as antihypertensives and nitrates), aging, heart failure, chronic kidney disease, respiratory disorders, liver failure, renal insufficiency, and sedatives or opioids used for intubation and ventilation. It is important to consult a healthcare professional for proper diagnosis and treatment, as the underlying cause may vary from person to person.
|
39 |
+
|
40 |
+
Query: How is long QT syndrome diagnosed and treated?'
|
41 |
+
Response: Long QT syndrome is diagnosed based on clinical criteria. Treatment typically involves supportive care, including medications as needed for associated symptoms. It is important to consult with a healthcare professional for proper diagnosis and treatment. If medications are prescribed, it is essential to be aware of potential side effects.
|
42 |
+
|
43 |
+
Query: What are the classes of antiarrhythmic drugs and their mechanisms of action?
|
44 |
+
Response: The classes of antiarrhythmic drugs are grouped into four main classes based on their dominant cellular electrophysiologic effect. Class I drugs are sodium channel blockers that slow conduction in fast-channel tissues. They are subdivided into subclasses A, B, and C based on their kinetics of the sodium channel effects. Class Ib drugs have fast kinetics. The effects of these drugs may be reflected in the electrocardiogram as widening of the P wave, widening of the QRS complex, prolongation of the PR interval, or a combination. It is important to consult a healthcare professional for specific information on antiarrhythmic drugs and their mechanisms of action.
|
45 |
+
|
46 |
+
Query: What are the key causes of pitting and non-pitting edema?
|
47 |
+
Response: The key causes of pitting and non-pitting edema can include injury (such as burns, chemicals, toxins, or blunt trauma), severe sepsis causing vascular endothelial leakage, soft-tissue infections like cellulitis or necrotizing myofasciitis, and other inflammatory conditions. Pitting edema may present as painless, symmetrical swelling, while non-pitting edema may be focal and erythematous. It is important to seek clinical evaluation, cultures, and imaging studies as needed to determine the underlying cause of edema. For specific treatment options and management, it is advisable to consult a healthcare professional.
|
48 |
+
|
49 |
+
Query: tell me about this medicine: Rhinathiol Antirhinitis 10/4 compr.
|
50 |
+
Response: Rhinathiol Antirhinitis 10/4 compr. is a medication in the form of tablets. It is indicated for the treatment of rhinitis, which is inflammation of the mucous membrane of the nose. This medication may help reduce symptoms such as nasal congestion, runny nose, and sneezing associated with rhinitis.
|
51 |
+
|
52 |
+
It is important to note that this information is general and for educational purposes only. It is always recommended to consult with a healthcare professional or pharmacist before starting any new medication. Additionally, common side effects of Rhinathiol Antirhinitis 10/4 compr. may include drowsiness, dizziness, dry mouth, and blurred vision. If you experience any severe or persistent side effects, it is important to seek medical attention.
|
53 |
+
|
54 |
+
Query: What are the first-line treatments for type 2 diabetes, and how do they differ in mechanism?
|
55 |
+
Response: The first-line treatments for type 2 diabetes include oral antihyperglycemics, non-insulin injectable medications such as glucagon-like peptide-1 (GLP-1) receptor agonists, insulin, or a combination of these. These medications work by helping to lower blood sugar levels in different ways. Oral antihyperglycemics help improve insulin sensitivity or reduce glucose production in the liver. GLP-1 receptor agonists stimulate insulin secretion and reduce glucagon secretion. Insulin helps regulate blood sugar levels by replacing or supplementing the body's natural insulin production. It is important to consult healthcare professionals for personalized treatment recommendations and to discuss potential side effects of these medications.
|
56 |
+
|
57 |
+
Query: What are the first-line treatments for type 2 diabetes, and how do they differ in mechanism?
|
58 |
+
Response: The first-line treatments for type 2 diabetes include oral antihyperglycemics, non-insulin injectable medications such as glucagon-like peptide-1 (GLP-1) receptor agonists, insulin, or a combination of these. These medications work by helping to lower blood sugar levels in different ways. Oral antihyperglycemics help improve insulin sensitivity or reduce glucose production in the liver. GLP-1 receptor agonists help increase insulin secretion and decrease glucagon secretion. Insulin therapy helps regulate blood sugar levels by replacing or supplementing the body's insulin production. It is important to consult healthcare professionals for personalized treatment recommendations and to discuss potential side effects of these medications.
|
59 |
+
|
60 |
+
Query: What are the first-line treatments for type 2 diabetes, and how do they differ in mechanism?
|
61 |
+
Response: The first-line treatments for type 2 diabetes include oral antihyperglycemics, non-insulin injectable medications such as glucagon-like peptide-1 (GLP-1) receptor agonists, insulin, or a combination of these. These medications work by helping to lower blood sugar levels in different ways. Oral antihyperglycemics help improve insulin sensitivity or reduce glucose production in the liver. GLP-1 receptor agonists help increase insulin secretion and decrease glucagon secretion. Insulin therapy helps regulate blood sugar levels by replacing or supplementing the body's insulin production. It is important to consult healthcare professionals for personalized treatment plans and to discuss potential side effects of these medications.
|
62 |
+
|
63 |
+
Query: who is John Cena?
|
64 |
+
Response: I'm sorry, but based on the provided context, there is no information related to John Cena. If you have any other medical questions or concerns, please feel free to ask. It is always recommended to consult with a healthcare professional for personalized medical advice.
|
65 |
+
|
66 |
+
Query: who is John Cena?
|
67 |
+
Response: I'm sorry, but based on the provided context, there is no information related to a person named John Cena. If you are looking for information on a specific individual, it is recommended to consult reliable sources or conduct a separate search. If you have any medical questions or concerns, it is always best to consult a healthcare professional for personalized advice and guidance.
|
68 |
+
|
preprocess.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Union, Optional
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
# Configure logging
|
7 |
+
logging.basicConfig(
|
8 |
+
level=logging.INFO,
|
9 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
10 |
+
)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
class DataPreprocessor:
|
14 |
+
"""A class to handle data preprocessing operations for different file formats."""
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def _preprocess_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
18 |
+
"""
|
19 |
+
Applies standard preprocessing steps to a DataFrame.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
df (pd.DataFrame): Input DataFrame to preprocess
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
pd.DataFrame: Preprocessed DataFrame
|
26 |
+
"""
|
27 |
+
try:
|
28 |
+
# Convert text columns to lowercase for standardization
|
29 |
+
df = df.map(lambda x: x.lower() if isinstance(x, str) else x)
|
30 |
+
|
31 |
+
# Drop columns that are fully null
|
32 |
+
df = df.dropna(axis=1, how='all')
|
33 |
+
|
34 |
+
# Fill remaining NaN values with empty strings
|
35 |
+
df = df.fillna('')
|
36 |
+
|
37 |
+
# Remove duplicate rows
|
38 |
+
df = df.drop_duplicates()
|
39 |
+
|
40 |
+
return df
|
41 |
+
|
42 |
+
except Exception as e:
|
43 |
+
logger.error(f"Error during DataFrame preprocessing: {str(e)}")
|
44 |
+
raise
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def preprocess_msd(cls,
|
48 |
+
file_path: Union[str, Path],
|
49 |
+
output_path: Union[str, Path],
|
50 |
+
sheet_name: Optional[Union[str, int]] = 0) -> pd.DataFrame:
|
51 |
+
"""
|
52 |
+
Preprocesses an MSD Excel file and saves the result.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
file_path: Path to the Excel file
|
56 |
+
output_path: Directory path for the output file
|
57 |
+
sheet_name: Sheet name or index to load (default: 0)
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
pd.DataFrame: Preprocessed DataFrame
|
61 |
+
|
62 |
+
Raises:
|
63 |
+
FileNotFoundError: If input file doesn't exist
|
64 |
+
PermissionError: If output directory is not writable
|
65 |
+
"""
|
66 |
+
try:
|
67 |
+
# Convert to Path objects
|
68 |
+
file_path = Path(file_path)
|
69 |
+
output_path = Path(output_path)
|
70 |
+
|
71 |
+
# Validate input file
|
72 |
+
if not file_path.exists():
|
73 |
+
raise FileNotFoundError(f"Input file not found: {file_path}")
|
74 |
+
|
75 |
+
# Ensure output directory exists
|
76 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
77 |
+
|
78 |
+
logger.info(f"Processing MSD file: {file_path}")
|
79 |
+
df = pd.read_excel(file_path, sheet_name=sheet_name)
|
80 |
+
|
81 |
+
# Apply preprocessing
|
82 |
+
df = cls._preprocess_dataframe(df)
|
83 |
+
|
84 |
+
# Save processed file
|
85 |
+
output_file = output_path / "msd_processed.csv"
|
86 |
+
df.to_csv(output_file, index=False)
|
87 |
+
logger.info(f"Saved processed file to: {output_file}")
|
88 |
+
|
89 |
+
return df
|
90 |
+
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"Error processing MSD file: {str(e)}")
|
93 |
+
raise
|
94 |
+
|
95 |
+
@classmethod
|
96 |
+
def preprocess_cbip(cls,
|
97 |
+
input_dir: Union[str, Path],
|
98 |
+
output_dir: Union[str, Path]) -> None:
|
99 |
+
"""
|
100 |
+
Preprocesses all CSV files in the CBIP directory.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
input_dir: Directory containing input CSV files
|
104 |
+
output_dir: Directory for output files
|
105 |
+
|
106 |
+
Raises:
|
107 |
+
FileNotFoundError: If input directory doesn't exist
|
108 |
+
PermissionError: If output directory is not writable
|
109 |
+
"""
|
110 |
+
try:
|
111 |
+
# Convert to Path objects
|
112 |
+
input_dir = Path(input_dir)
|
113 |
+
output_dir = Path(output_dir)
|
114 |
+
|
115 |
+
# Validate input directory
|
116 |
+
if not input_dir.exists():
|
117 |
+
raise FileNotFoundError(f"Input directory not found: {input_dir}")
|
118 |
+
|
119 |
+
# Ensure output directory exists
|
120 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
121 |
+
|
122 |
+
# Process all CSV files
|
123 |
+
csv_files = list(input_dir.rglob("*.csv"))
|
124 |
+
if not csv_files:
|
125 |
+
logger.warning(f"No CSV files found in: {input_dir}")
|
126 |
+
return
|
127 |
+
|
128 |
+
for file_path in csv_files:
|
129 |
+
try:
|
130 |
+
logger.info(f"Processing CBIP file: {file_path}")
|
131 |
+
|
132 |
+
# Read CSV file
|
133 |
+
df = pd.read_csv(
|
134 |
+
file_path,
|
135 |
+
delimiter=';',
|
136 |
+
quotechar='"',
|
137 |
+
skip_blank_lines=True
|
138 |
+
)
|
139 |
+
|
140 |
+
# Apply preprocessing
|
141 |
+
df = cls._preprocess_dataframe(df)
|
142 |
+
|
143 |
+
# Save processed file
|
144 |
+
output_file = output_dir / file_path.name
|
145 |
+
df.to_csv(output_file, index=False)
|
146 |
+
logger.info(f"Saved processed file to: {output_file}")
|
147 |
+
|
148 |
+
except Exception as e:
|
149 |
+
logger.error(f"Error processing {file_path}: {str(e)}")
|
150 |
+
continue
|
151 |
+
|
152 |
+
except Exception as e:
|
153 |
+
logger.error(f"Error processing CBIP directory: {str(e)}")
|
154 |
+
raise
|
155 |
+
|
156 |
+
def main():
|
157 |
+
"""Main execution function."""
|
158 |
+
try:
|
159 |
+
import os
|
160 |
+
import argparse
|
161 |
+
from pathlib import Path
|
162 |
+
|
163 |
+
# Create processed_data directory in current working directory
|
164 |
+
output_base = Path.cwd() / "processed_data"
|
165 |
+
msd_output = output_base / "msd"
|
166 |
+
cbip_output = output_base / "cbip"
|
167 |
+
|
168 |
+
parser = argparse.ArgumentParser(description='Process MSD and CBIP data files.')
|
169 |
+
parser.add_argument('--msd-input', required=True, help='Path to MSD Excel file')
|
170 |
+
parser.add_argument('--cbip-input', required=True, help='Input directory containing CBIP CSV files')
|
171 |
+
|
172 |
+
args = parser.parse_args()
|
173 |
+
|
174 |
+
preprocessor = DataPreprocessor()
|
175 |
+
|
176 |
+
# Process MSD file
|
177 |
+
preprocessor.preprocess_msd(
|
178 |
+
args.msd_input,
|
179 |
+
msd_output
|
180 |
+
)
|
181 |
+
|
182 |
+
# Process CBIP directory
|
183 |
+
preprocessor.preprocess_cbip(
|
184 |
+
args.cbip_input,
|
185 |
+
cbip_output
|
186 |
+
)
|
187 |
+
|
188 |
+
except Exception as e:
|
189 |
+
logger.error(f"Main execution failed: {str(e)}")
|
190 |
+
raise
|
191 |
+
|
192 |
+
if __name__ == "__main__":
|
193 |
+
main()
|
rag.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from langchain_community.vectorstores import FAISS
|
7 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
8 |
+
from langchain_openai import ChatOpenAI
|
9 |
+
from langchain.chains import RetrievalQA
|
10 |
+
from langchain.prompts import PromptTemplate
|
11 |
+
from langchain.retrievers import MergerRetriever
|
12 |
+
|
13 |
+
# Load environment variables from .env file
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
# Configure logging
|
17 |
+
logging.basicConfig(
|
18 |
+
level=logging.INFO,
|
19 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
20 |
+
)
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
def load_faiss_index(folder_path: str, model_name: str) -> Optional[FAISS]:
|
24 |
+
"""
|
25 |
+
Load a FAISS index with a specific embedding model.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
folder_path: Path to the FAISS index folder
|
29 |
+
model_name: Name of the HuggingFace embedding model
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
FAISS: Loaded FAISS index object
|
33 |
+
|
34 |
+
Raises:
|
35 |
+
ValueError: If the folder path doesn't exist
|
36 |
+
"""
|
37 |
+
try:
|
38 |
+
if not os.path.exists(folder_path):
|
39 |
+
raise ValueError(f"FAISS index folder not found: {folder_path}")
|
40 |
+
|
41 |
+
logger.info(f"Loading FAISS index from {folder_path}")
|
42 |
+
embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
43 |
+
return FAISS.load_local(
|
44 |
+
folder_path=folder_path,
|
45 |
+
embeddings=embeddings,
|
46 |
+
allow_dangerous_deserialization=True
|
47 |
+
)
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(f"Error loading FAISS index: {str(e)}")
|
50 |
+
raise
|
51 |
+
|
52 |
+
def generate_answer(query: str) -> str:
|
53 |
+
"""
|
54 |
+
Generate an answer for the given query using RAG.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
query: User's question
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
str: Generated answer
|
61 |
+
|
62 |
+
Raises:
|
63 |
+
ValueError: If query is empty or required files are missing
|
64 |
+
"""
|
65 |
+
try:
|
66 |
+
if not query or not query.strip():
|
67 |
+
raise ValueError("Query cannot be empty")
|
68 |
+
|
69 |
+
# Get the current directory and construct paths
|
70 |
+
current_dir = Path(__file__).parent
|
71 |
+
vectors_dir = current_dir / "vectors_data"
|
72 |
+
|
73 |
+
# Validate vectors directory exists
|
74 |
+
if not vectors_dir.exists():
|
75 |
+
raise ValueError(f"Vectors directory not found at {vectors_dir}")
|
76 |
+
|
77 |
+
# Load FAISS indices
|
78 |
+
logger.info("Loading FAISS indices...")
|
79 |
+
data_vec = load_faiss_index(
|
80 |
+
str(vectors_dir / "msd_data_vec"),
|
81 |
+
"sentence-transformers/all-MiniLM-L12-v2"
|
82 |
+
)
|
83 |
+
med_vec = load_faiss_index(
|
84 |
+
str(vectors_dir / "med_data_vec"),
|
85 |
+
"sentence-transformers/all-MiniLM-L12-v2"
|
86 |
+
)
|
87 |
+
|
88 |
+
# Create the LLM instance
|
89 |
+
llm = ChatOpenAI(
|
90 |
+
model="gpt-3.5-turbo",
|
91 |
+
temperature=0,
|
92 |
+
openai_api_key="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"
|
93 |
+
)
|
94 |
+
|
95 |
+
# Define the prompt template
|
96 |
+
template = """You are a helpful medical information assistant. Use the following pieces of context to answer the medical question at the end.
|
97 |
+
|
98 |
+
Important notes:
|
99 |
+
- Base your answer strictly on the provided context and understandable for all readers
|
100 |
+
- If you don't know the answer, just say that you don't know
|
101 |
+
- Include relevant disclaimers about consulting healthcare professionals
|
102 |
+
- If suggesting medications (upon question), mention potential side effects if provided in the context
|
103 |
+
- Highlight if the information is general knowledge or requires professional medical advice
|
104 |
+
|
105 |
+
Context: {context}
|
106 |
+
|
107 |
+
Question: {question}
|
108 |
+
|
109 |
+
Medical Information Assistant:"""
|
110 |
+
|
111 |
+
QA_CHAIN_PROMPT = PromptTemplate(
|
112 |
+
input_variables=["context", "question"],
|
113 |
+
template=template
|
114 |
+
)
|
115 |
+
|
116 |
+
# Initialize and combine retrievers
|
117 |
+
logger.info("Setting up retrieval chain...")
|
118 |
+
data_retriever = data_vec.as_retriever()
|
119 |
+
med_retriever = med_vec.as_retriever()
|
120 |
+
combined_retriever = MergerRetriever(
|
121 |
+
retrievers=[data_retriever, med_retriever]
|
122 |
+
)
|
123 |
+
|
124 |
+
# Initialize the RetrievalQA chain
|
125 |
+
qa_chain = RetrievalQA.from_chain_type(
|
126 |
+
llm=llm,
|
127 |
+
retriever=combined_retriever,
|
128 |
+
return_source_documents=True,
|
129 |
+
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
|
130 |
+
)
|
131 |
+
|
132 |
+
# Run the chain
|
133 |
+
logger.info("Generating answer...")
|
134 |
+
result = qa_chain.invoke({"query": query})
|
135 |
+
logger.info("Answer generated successfully")
|
136 |
+
|
137 |
+
# Extracting the relevant documents from the result
|
138 |
+
extracted_docs = result.get("source_documents", [])
|
139 |
+
logger.info(f"Extracted documents: {extracted_docs}") # Log the extracted documents
|
140 |
+
|
141 |
+
# New organized printing of extracted documents
|
142 |
+
print("\nExtracted Documents:")
|
143 |
+
for doc in extracted_docs:
|
144 |
+
print(f"Source: {doc.metadata['source']}, Row: {doc.metadata['row']}")
|
145 |
+
print(f"Content: {doc.page_content}\n")
|
146 |
+
|
147 |
+
return result["result"]
|
148 |
+
|
149 |
+
except Exception as e:
|
150 |
+
logger.error(f"Error generating answer: {str(e)}")
|
151 |
+
raise
|
152 |
+
|
153 |
+
def main():
|
154 |
+
"""
|
155 |
+
Main function to demonstrate the usage of the RAG system.
|
156 |
+
"""
|
157 |
+
try:
|
158 |
+
# Example usage
|
159 |
+
query = "suggest me some medicine for bronchitis"
|
160 |
+
logger.info(f"Processing query: {query}")
|
161 |
+
|
162 |
+
response = generate_answer(query)
|
163 |
+
print("\nQuery:", query)
|
164 |
+
print("\nResponse:", response)
|
165 |
+
|
166 |
+
except Exception as e:
|
167 |
+
logger.error(f"Error in main function: {str(e)}")
|
168 |
+
print(f"An error occurred: {str(e)}")
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
main()
|
readme.md
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Building a Retrieval-Augmented Question-Answering System with FastAPI and LangChain
|
2 |
+
Develop a simple question-answering web service that leverages Retrieval-Augmented Generation (RAG) to provide answers based on a set of provided documents. The service will be built using Python, FastAPI, and LangChain.
|
3 |
+
|
4 |
+
## Installation
|
5 |
+
|
6 |
+
1. Clone or download this repository
|
7 |
+
```
|
8 |
+
git clone https://github.com/shamim237/artech_med_bot.git
|
9 |
+
```
|
10 |
+
2. Create a virtual environment (optional but recommended):
|
11 |
+
|
12 |
+
```
|
13 |
+
python -m venv venv
|
14 |
+
source venv/bin/activate # On Windows use `venv\Scripts\activate`
|
15 |
+
```
|
16 |
+
3. Install the required packages:
|
17 |
+
```
|
18 |
+
pip install -r requirements.txt
|
19 |
+
```
|
20 |
+
|
21 |
+
# Run the system
|
22 |
+
|
23 |
+
To run the system on your machine, follow the steps shown below:
|
24 |
+
|
25 |
+
## **preprocess.py**:
|
26 |
+
This script for preprocessing MSD (Excel) and CBIP (CSV) data files. This script provides standardized data cleaning and transformation operations for both file formats.
|
27 |
+
|
28 |
+
### Features
|
29 |
+
- Processes MSD Excel files and CBIP CSV files
|
30 |
+
- Standardizes text to lowercase
|
31 |
+
- Removes empty columns
|
32 |
+
- Handles missing values
|
33 |
+
- Eliminates duplicate rows
|
34 |
+
- Preserves original file structure
|
35 |
+
- Comprehensive logging
|
36 |
+
|
37 |
+
### Usage
|
38 |
+
```
|
39 |
+
python preprocess.py --msd-input "path/to/msd.xlsx" --cbip-input "path/to/cbip/directory"
|
40 |
+
```
|
41 |
+
#### Arguments
|
42 |
+
- `--msd-input`: Path to the MSD Excel file
|
43 |
+
- `--cbip-input`: Path to the directory containing CBIP CSV files
|
44 |
+
|
45 |
+
|
46 |
+
#### Output
|
47 |
+
|
48 |
+
The script creates a `processed_data` directory in your current working directory with the following structure:
|
49 |
+
processed_data/
|
50 |
+
├── msd/
|
51 |
+
│ └── msd_processed.csv
|
52 |
+
└── cbip/
|
53 |
+
└── [processed_csv_files]
|
54 |
+
|
55 |
+
### Data Processing Steps
|
56 |
+
|
57 |
+
1. **Text Standardization**: Converts all text to lowercase
|
58 |
+
2. **Column Cleaning**: Removes columns that are completely empty
|
59 |
+
3. **Missing Value Handling**: Fills NaN values with empty strings
|
60 |
+
4. **Duplicate Removal**: Removes duplicate rows from the dataset
|
61 |
+
|
62 |
+
### Error Handling
|
63 |
+
|
64 |
+
- The script includes comprehensive error handling and logging
|
65 |
+
- Errors are logged with timestamps and detailed messages
|
66 |
+
- Processing continues even if individual files fail
|
67 |
+
|
68 |
+
|
69 |
+
## "vectorize.py":
|
70 |
+
This script processes CSV documents and creates FAISS vector stores using LangChain and Hugging Face embeddings. It's designed to handle both MSD (Master Service Data) and medical data sources, converting them into efficient searchable vector representations.
|
71 |
+
|
72 |
+
### Features
|
73 |
+
- CSV document loading with support for multiple files
|
74 |
+
- Text chunking with configurable size and overlap
|
75 |
+
- FAISS vector store creation and persistence
|
76 |
+
- Comprehensive error handling and logging
|
77 |
+
- Support for Hugging Face embedding models
|
78 |
+
|
79 |
+
### Configuration
|
80 |
+
|
81 |
+
The script uses the following default configuration:
|
82 |
+
|
83 |
+
- MSD Data Path: `./processed_data/msd/msd_processed.csv`
|
84 |
+
- Medical CSV Path: `./processed_data/cbip/*.csv`
|
85 |
+
- MSD Vector Output: `./vectors_data/msd_data_vec`
|
86 |
+
- Medical Vector Output: `./vectors_data/med_data_vec`
|
87 |
+
- Embedding Model: `sentence-transformers/all-MiniLM-L12-v2`
|
88 |
+
|
89 |
+
### Usage
|
90 |
+
- Just run the script to get default output
|
91 |
+
```
|
92 |
+
python -m vectorize.py
|
93 |
+
```
|
94 |
+
- or change paths of the dataset
|
95 |
+
|
96 |
+
## "rag.py":
|
97 |
+
This script implements a Retrieval-Augmented Generation (RAG) system using LangChain, FAISS vector store, and OpenAI's GPT-3.5 model. The system combines medical and general data sources to provide informed answers to user queries.
|
98 |
+
|
99 |
+
### Features
|
100 |
+
|
101 |
+
- Dual vector store integration (medical and general data)
|
102 |
+
- HuggingFace embeddings using `all-MiniLM-L12-v2` model
|
103 |
+
- OpenAI GPT-3.5 for answer generation
|
104 |
+
- Comprehensive error handling and logging
|
105 |
+
- Environment variable support for API keys
|
106 |
+
|
107 |
+
### Prerequisites
|
108 |
+
|
109 |
+
- OpenAI API key
|
110 |
+
- Create a `.env` file in the project root and add your OpenAI API key: OPENAI_API_KEY=your_api_key_here
|
111 |
+
- Required vector stores in the `vectors_data` directory:
|
112 |
+
- `msd_data_vec/` - General data vector store
|
113 |
+
- `med_data_vec/` - Medical data vector store
|
114 |
+
|
115 |
+
### Usage
|
116 |
+
```
|
117 |
+
python rag.py
|
118 |
+
```
|
119 |
+
|
120 |
+
## "app.py":
|
121 |
+
This script is a FastAPI-based REST API that generates answers to questions using RAG (Retrieval-Augmented Generation) technology.
|
122 |
+
|
123 |
+
### Features
|
124 |
+
|
125 |
+
- Question answering endpoint with RAG integration
|
126 |
+
- Request ID tracking for all API calls
|
127 |
+
- Comprehensive error handling and logging
|
128 |
+
- Health check endpoint
|
129 |
+
- CORS support
|
130 |
+
- API documentation (Swagger UI and ReDoc)
|
131 |
+
|
132 |
+
### Usage
|
133 |
+
```
|
134 |
+
uvicorn app:app --reload
|
135 |
+
```
|
136 |
+
|
137 |
+
The server will start on `http://localhost:8000`
|
138 |
+
|
139 |
+
### API Endpoints
|
140 |
+
|
141 |
+
#### 1. Question Answering
|
142 |
+
- **Endpoint**: `/answer`
|
143 |
+
- **Method**: POST
|
144 |
+
- **Request Body**:
|
145 |
+
```
|
146 |
+
{
|
147 |
+
"question": "What is an overactive bladder?"
|
148 |
+
}
|
149 |
+
```
|
150 |
+
- **Response**:
|
151 |
+
```
|
152 |
+
{
|
153 |
+
"answer": "The generated answer..."
|
154 |
+
}
|
155 |
+
```
|
156 |
+
|
157 |
+
## test_rag.py:
|
158 |
+
|
159 |
+
The test suite validates the functionality of:
|
160 |
+
- Individual data retrievers (medicine and general data)
|
161 |
+
- Combined retriever functionality
|
162 |
+
- Answer generation system
|
163 |
+
- Error handling for edge cases
|
164 |
+
|
165 |
+
### Test Cases
|
166 |
+
|
167 |
+
The test suite includes the following test cases:
|
168 |
+
|
169 |
+
1. `test_data_retriever`: Tests retrieval from general data store
|
170 |
+
2. `test_med_retriever`: Tests retrieval from medical data store
|
171 |
+
3. `test_combined_retriever`: Tests the merged retriever functionality
|
172 |
+
4. `test_generate_answer`: Validates answer generation
|
173 |
+
5. `test_empty_query`: Tests error handling for invalid inputs
|
174 |
+
|
175 |
+
### Usage
|
176 |
+
```
|
177 |
+
python -m unittest test_rag.py
|
178 |
+
```
|
179 |
+
|
180 |
+
### Vector Store Setup
|
181 |
+
|
182 |
+
The system expects two FAISS indices in the `vectors/` directory:
|
183 |
+
- `msd_data_vec`: General knowledge vector store
|
184 |
+
- `med_data_vec`: Medical knowledge vector store
|
185 |
+
|
186 |
+
Both indices use the `sentence-transformers/all-MiniLM-L12-v2` embedding model.
|
187 |
+
|
188 |
+
#### Notes
|
189 |
+
|
190 |
+
- Ensure all vector stores are properly initialized before running tests
|
191 |
+
- The system uses the MiniLM-L12-v2 model for embeddings
|
192 |
+
- Empty or whitespace-only queries will raise ValueError exceptions
|
193 |
+
|
194 |
+
|
195 |
+
## test_app.py:
|
196 |
+
|
197 |
+
The test suite (`test_app.py`) validates the `/answer` endpoint's response to different types of requests, ensuring proper handling of both valid and invalid inputs.
|
198 |
+
|
199 |
+
### Test Cases
|
200 |
+
|
201 |
+
The test suite includes the following test cases:
|
202 |
+
|
203 |
+
1. **Valid Question Test**
|
204 |
+
- Verifies that the endpoint correctly processes a valid question
|
205 |
+
- Expects a 200 status code and an answer in the response
|
206 |
+
|
207 |
+
2. **Empty Question Test**
|
208 |
+
- Validates handling of empty string inputs
|
209 |
+
- Expects a 422 status code (Pydantic validation error)
|
210 |
+
|
211 |
+
3. **Whitespace Question Test**
|
212 |
+
- Checks handling of whitespace-only inputs
|
213 |
+
- Expects a 500 status code with an error message
|
214 |
+
|
215 |
+
4. **Missing Question Field Test**
|
216 |
+
- Verifies behavior when the question field is omitted
|
217 |
+
- Expects a 422 status code (FastAPI validation error)
|
218 |
+
|
219 |
+
5. **Invalid JSON Test**
|
220 |
+
- Tests handling of malformed JSON requests
|
221 |
+
- Expects a 422 status code (FastAPI validation error)
|
222 |
+
|
223 |
+
### Usage
|
224 |
+
```
|
225 |
+
python -m unittest test_app.py
|
226 |
+
```
|
227 |
+
|
228 |
+
## Assumptions and Trade-offs:
|
229 |
+
I generated and stored vector embeddings separately for disease/MSD data and medicine/CBIP data, believing that this separation would enhance the LLM's performance.
|
230 |
+
|
231 |
+
## Comments:
|
232 |
+
The quality of responses from this RAG-based LLM can be further strengthened through the following steps:
|
233 |
+
- Organizing the disease-related dataset more systematically.
|
234 |
+
- Structuring the medicine-related dataset more effectively.
|
235 |
+
- Enhancing disease-treatment and drug recommendations through better-organized mappings.
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas
|
2 |
+
openpyxl
|
3 |
+
langchain
|
4 |
+
langchain-community
|
5 |
+
langchain-huggingface
|
6 |
+
faiss-cpu
|
7 |
+
sentence-transformers
|
8 |
+
python-dotenv
|
9 |
+
langchain-openai
|
10 |
+
openai
|
11 |
+
fastapi
|
12 |
+
uvicorn
|
13 |
+
pydantic
|
test_app.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
from fastapi.testclient import TestClient
|
3 |
+
from app import app
|
4 |
+
|
5 |
+
class TestAnswerEndpoint(unittest.TestCase):
|
6 |
+
def setUp(self):
|
7 |
+
self.client = TestClient(app)
|
8 |
+
|
9 |
+
def test_valid_question(self):
|
10 |
+
response = self.client.post("/answer", json={"question": "What is an overactive bladder?"})
|
11 |
+
self.assertEqual(response.status_code, 200)
|
12 |
+
self.assertIn("answer", response.json())
|
13 |
+
|
14 |
+
def test_empty_question(self):
|
15 |
+
# Empty string fails Pydantic validation (constr(min_length=1))
|
16 |
+
response = self.client.post("/answer", json={"question": ""})
|
17 |
+
self.assertEqual(response.status_code, 422) # Pydantic validation error
|
18 |
+
self.assertIn("should have at least 1 character", response.json()["detail"][0]["msg"])
|
19 |
+
|
20 |
+
def test_whitespace_question(self):
|
21 |
+
# Whitespace-only string passes Pydantic but fails our custom validation
|
22 |
+
response = self.client.post("/answer", json={"question": " "})
|
23 |
+
self.assertEqual(response.status_code, 500) # Internal server error from our validation
|
24 |
+
self.assertEqual(response.json()["detail"], "Error processing request")
|
25 |
+
|
26 |
+
def test_missing_question_field(self):
|
27 |
+
response = self.client.post("/answer", json={})
|
28 |
+
self.assertEqual(response.status_code, 422) # FastAPI validation error
|
29 |
+
|
30 |
+
def test_invalid_json(self):
|
31 |
+
response = self.client.post("/answer", data="invalid json")
|
32 |
+
self.assertEqual(response.status_code, 422) # FastAPI validation error
|
33 |
+
|
34 |
+
if __name__ == '__main__':
|
35 |
+
unittest.main()
|
test_rag.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
from rag import load_faiss_index, generate_answer
|
3 |
+
from langchain.retrievers import MergerRetriever
|
4 |
+
|
5 |
+
class TestRetrieval(unittest.TestCase):
|
6 |
+
@classmethod
|
7 |
+
def setUpClass(cls):
|
8 |
+
# Initialize FAISS indices
|
9 |
+
cls.data_vec = load_faiss_index("./vectors_data/msd_data_vec", "sentence-transformers/all-MiniLM-L12-v2")
|
10 |
+
cls.med_vec = load_faiss_index("./vectors_data/med_data_vec", "sentence-transformers/all-MiniLM-L12-v2")
|
11 |
+
|
12 |
+
# Initialize retrievers
|
13 |
+
cls.data_retriever = cls.data_vec.as_retriever()
|
14 |
+
cls.med_retriever = cls.med_vec.as_retriever()
|
15 |
+
# Combine both retrievers into a single retriever
|
16 |
+
cls.combined_retriever = MergerRetriever(retrievers=[cls.data_retriever, cls.med_retriever])
|
17 |
+
|
18 |
+
def test_data_retriever(self):
|
19 |
+
# Test the data retriever with a specific query
|
20 |
+
query = "what are the symptoms of diabetes?"
|
21 |
+
docs = self.data_retriever.get_relevant_documents(query)
|
22 |
+
|
23 |
+
# Assert that documents are returned and are not empty
|
24 |
+
self.assertIsNotNone(docs)
|
25 |
+
self.assertTrue(len(docs) > 0)
|
26 |
+
# Check if documents have content
|
27 |
+
self.assertTrue(all(doc.page_content.strip() != "" for doc in docs))
|
28 |
+
|
29 |
+
def test_med_retriever(self):
|
30 |
+
# Test the medical retriever with a specific query
|
31 |
+
query = "what are common antibiotics?"
|
32 |
+
docs = self.med_retriever.get_relevant_documents(query)
|
33 |
+
|
34 |
+
# Assert that documents are returned and are not empty
|
35 |
+
self.assertIsNotNone(docs)
|
36 |
+
self.assertTrue(len(docs) > 0)
|
37 |
+
self.assertTrue(all(doc.page_content.strip() != "" for doc in docs))
|
38 |
+
|
39 |
+
def test_combined_retriever(self):
|
40 |
+
# Test the combined retriever with a specific query
|
41 |
+
query = "what is the treatment for high blood pressure?"
|
42 |
+
docs = self.combined_retriever.get_relevant_documents(query)
|
43 |
+
|
44 |
+
# Assert that documents are returned and are not empty
|
45 |
+
self.assertIsNotNone(docs)
|
46 |
+
self.assertTrue(len(docs) > 0)
|
47 |
+
self.assertTrue(all(doc.page_content.strip() != "" for doc in docs))
|
48 |
+
|
49 |
+
def test_generate_answer(self):
|
50 |
+
# Test the answer generation function with a specific query
|
51 |
+
query = "what are the side effects of aspirin?"
|
52 |
+
response = generate_answer(query)
|
53 |
+
|
54 |
+
# Assert that a valid response is returned
|
55 |
+
self.assertIsNotNone(response)
|
56 |
+
self.assertIsInstance(response, str)
|
57 |
+
self.assertTrue(len(response) > 0)
|
58 |
+
|
59 |
+
def test_empty_query(self):
|
60 |
+
# Test the answer generation function with an empty query
|
61 |
+
with self.assertRaises(ValueError): # More specific exception
|
62 |
+
generate_answer("")
|
63 |
+
|
64 |
+
# Test the answer generation function with a whitespace-only query
|
65 |
+
with self.assertRaises(ValueError):
|
66 |
+
generate_answer(" ") # Test whitespace-only query
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
unittest.main()
|
vectorize.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List, Optional
|
5 |
+
from langchain_community.vectorstores import FAISS
|
6 |
+
from langchain_community.document_loaders import CSVLoader
|
7 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
8 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
9 |
+
from langchain.docstore.document import Document
|
10 |
+
|
11 |
+
# Configure logging
|
12 |
+
logging.basicConfig(
|
13 |
+
level=logging.INFO,
|
14 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
15 |
+
handlers=[
|
16 |
+
logging.FileHandler('vectorize.log'),
|
17 |
+
logging.StreamHandler()
|
18 |
+
]
|
19 |
+
)
|
20 |
+
|
21 |
+
class VectorizationError(Exception):
|
22 |
+
"""Custom exception for vectorization-related errors"""
|
23 |
+
pass
|
24 |
+
|
25 |
+
def load_csv_documents(csv_file_path: str) -> List[Document]:
|
26 |
+
"""
|
27 |
+
Load CSV documents from the specified path.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
csv_file_path (str): Path pattern to search for CSV files.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
List[Document]: A list of documents loaded from the CSV files.
|
34 |
+
|
35 |
+
Raises:
|
36 |
+
VectorizationError: If no CSV files are found or if there's an error loading them.
|
37 |
+
"""
|
38 |
+
try:
|
39 |
+
documents = []
|
40 |
+
csv_files = list(glob.glob(csv_file_path))
|
41 |
+
|
42 |
+
if not csv_files:
|
43 |
+
raise VectorizationError(f"No CSV files found at path: {csv_file_path}")
|
44 |
+
|
45 |
+
for csv_file in csv_files:
|
46 |
+
logging.info(f"Loading CSV file: {csv_file}")
|
47 |
+
loader = CSVLoader(csv_file, encoding="utf-8")
|
48 |
+
documents.extend(loader.load())
|
49 |
+
|
50 |
+
logging.info(f"Successfully loaded {len(documents)} documents from {len(csv_files)} CSV files")
|
51 |
+
return documents
|
52 |
+
|
53 |
+
except Exception as e:
|
54 |
+
raise VectorizationError(f"Error loading CSV documents: {str(e)}")
|
55 |
+
|
56 |
+
def create_vector_store(
|
57 |
+
documents: List[Document],
|
58 |
+
embeddings_model: HuggingFaceEmbeddings,
|
59 |
+
output_path: str,
|
60 |
+
chunk_size: int = 500,
|
61 |
+
chunk_overlap: int = 50
|
62 |
+
) -> Optional[FAISS]:
|
63 |
+
"""
|
64 |
+
Create and save a FAISS vector store from documents.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
documents (List[Document]): List of documents to vectorize
|
68 |
+
embeddings_model (HuggingFaceEmbeddings): The embeddings model to use
|
69 |
+
output_path (str): Path to save the FAISS index
|
70 |
+
chunk_size (int, optional): Size of text chunks. Defaults to 500.
|
71 |
+
chunk_overlap (int, optional): Overlap between chunks. Defaults to 50.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
Optional[FAISS]: The created FAISS index if successful, None otherwise
|
75 |
+
"""
|
76 |
+
try:
|
77 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
78 |
+
chunk_size=chunk_size,
|
79 |
+
chunk_overlap=chunk_overlap
|
80 |
+
)
|
81 |
+
|
82 |
+
chunked_documents = text_splitter.split_documents(documents)
|
83 |
+
logging.info(f"Created {len(chunked_documents)} chunks from {len(documents)} documents")
|
84 |
+
|
85 |
+
faiss_index = FAISS.from_documents(chunked_documents, embeddings_model)
|
86 |
+
|
87 |
+
|
88 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
89 |
+
|
90 |
+
faiss_index.save_local(output_path)
|
91 |
+
logging.info(f"Successfully saved FAISS index to {output_path}")
|
92 |
+
|
93 |
+
return faiss_index
|
94 |
+
|
95 |
+
except Exception as e:
|
96 |
+
logging.error(f"Error creating vector store: {str(e)}")
|
97 |
+
return None
|
98 |
+
|
99 |
+
def main():
|
100 |
+
try:
|
101 |
+
# Configuration with relative paths
|
102 |
+
config = {
|
103 |
+
'msd_data_path': "./processed_data/msd/msd_processed.csv",
|
104 |
+
'medical_csv_path': "./processed_data/cbip/*.csv",
|
105 |
+
'msd_vector_path': "./vectors_data/msd_data_vec",
|
106 |
+
'medical_vector_path': "./vectors_data/med_data_vec",
|
107 |
+
'model_name': "sentence-transformers/all-MiniLM-L12-v2"
|
108 |
+
}
|
109 |
+
|
110 |
+
# Create vectors_data directory if it doesn't exist
|
111 |
+
Path("./vectors_data").mkdir(exist_ok=True)
|
112 |
+
|
113 |
+
logging.info("Starting vectorization process")
|
114 |
+
|
115 |
+
# Load documents
|
116 |
+
msd_data_documents = load_csv_documents(config['msd_data_path'])
|
117 |
+
medical_documents = load_csv_documents(config['medical_csv_path'])
|
118 |
+
|
119 |
+
# Initialize embeddings model
|
120 |
+
logging.info(f"Initializing embeddings model: {config['model_name']}")
|
121 |
+
embeddings_model = HuggingFaceEmbeddings(model_name=config['model_name'])
|
122 |
+
|
123 |
+
# Create vector stores
|
124 |
+
msd_index = create_vector_store(
|
125 |
+
msd_data_documents,
|
126 |
+
embeddings_model,
|
127 |
+
config['msd_vector_path']
|
128 |
+
)
|
129 |
+
|
130 |
+
medical_index = create_vector_store(
|
131 |
+
medical_documents,
|
132 |
+
embeddings_model,
|
133 |
+
config['medical_vector_path']
|
134 |
+
)
|
135 |
+
|
136 |
+
if msd_index and medical_index:
|
137 |
+
logging.info("Vectorization process completed successfully")
|
138 |
+
else:
|
139 |
+
logging.error("Vectorization process completed with errors")
|
140 |
+
|
141 |
+
except VectorizationError as ve:
|
142 |
+
logging.error(f"Vectorization error: {str(ve)}")
|
143 |
+
raise
|
144 |
+
except Exception as e:
|
145 |
+
logging.error(f"Unexpected error: {str(e)}")
|
146 |
+
raise
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
main()
|
vectors_data/med_data_vec/index.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:178a3e2e7fea1a9ca64ca39dd32443d3c90f7fcd7602328470f65a896c43557d
|
3 |
+
size 13323309
|
vectors_data/med_data_vec/index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f1a9804e83879fc36cc98e68dc88564bd16ce3d552eb0b0b54da145f8d92e14
|
3 |
+
size 2648619
|
vectors_data/msd_data_vec/index.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:504d3ef09a0197a1f77de7865f23a801596ece8722e024cdcf07859c9ae53a52
|
3 |
+
size 370983981
|
vectors_data/msd_data_vec/index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57f6ad56b1c1462bf030ab66e7f37bf2616c6c8b0cfcd9841df39606a8157eba
|
3 |
+
size 135114811
|