artech-med-bot / rag.py
shamim237's picture
Update rag.py
f856ab7 verified
import os
import logging
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.retrievers import MergerRetriever
# Load environment variables from .env file
load_dotenv()
# Retrieve the OpenAI API key from environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def load_faiss_index(folder_path: str, model_name: str) -> Optional[FAISS]:
"""
Load a FAISS index with a specific embedding model.
Args:
folder_path: Path to the FAISS index folder
model_name: Name of the HuggingFace embedding model
Returns:
FAISS: Loaded FAISS index object
Raises:
ValueError: If the folder path doesn't exist
"""
try:
if not os.path.exists(folder_path):
raise ValueError(f"FAISS index folder not found: {folder_path}")
logger.info(f"Loading FAISS index from {folder_path}")
embeddings = HuggingFaceEmbeddings(model_name=model_name)
return FAISS.load_local(
folder_path=folder_path,
embeddings=embeddings,
allow_dangerous_deserialization=True
)
except Exception as e:
logger.error(f"Error loading FAISS index: {str(e)}")
raise
def generate_answer(query: str) -> str:
"""
Generate an answer for the given query using RAG.
Args:
query: User's question
Returns:
str: Generated answer
Raises:
ValueError: If query is empty or required files are missing
"""
try:
if not query or not query.strip():
raise ValueError("Query cannot be empty")
# Get the current directory and construct paths
current_dir = Path(__file__).parent
vectors_dir = current_dir / "vectors_data"
# Validate vectors directory exists
if not vectors_dir.exists():
raise ValueError(f"Vectors directory not found at {vectors_dir}")
# Load FAISS indices
logger.info("Loading FAISS indices...")
data_vec = load_faiss_index(
str(vectors_dir / "faiss_v4"),
"sentence-transformers/all-MiniLM-L12-v2"
)
# Create the LLM instance
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
openai_api_key=OPENAI_API_KEY
)
template = """You are a knowledgeable and approachable medical information assistant. Use the context provided to answer the medical question at the end. Follow these guidelines to ensure a clear, user-friendly, and professional response:
Important Guidelines:
1. **Clarity and Accessibility:**
- Write in simple, understandable language suitable for a general audience.
- Explain any technical terms briefly, if used.
2. **Structure:**
- Use clear paragraphs or bullet points for organization.
- Start with a concise summary of the issue before diving into details.
3. **Accuracy and Reliability:**
- Base your response strictly on the context provided.
- If you cannot provide an answer based on the context, state this honestly.
4. **Medical Safety and Disclaimers:**
- Include a disclaimer emphasizing the need to consult a healthcare professional for a personalized diagnosis or treatment plan.
5. **Treatment Information (if applicable):**
- Clearly outline treatment options, including:
- Drug name
- Drug class
- Dosage
- Frequency and duration
- Potential side effects
- Risks and additional recommendations
- Specify that these options are general and should be discussed with a healthcare provider.
6. **Encourage Engagement:**
- Invite users to ask clarifying questions or provide additional details for a more tailored response.
Context: {context}
Question: {question}
Medical Information Assistant:"""
QA_CHAIN_PROMPT = PromptTemplate(
input_variables=["context", "question"],
template=template
)
# Initialize and combine retrievers
logger.info("Setting up retrieval chain...")
data_retriever = data_vec.as_retriever()
# Initialize the RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=data_retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
)
# Run the chain
logger.info("Generating answer...")
result = qa_chain.invoke({"query": query})
logger.info("Answer generated successfully")
# Extracting the relevant documents from the result
extracted_docs = result.get("source_documents", [])
logger.info(f"Extracted documents: {extracted_docs}") # Log the extracted documents
return result["result"]
except Exception as e:
logger.error(f"Error generating answer: {str(e)}")
raise
def main():
"""
Main function to demonstrate the usage of the RAG system.
"""
try:
# Example usage
query = "suggest me some medicine for bronchitis"
logger.info(f"Processing query: {query}")
response = generate_answer(query)
print("\nQuery:", query)
print("\nResponse:", response)
except Exception as e:
logger.error(f"Error in main function: {str(e)}")
print(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()