File size: 3,949 Bytes
248b0d6
de080b7
 
 
248b0d6
 
de080b7
248b0d6
de080b7
 
 
 
 
 
 
248b0d6
de080b7
 
248b0d6
de080b7
 
 
 
 
248b0d6
de080b7
 
248b0d6
de080b7
 
248b0d6
 
 
 
 
de080b7
248b0d6
de080b7
248b0d6
de080b7
 
 
 
 
 
248b0d6
de080b7
248b0d6
de080b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248b0d6
de080b7
 
248b0d6
de080b7
 
 
 
248b0d6
 
de080b7
 
 
 
248b0d6
 
 
de080b7
 
 
 
 
 
 
 
 
 
 
248b0d6
de080b7
 
 
248b0d6
de080b7
 
 
 
 
248b0d6
 
 
 
 
 
 
 
 
de080b7
248b0d6
 
de080b7
248b0d6
de080b7
 
 
248b0d6
de080b7
 
248b0d6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# File loading and environment variables
import os
from dotenv import load_dotenv

# Gemini Library
import google.generativeai as genai

# Langchain
from langchain.document_loaders import TextLoader
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings

# MongoDB
from pymongo import MongoClient

# Function type hints
from typing import Dict, Any

# Streamlit
import streamlit as st

# Load environment variables
load_dotenv()

# Retrieve environment variables
MONGO_URI = os.getenv("MONGO_URI")
HF_TOKEN = os.getenv("HF_TOKEN")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

# Configure Gemini
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel("gemini-1.5-flash")

# Setup Vector Store and MongoDB Atlas connection

# Connect to MongoDB Atlas cluster using the connection string
DB_NAME = "ericguan04"
COLLECTION_NAME = "first_aid_intents"
vector_search_index = "vector_index"

@st.cache_resource
def get_mongodb_collection():
    # Connect to MongoDB Atlas cluster using the connection string
    cluster = MongoClient(MONGO_URI)
    # Connect to the specific collection in the database
    return cluster[DB_NAME][COLLECTION_NAME]

MONGODB_COLLECTION = get_mongodb_collection()

@st.cache_resource
def load_embedding_model():
    return HuggingFaceInferenceAPIEmbeddings(
        api_key=HF_TOKEN, model_name="sentence-transformers/all-mpnet-base-v2"
    )

embedding_model = load_embedding_model()

vector_search = MongoDBAtlasVectorSearch.from_connection_string(
    connection_string=MONGO_URI,
    namespace=f"{DB_NAME}.{COLLECTION_NAME}",
    embedding=embedding_model,
    index_name=vector_search_index,
)

# k to search for only the X most relevant documents
k = 10

# score_threshold to use only documents with a relevance score above 0.80
score_threshold = 0.80

# Build your retriever
retriever_1 = vector_search.as_retriever(
   search_type="similarity",  # similarity, mmr, similarity_score_threshold. https://api.python.langchain.com/en/latest/vectorstores/langchain_core.vectorstores.VectorStore.html#langchain_core.vectorstores.VectorStore.as_retriever
   search_kwargs={"k": k, "score_threshold": score_threshold},
)

# Define the prompt template
prompt = PromptTemplate.from_template(
    """You are playing the role of a medical assistant. A patient has come to you with a minor medical issue.
    Use the following pieces of context to answer the question at the end.
    To be more natural, do not mention you are referring to the context.

    START OF CONTEXT:
    {context}
    END OF CONTEXT:

    START OF QUESTION:
    {question}
    END OF QUESTION:

    If you do not know the answer, just say that you do not know.
    NEVER assume things.
    If the question is not relevant to the context, just say that it is not relevant.
    """
)

# Formatting the retrieved documents before inserting them in the system prompt template
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

@st.cache_resource
def generate_response(input_dict: Dict[str, Any]) -> str:
    """
    Generate a response using the Gemini model.
    
    Parameters:
        input_dict (Dict[str, Any]): Dictionary with formatted context and question.
        
    Returns:
        str: Generated response from the Gemini model.
    """
    formatted_prompt = prompt.format(**input_dict)
    response = model.generate_content(formatted_prompt)
    return response.text  # Adjust based on actual response structure

# Build the chain with retriever_1
rag_chain = (
    {
        "context": retriever_1 | RunnableLambda(format_docs),
        "question": RunnablePassthrough(),
    }
    | RunnableLambda(generate_response)
)