Spaces:
Sleeping
Sleeping
File size: 3,949 Bytes
248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 de080b7 248b0d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# File loading and environment variables
import os
from dotenv import load_dotenv
# Gemini Library
import google.generativeai as genai
# Langchain
from langchain.document_loaders import TextLoader
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
# MongoDB
from pymongo import MongoClient
# Function type hints
from typing import Dict, Any
# Streamlit
import streamlit as st
# Load environment variables
load_dotenv()
# Retrieve environment variables
MONGO_URI = os.getenv("MONGO_URI")
HF_TOKEN = os.getenv("HF_TOKEN")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
# Configure Gemini
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel("gemini-1.5-flash")
# Setup Vector Store and MongoDB Atlas connection
# Connect to MongoDB Atlas cluster using the connection string
DB_NAME = "ericguan04"
COLLECTION_NAME = "first_aid_intents"
vector_search_index = "vector_index"
@st.cache_resource
def get_mongodb_collection():
# Connect to MongoDB Atlas cluster using the connection string
cluster = MongoClient(MONGO_URI)
# Connect to the specific collection in the database
return cluster[DB_NAME][COLLECTION_NAME]
MONGODB_COLLECTION = get_mongodb_collection()
@st.cache_resource
def load_embedding_model():
return HuggingFaceInferenceAPIEmbeddings(
api_key=HF_TOKEN, model_name="sentence-transformers/all-mpnet-base-v2"
)
embedding_model = load_embedding_model()
vector_search = MongoDBAtlasVectorSearch.from_connection_string(
connection_string=MONGO_URI,
namespace=f"{DB_NAME}.{COLLECTION_NAME}",
embedding=embedding_model,
index_name=vector_search_index,
)
# k to search for only the X most relevant documents
k = 10
# score_threshold to use only documents with a relevance score above 0.80
score_threshold = 0.80
# Build your retriever
retriever_1 = vector_search.as_retriever(
search_type="similarity", # similarity, mmr, similarity_score_threshold. https://api.python.langchain.com/en/latest/vectorstores/langchain_core.vectorstores.VectorStore.html#langchain_core.vectorstores.VectorStore.as_retriever
search_kwargs={"k": k, "score_threshold": score_threshold},
)
# Define the prompt template
prompt = PromptTemplate.from_template(
"""You are playing the role of a medical assistant. A patient has come to you with a minor medical issue.
Use the following pieces of context to answer the question at the end.
To be more natural, do not mention you are referring to the context.
START OF CONTEXT:
{context}
END OF CONTEXT:
START OF QUESTION:
{question}
END OF QUESTION:
If you do not know the answer, just say that you do not know.
NEVER assume things.
If the question is not relevant to the context, just say that it is not relevant.
"""
)
# Formatting the retrieved documents before inserting them in the system prompt template
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
@st.cache_resource
def generate_response(input_dict: Dict[str, Any]) -> str:
"""
Generate a response using the Gemini model.
Parameters:
input_dict (Dict[str, Any]): Dictionary with formatted context and question.
Returns:
str: Generated response from the Gemini model.
"""
formatted_prompt = prompt.format(**input_dict)
response = model.generate_content(formatted_prompt)
return response.text # Adjust based on actual response structure
# Build the chain with retriever_1
rag_chain = (
{
"context": retriever_1 | RunnableLambda(format_docs),
"question": RunnablePassthrough(),
}
| RunnableLambda(generate_response)
)
|