File size: 3,489 Bytes
49dcda7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
from sentence_transformers import SentenceTransformer
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_community.llms import Ollama
import faiss
import numpy as np
import pickle
import requests
import json

# Load the FAISS index
@st.cache(allow_output_mutation=True)
def load_faiss_index():
    try:
        return faiss.read_index("database/pdf_sections_index.faiss")
    except FileNotFoundError:
        st.error("FAISS index file not found. Please ensure 'pdf_sections_index.faiss' exists.")
        st.stop()

# Load the embedding model
@st.cache(allow_output_mutation=True)
def load_embedding_model():
    return SentenceTransformer('all-MiniLM-L6-v2')

# Load sections data
@st.cache(allow_output_mutation=True)
def load_sections_data():
    try:
        with open('database/pdf_sections_data.pkl', 'rb') as f:
            return pickle.load(f)
    except FileNotFoundError:
        st.error("Sections data file not found. Please ensure 'pdf_sections_data.pkl' exists.")
        st.stop()

# Initialize resources
index = load_faiss_index()
model = load_embedding_model()
sections_data = load_sections_data()

def search_faiss(query, k=3):
    query_vector = model.encode([query])[0].astype('float32')
    query_vector = np.expand_dims(query_vector, axis=0)
    distances, indices = index.search(query_vector, k)
    
    results = []
    for dist, idx in zip(distances[0], indices[0]):
        results.append({
            'distance': dist,
            'content': sections_data[idx]['content'],
            'metadata': sections_data[idx]['metadata']
        })
    
    return results

prompt_template = """
You are an AI assistant specialized in Mental Health & wellness guidelines. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

Context:
{context}

Question: {question}

Answer:"""

prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

@st.cache(allow_output_mutation=True)
def load_llm():
    return Ollama(model="phi3")  

llm = load_llm()
chain = LLMChain(llm=llm, prompt=prompt)

def answer_question(query):
    search_results = search_faiss(query)
    context = "\n\n".join([result['content'] for result in search_results])
    response = chain.run(context=context, question=query)
    return response, context

# Streamlit UI
st.title("Mental Health & Wellness Assistant")

query = st.text_input("Enter your question about Mental Health:")

if st.button("Get Answer"):
    if query:
        with st.spinner("Searching, Thinking and generating answer..."):
            answer, context = answer_question(query)
            st.subheader("Answer:")
            st.write(answer)
            with st.expander("Show Context"):
                st.write(context)
    else:
        st.warning("Please enter a question.")
        
# Footer section with social links
st.markdown("""
    <div class="social-icons">
        <a href="https://github.com/yadavadit" target="_blank"><img src="https://img.icons8.com/material-outlined/48/e50914/github.png"/></a>
        <a href="https://www.linkedin.com/in/yaditi/" target="_blank"><img src="https://img.icons8.com/color/48/e50914/linkedin.png"/></a>
        <a href="mailto:[email protected]"><img src="https://img.icons8.com/color/48/e50914/gmail.png"/></a>
    </div>
    """, unsafe_allow_html=True)