pediatric_RAG / app.py
drkareemkamal's picture
Update app.py
52a148f verified
import os
import streamlit as st
from langchain_community.document_loaders import PDFPlumberLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
from langchain.llms import CTransformers
from langchain.llms import HuggingFaceHub
import torch
# ==== Configuration ====
pdfs_directory = 'pdfs'
vectorstores_directory = 'vectorstores_medical'
os.makedirs(pdfs_directory, exist_ok=True)
os.makedirs(vectorstores_directory, exist_ok=True)
PREDEFINED_BOOKS = [f for f in os.listdir(pdfs_directory) if f.endswith(".pdf")]
TEMPLATE = """
You are a medical assistant with deep clinical knowledge.
Use the following retrieved context to answer the question.
If unsure, say "I don't know." Keep answers accurate, concise, and clear.
Question: {question}
Context: {context}
Answer:
"""
# ==== Embedding Model (Medical) ====
embedding_model = HuggingFaceEmbeddings(
model_name='pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb',
model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
encode_kwargs={"normalize_embeddings": False}
)
# ==== LLM (Local Quantized Medical Model) ====
# llm = CTransformers(
# model='TheBloke/MedAlpaca-7B-GGUF',
# model_file='medalpaca-7b.Q4_K_M.gguf',
# model_type='llama',
# config={'max_new_tokens': 512, 'temperature': 0.4}
# )
llm = HuggingFaceHub(
repo_id="epfl-llm/meditron-7b", # Or BioGPT, GatorTron, ClinicalT5, etc.
model_kwargs={"temperature": 0.4, "max_new_tokens": 512},
# repo_id="microsoft/BioGPT-Large",
# model_kwargs={"temperature": 0.4, "max_new_tokens": 512},
# repo_id="emilyalsentzer/Bio_ClinicalBERT", # Encoder-only, fast
# model_kwargs={"temperature": 0.3, "max_new_tokens": 256},
huggingfacehub_api_token=os.getenv('hf_token')
)
# ==== Helpers ====
def split_text(documents):
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
add_start_index=True
)
return splitter.split_documents(documents)
def get_vectorstore_path(book_filename):
base_name = os.path.splitext(book_filename)[0]
return os.path.join(vectorstores_directory, base_name)
def load_or_create_vectorstore(book_filename, documents=None):
vs_path = get_vectorstore_path(book_filename)
if os.path.exists(os.path.join(vs_path, "index.faiss")):
return FAISS.load_local(vs_path, embedding_model, allow_dangerous_deserialization=True)
if documents is None:
raise ValueError("Documents required to create vector store.")
with st.spinner(f"⏳ Creating vector store for '{book_filename}'..."):
os.makedirs(vs_path, exist_ok=True)
chunks = split_text(documents)
vector_store = FAISS.from_documents(chunks, embedding_model)
vector_store.save_local(vs_path)
st.success(f"βœ… Vector store created for '{book_filename}'.")
return vector_store
def retrieve_docs(vector_store, query):
return vector_store.similarity_search(query)
def answer_question(question, documents):
context = "\n\n".join(doc.page_content for doc in documents)
prompt = ChatPromptTemplate.from_template(TEMPLATE)
chain = LLMChain(llm=llm, prompt=prompt)
return chain.run({"question": question, "context": context})
def upload_pdf(file):
save_path = os.path.join(pdfs_directory, file.name)
with open(save_path, "wb") as f:
f.write(file.getbuffer())
return file.name
def load_pdf(file_path):
loader = PDFPlumberLoader(file_path)
return loader.load()
# ==== Streamlit App ====
st.set_page_config(page_title="🩺 Medical PDF Chat", layout="centered")
st.title("πŸ“š Medical Assistant - PDF Q&A")
with st.sidebar:
st.header("Select or Upload a Medical Book")
selected_book = st.selectbox("Choose a PDF", PREDEFINED_BOOKS + ["Upload new book"])
if selected_book == "Upload new book":
uploaded_file = st.file_uploader("Upload Medical PDF", type="pdf")
if uploaded_file:
filename = upload_pdf(uploaded_file)
st.success(f"πŸ“₯ Uploaded: {filename}")
selected_book = filename
# ==== Main Logic ====
if selected_book and selected_book != "Upload new book":
st.info(f"πŸ“– You selected: {selected_book}")
file_path = os.path.join(pdfs_directory, selected_book)
vectorstore_path = get_vectorstore_path(selected_book)
try:
if os.path.exists(os.path.join(vectorstore_path, "index.faiss")):
st.success("βœ… Vector store already exists. Using cached version.")
vector_store = load_or_create_vectorstore(selected_book)
else:
documents = load_pdf(file_path)
vector_store = load_or_create_vectorstore(selected_book, documents)
# Chat Input
question = st.chat_input("Ask your medical question...")
if question:
st.chat_message("user").write(question)
related_docs = retrieve_docs(vector_store, question)
answer = answer_question(question, related_docs)
st.chat_message("assistant").write(answer)
except Exception as e:
st.error(f"❌ Error loading or processing the PDF: {e}")