first_aid_ai / rag_model.py
Eric Guan
Remove binary files from repository
248b0d6
# File loading and environment variables
import os
from dotenv import load_dotenv
# Gemini Library
import google.generativeai as genai
# Langchain
from langchain.document_loaders import TextLoader
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
# MongoDB
from pymongo import MongoClient
# Function type hints
from typing import Dict, Any
# Streamlit
import streamlit as st
# Load environment variables
load_dotenv()
# Retrieve environment variables
MONGO_URI = os.getenv("MONGO_URI")
HF_TOKEN = os.getenv("HF_TOKEN")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
# Configure Gemini
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel("gemini-1.5-flash")
# Setup Vector Store and MongoDB Atlas connection
# Connect to MongoDB Atlas cluster using the connection string
DB_NAME = "ericguan04"
COLLECTION_NAME = "first_aid_intents"
vector_search_index = "vector_index"
@st.cache_resource
def get_mongodb_collection():
# Connect to MongoDB Atlas cluster using the connection string
cluster = MongoClient(MONGO_URI)
# Connect to the specific collection in the database
return cluster[DB_NAME][COLLECTION_NAME]
MONGODB_COLLECTION = get_mongodb_collection()
@st.cache_resource
def load_embedding_model():
return HuggingFaceInferenceAPIEmbeddings(
api_key=HF_TOKEN, model_name="sentence-transformers/all-mpnet-base-v2"
)
embedding_model = load_embedding_model()
vector_search = MongoDBAtlasVectorSearch.from_connection_string(
connection_string=MONGO_URI,
namespace=f"{DB_NAME}.{COLLECTION_NAME}",
embedding=embedding_model,
index_name=vector_search_index,
)
# k to search for only the X most relevant documents
k = 10
# score_threshold to use only documents with a relevance score above 0.80
score_threshold = 0.80
# Build your retriever
retriever_1 = vector_search.as_retriever(
search_type="similarity", # similarity, mmr, similarity_score_threshold. https://api.python.langchain.com/en/latest/vectorstores/langchain_core.vectorstores.VectorStore.html#langchain_core.vectorstores.VectorStore.as_retriever
search_kwargs={"k": k, "score_threshold": score_threshold},
)
# Define the prompt template
prompt = PromptTemplate.from_template(
"""You are playing the role of a medical assistant. A patient has come to you with a minor medical issue.
Use the following pieces of context to answer the question at the end.
To be more natural, do not mention you are referring to the context.
START OF CONTEXT:
{context}
END OF CONTEXT:
START OF QUESTION:
{question}
END OF QUESTION:
If you do not know the answer, just say that you do not know.
NEVER assume things.
If the question is not relevant to the context, just say that it is not relevant.
"""
)
# Formatting the retrieved documents before inserting them in the system prompt template
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
@st.cache_resource
def generate_response(input_dict: Dict[str, Any]) -> str:
"""
Generate a response using the Gemini model.
Parameters:
input_dict (Dict[str, Any]): Dictionary with formatted context and question.
Returns:
str: Generated response from the Gemini model.
"""
formatted_prompt = prompt.format(**input_dict)
response = model.generate_content(formatted_prompt)
return response.text # Adjust based on actual response structure
# Build the chain with retriever_1
rag_chain = (
{
"context": retriever_1 | RunnableLambda(format_docs),
"question": RunnablePassthrough(),
}
| RunnableLambda(generate_response)
)