Spaces:
Sleeping
Sleeping
# File loading and environment variables | |
import os | |
from dotenv import load_dotenv | |
# Gemini Library | |
import google.generativeai as genai | |
# Langchain | |
from langchain.document_loaders import TextLoader | |
from langchain.prompts import PromptTemplate | |
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
# MongoDB | |
from pymongo import MongoClient | |
# Function type hints | |
from typing import Dict, Any | |
# Streamlit | |
import streamlit as st | |
# Load environment variables | |
load_dotenv() | |
# Retrieve environment variables | |
MONGO_URI = os.getenv("MONGO_URI") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
# Configure Gemini | |
genai.configure(api_key=GEMINI_API_KEY) | |
model = genai.GenerativeModel("gemini-1.5-flash") | |
# Setup Vector Store and MongoDB Atlas connection | |
# Connect to MongoDB Atlas cluster using the connection string | |
DB_NAME = "ericguan04" | |
COLLECTION_NAME = "first_aid_intents" | |
vector_search_index = "vector_index" | |
def get_mongodb_collection(): | |
# Connect to MongoDB Atlas cluster using the connection string | |
cluster = MongoClient(MONGO_URI) | |
# Connect to the specific collection in the database | |
return cluster[DB_NAME][COLLECTION_NAME] | |
MONGODB_COLLECTION = get_mongodb_collection() | |
def load_embedding_model(): | |
return HuggingFaceInferenceAPIEmbeddings( | |
api_key=HF_TOKEN, model_name="sentence-transformers/all-mpnet-base-v2" | |
) | |
embedding_model = load_embedding_model() | |
vector_search = MongoDBAtlasVectorSearch.from_connection_string( | |
connection_string=MONGO_URI, | |
namespace=f"{DB_NAME}.{COLLECTION_NAME}", | |
embedding=embedding_model, | |
index_name=vector_search_index, | |
) | |
# k to search for only the X most relevant documents | |
k = 10 | |
# score_threshold to use only documents with a relevance score above 0.80 | |
score_threshold = 0.80 | |
# Build your retriever | |
retriever_1 = vector_search.as_retriever( | |
search_type="similarity", # similarity, mmr, similarity_score_threshold. https://api.python.langchain.com/en/latest/vectorstores/langchain_core.vectorstores.VectorStore.html#langchain_core.vectorstores.VectorStore.as_retriever | |
search_kwargs={"k": k, "score_threshold": score_threshold}, | |
) | |
# Define the prompt template | |
prompt = PromptTemplate.from_template( | |
"""You are playing the role of a medical assistant. A patient has come to you with a minor medical issue. | |
Use the following pieces of context to answer the question at the end. | |
To be more natural, do not mention you are referring to the context. | |
START OF CONTEXT: | |
{context} | |
END OF CONTEXT: | |
START OF QUESTION: | |
{question} | |
END OF QUESTION: | |
If you do not know the answer, just say that you do not know. | |
NEVER assume things. | |
If the question is not relevant to the context, just say that it is not relevant. | |
""" | |
) | |
# Formatting the retrieved documents before inserting them in the system prompt template | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
def generate_response(input_dict: Dict[str, Any]) -> str: | |
""" | |
Generate a response using the Gemini model. | |
Parameters: | |
input_dict (Dict[str, Any]): Dictionary with formatted context and question. | |
Returns: | |
str: Generated response from the Gemini model. | |
""" | |
formatted_prompt = prompt.format(**input_dict) | |
response = model.generate_content(formatted_prompt) | |
return response.text # Adjust based on actual response structure | |
# Build the chain with retriever_1 | |
rag_chain = ( | |
{ | |
"context": retriever_1 | RunnableLambda(format_docs), | |
"question": RunnablePassthrough(), | |
} | |
| RunnableLambda(generate_response) | |
) | |