gabruarya commited on
Commit
4b183a8
·
1 Parent(s): 4ab0328

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +0 -0
  2. streamlit_app.py +157 -0
requirements.txt ADDED
Binary file (8.18 kB). View file
 
streamlit_app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+ import streamlit as st
4
+ import os
5
+ from llamaapi import LlamaAPI
6
+ from langchain_experimental.llms import ChatLlamaAPI
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ import pinecone
9
+ from langchain.vectorstores import Pinecone
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.chains import RetrievalQA
12
+ import streamlit.components.v1 as components
13
+
14
+ HUGGINGFACEHUB_API_TOKEN = st.secrets['HUGGINGFACEHUB_API_TOKEN']
15
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN
16
+
17
+
18
+ @dataclass
19
+ class Message:
20
+ """Class for keeping track of a chat message."""
21
+ origin: Literal["human", "ai"]
22
+ message: str
23
+
24
+
25
+ def load_css():
26
+ with open("static/styles.css", "r") as f:
27
+ css = f"<style>{f.read()}</style>"
28
+ st.markdown(css, unsafe_allow_html=True)
29
+
30
+
31
+ def download_hugging_face_embeddings():
32
+ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
33
+ return embeddings
34
+
35
+
36
+ def initialize_session_state():
37
+ if "history" not in st.session_state:
38
+ st.session_state.history = []
39
+ if "conversation" not in st.session_state:
40
+ llama = LlamaAPI(st.secrets["LlamaAPI"])
41
+ model = ChatLlamaAPI(client=llama)
42
+
43
+ embeddings = download_hugging_face_embeddings()
44
+
45
+ # Initializing the Pinecone
46
+ pinecone.init(
47
+ api_key=st.secrets["PINECONE_API_KEY"], # find at app.pinecone.io
48
+ environment=st.secrets["PINECONE_API_ENV"] # next to api key in console
49
+ )
50
+ index_name = "legal-advisor" # put in the name of your pinecone index here
51
+
52
+ docsearch = Pinecone.from_existing_index(index_name, embeddings)
53
+
54
+ prompt_template = """
55
+ You are a trained bot to guide people about Indian Law. You will answer user's query with your knowledge and the context provided.
56
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
57
+ Do not say thank you and tell you are an AI Assistant and be open about everything.
58
+ Use the following pieces of context to answer the users question.
59
+ Context: {context}
60
+ Question: {question}
61
+ Only return the helpful answer below and nothing else.
62
+ Helpful answer:
63
+ """
64
+
65
+ PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
66
+ chain_type_kwargs = {"prompt": PROMPT}
67
+ retrieval_chain = RetrievalQA.from_chain_type(llm=model,
68
+ chain_type="stuff",
69
+ retriever=docsearch.as_retriever(
70
+ search_kwargs={'k': 2}),
71
+ return_source_documents=True,
72
+ chain_type_kwargs=chain_type_kwargs)
73
+
74
+ st.session_state.conversation = retrieval_chain
75
+
76
+
77
+ def on_click_callback():
78
+ human_prompt = st.session_state.human_prompt
79
+ response = st.session_state.conversation(
80
+ human_prompt
81
+ )
82
+ llm_response = response['result']
83
+ print(llm_response)
84
+ st.session_state.history.append(
85
+ Message("human", human_prompt)
86
+ )
87
+ st.session_state.history.append(
88
+ Message("ai", llm_response)
89
+ )
90
+
91
+
92
+ load_css()
93
+ initialize_session_state()
94
+
95
+ st.title("Hello Custom CSS Chatbot 🤖")
96
+
97
+ chat_placeholder = st.container()
98
+ prompt_placeholder = st.form("chat-form")
99
+
100
+ with chat_placeholder:
101
+ for chat in st.session_state.history:
102
+ div = f"""
103
+ <div class="chat-row
104
+ {'' if chat.origin == 'ai' else 'row-reverse'}">
105
+ <img class="chat-icon" src="app/static/{
106
+ 'ai_icon.png' if chat.origin == 'ai'
107
+ else 'user_icon.png'}"
108
+ width=32 height=32>
109
+ <div class="chat-bubble
110
+ {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
111
+ &#8203;{chat.message}
112
+ </div>
113
+ </div>
114
+ """
115
+ st.markdown(div, unsafe_allow_html=True)
116
+
117
+ for _ in range(3):
118
+ st.markdown("")
119
+
120
+ with prompt_placeholder:
121
+ st.markdown("**Chat**")
122
+ cols = st.columns((6, 1))
123
+ cols[0].text_input(
124
+ "Chat",
125
+ value="Hello bot",
126
+ label_visibility="collapsed",
127
+ key="human_prompt",
128
+ )
129
+ cols[1].form_submit_button(
130
+ "Submit",
131
+ type="primary",
132
+ on_click=on_click_callback,
133
+ )
134
+
135
+ components.html("""
136
+ <script>
137
+ const streamlitDoc = window.parent.document;
138
+
139
+ const buttons = Array.from(
140
+ streamlitDoc.querySelectorAll('.stButton > button')
141
+ );
142
+ const submitButton = buttons.find(
143
+ el => el.innerText === 'Submit'
144
+ );
145
+
146
+ streamlitDoc.addEventListener('keydown', function(e) {
147
+ switch (e.key) {
148
+ case 'Enter':
149
+ submitButton.click();
150
+ break;
151
+ }
152
+ });
153
+ </script>
154
+ """,
155
+ height=0,
156
+ width=0,
157
+ )