saswattulo commited on
Commit
be78402
Β·
verified Β·
1 Parent(s): dc76787

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import streamlit as st
4
+ from cassandra.auth import PlainTextAuthProvider
5
+ from cassandra.cluster import Cluster
6
+ from llama_index import ServiceContext
7
+ from llama_index import set_global_service_context
8
+ from llama_index import VectorStoreIndex, SimpleDirectoryReader, StorageContext
9
+ from llama_index.embeddings import GradientEmbedding
10
+ from llama_index.llms import GradientBaseModelLLM
11
+ from llama_index.vector_stores import CassandraVectorStore
12
+ from copy import deepcopy
13
+ from tempfile import NamedTemporaryFile
14
+
15
+ os.environ['GRADIENT_ACCESS_TOKEN'] = "sevG6Rqb0ztaquM4xjr83SBNSYj91cux"
16
+ os.environ['GRADIENT_WORKSPACE_ID'] = "4de36c1f-5ee6-41da-8f95-9d2fb1ded33a_workspace"
17
+
18
+ @st.cache_resource
19
+ def create_datastax_connection():
20
+
21
+ cloud_config= {'secure_connect_bundle': 'secure-connect-temp-db.zip'}
22
+
23
+ with open("temp_db-token.json") as f:
24
+ secrets = json.load(f)
25
+
26
+ CLIENT_ID = secrets["clientId"]
27
+ CLIENT_SECRET = secrets["secret"]
28
+
29
+ auth_provider = PlainTextAuthProvider(CLIENT_ID, CLIENT_SECRET)
30
+ cluster = Cluster(cloud=cloud_config, auth_provider=auth_provider)
31
+ astra_session = cluster.connect()
32
+ return astra_session
33
+
34
+ def main():
35
+
36
+ index_placeholder = None
37
+ st.set_page_config(page_title = "NyayMitra", page_icon="πŸ¦™")
38
+ st.header('NyayMitra')
39
+
40
+ if "conversation" not in st.session_state:
41
+ st.session_state.conversation = None
42
+
43
+ if "activate_chat" not in st.session_state:
44
+ st.session_state.activate_chat = False
45
+
46
+ if "messages" not in st.session_state:
47
+ st.session_state.messages = []
48
+
49
+ for message in st.session_state.messages:
50
+ with st.chat_message(message["role"], avatar = message['avatar']):
51
+ st.markdown(message["content"])
52
+
53
+ session = create_datastax_connection()
54
+
55
+ os.environ['GRADIENT_ACCESS_TOKEN'] = "sevG6Rqb0ztaquM4xjr83SBNSYj91cux"
56
+ os.environ['GRADIENT_WORKSPACE_ID'] = "4de36c1f-5ee6-41da-8f95-9d2fb1ded33a_workspace"
57
+
58
+ llm = GradientBaseModelLLM(base_model_slug="llama2-7b-chat", max_tokens=400)
59
+
60
+ embed_model = GradientEmbedding(
61
+ gradient_access_token = os.environ["GRADIENT_ACCESS_TOKEN"],
62
+ gradient_workspace_id = os.environ["GRADIENT_WORKSPACE_ID"],
63
+ gradient_model_slug="bge-large")
64
+
65
+ service_context = ServiceContext.from_defaults(
66
+ llm = llm,
67
+ embed_model = embed_model,
68
+ chunk_size=256)
69
+
70
+ set_global_service_context(service_context)
71
+
72
+ with st.sidebar:
73
+ st.subheader('Start your chat here')
74
+ if st.button('Process'):
75
+ with st.spinner('Processing'):
76
+ reader = 'data'
77
+
78
+ documents = SimpleDirectoryReader(reader).load_data()
79
+ index = VectorStoreIndex.from_documents(documents,
80
+ service_context=service_context)
81
+ query_engine = index.as_query_engine()
82
+ if "query_engine" not in st.session_state:
83
+ st.session_state.query_engine = query_engine
84
+ st.session_state.activate_chat = True
85
+
86
+ if st.session_state.activate_chat == True:
87
+ if prompt := st.chat_input("Ask your question"):
88
+ with st.chat_message("user", avatar = 'πŸ‘¨πŸ»'):
89
+ st.markdown(prompt)
90
+ st.session_state.messages.append({"role": "user",
91
+ "avatar" :'πŸ‘¨πŸ»',
92
+ "content": prompt})
93
+
94
+ query_index_placeholder = st.session_state.query_engine
95
+ pdf_response = query_index_placeholder.query(prompt)
96
+ cleaned_response = pdf_response.response
97
+ with st.chat_message("assistant", avatar='πŸ€–'):
98
+ st.markdown(cleaned_response)
99
+ st.session_state.messages.append({"role": "assistant",
100
+ "avatar" :'πŸ€–',
101
+ "content": cleaned_response})
102
+ else:
103
+ st.markdown(
104
+ ' '
105
+ )
106
+
107
+
108
+ if __name__ == '__main__':
109
+ main()