Harsh2001 commited on
Commit
65cd502
·
verified ·
1 Parent(s): 836db00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -87
app.py CHANGED
@@ -1,87 +1,72 @@
1
- import os
2
- import warnings
3
- import nest_asyncio
4
- import streamlit as st
5
- from dotenv import load_dotenv
6
- from DataLoading.Data import get_data
7
- from llama_index.core import Settings
8
- from llama_index.llms.groq import Groq
9
- from llama_index.vector_stores.faiss import FaissVectorStore
10
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
11
- from llama_index.core import StorageContext, load_index_from_storage
12
-
13
- nest_asyncio.apply()
14
- load_dotenv()
15
- warnings.filterwarnings("ignore")
16
-
17
- def init_llm(model_name):
18
- return Groq(model=model_name, api_key=os.getenv("GROQ_API_KEY"))
19
-
20
- @st.cache_resource
21
- def load_index(selected_model):
22
- curr_direc = os.getcwd()
23
- file_path = os.path.join(curr_direc, 'processed_data.csv')
24
- # print(file_path)
25
- get_data(file_path)
26
- model = init_llm(selected_model)
27
- embedding_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
28
-
29
- Settings.embed_model = embedding_model
30
- Settings.llm = model
31
-
32
- vector_store = FaissVectorStore.from_persist_dir('storage')
33
- storage_context = StorageContext.from_defaults(
34
- vector_store=vector_store, persist_dir='storage'
35
- )
36
- index = load_index_from_storage(storage_context=storage_context)
37
- return index.as_query_engine()
38
-
39
- st.title("Chatbot from ClienterAI")
40
-
41
- st.sidebar.header("Settings")
42
- selected_model = st.sidebar.selectbox(
43
- "Select Groq Model:",
44
- options=["mixtral-8x7b-32768", "gemma2-9b-it", "llama-3.1-70b-versatile", "llama3-8b-8192", "llava-v1.5-7b-4096-preview"],
45
- index=0
46
- )
47
-
48
- query_engine = load_index(selected_model)
49
-
50
- if "messages" not in st.session_state:
51
- st.session_state["messages"] = []
52
-
53
-
54
- with st.form("chat_form", clear_on_submit=True):
55
- user_input = st.text_input("Ask a question based on your data:", "")
56
- submitted = st.form_submit_button("Send")
57
-
58
- if submitted and user_input:
59
- st.session_state["messages"].append({"role": "user", "content": user_input})
60
- response = query_engine.query(user_input)
61
- ai_response = response
62
- st.session_state["messages"].append({"role": "assistant", "content": ai_response})
63
-
64
- for message in st.session_state["messages"]:
65
- if message["role"] == "user":
66
- st.markdown(f"**You:** {message['content']}")
67
- else:
68
- st.markdown(f"**Assistant:** {message['content']}")
69
-
70
- if st.sidebar.button("Clear Chat"):
71
- st.session_state["messages"] = []
72
- st.sidebar.success("Chat cleared!")
73
-
74
-
75
- st.markdown("""
76
- <style>
77
- .stForm {
78
- position: fixed;
79
- align-self: center;
80
- bottom: 0;
81
- width: 50%;
82
- left: 25%;
83
- right: 50%;
84
- padding: 10px;
85
- }
86
- <style>
87
- """, unsafe_allow_html=True)
 
1
+ import os
2
+ import warnings
3
+ import nest_asyncio
4
+ import streamlit as st
5
+ from dotenv import load_dotenv
6
+ from DataLoading.Data import get_data
7
+ from llama_index.core import Settings
8
+ from llama_index.llms.groq import Groq
9
+ from llama_index.vector_stores.faiss import FaissVectorStore
10
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
11
+ from llama_index.core import StorageContext, load_index_from_storage
12
+
13
+ nest_asyncio.apply()
14
+ load_dotenv()
15
+ warnings.filterwarnings("ignore")
16
+
17
+ def init_llm(model_name):
18
+ return Groq(model=model_name, api_key=os.getenv("GROQ_API_KEY"))
19
+
20
+ @st.cache_resource
21
+ def load_index(selected_model):
22
+ curr_direc = os.getcwd()
23
+ file_path = os.path.join(curr_direc, 'processed_data.csv')
24
+ # print(file_path)
25
+ get_data(file_path)
26
+ model = init_llm(selected_model)
27
+ embedding_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
28
+
29
+ Settings.embed_model = embedding_model
30
+ Settings.llm = model
31
+
32
+ vector_store = FaissVectorStore.from_persist_dir('storage')
33
+ storage_context = StorageContext.from_defaults(
34
+ vector_store=vector_store, persist_dir='storage'
35
+ )
36
+ index = load_index_from_storage(storage_context=storage_context)
37
+ return index.as_query_engine()
38
+
39
+ st.title("Chatbot from ClienterAI")
40
+
41
+ st.sidebar.header("Settings")
42
+ selected_model = st.sidebar.selectbox(
43
+ "Select Groq Model:",
44
+ options=["mixtral-8x7b-32768", "gemma2-9b-it", "llama-3.1-70b-versatile", "llama3-8b-8192", "llava-v1.5-7b-4096-preview"],
45
+ index=0
46
+ )
47
+
48
+ query_engine = load_index(selected_model)
49
+
50
+ if "messages" not in st.session_state:
51
+ st.session_state["messages"] = []
52
+
53
+
54
+ with st.form("chat_form", clear_on_submit=True):
55
+ user_input = st.text_input("Ask a question based on your data:", "")
56
+ submitted = st.form_submit_button("Send")
57
+
58
+ if submitted and user_input:
59
+ st.session_state["messages"].append({"role": "user", "content": user_input})
60
+ response = query_engine.query(user_input)
61
+ ai_response = response
62
+ st.session_state["messages"].append({"role": "assistant", "content": ai_response})
63
+
64
+ for message in st.session_state["messages"]:
65
+ if message["role"] == "user":
66
+ st.markdown(f"**You:** {message['content']}")
67
+ else:
68
+ st.markdown(f"**Assistant:** {message['content']}")
69
+
70
+ if st.sidebar.button("Clear Chat"):
71
+ st.session_state["messages"] = []
72
+ st.sidebar.success("Chat cleared!")