naotakigawa commited on
Commit
bbe0b27
·
1 Parent(s): 478c7ef

try commit

Browse files
Files changed (1) hide show
  1. app.py +215 -191
app.py CHANGED
@@ -1,191 +1,215 @@
1
- import streamlit as st
2
- import os
3
- import pickle
4
- import faiss
5
- import logging
6
-
7
- from multiprocessing import Lock
8
- from multiprocessing.managers import BaseManager
9
- from llama_index.callbacks import CallbackManager, LlamaDebugHandler
10
- from llama_index import VectorStoreIndex, Document,Prompt, SimpleDirectoryReader, ServiceContext, StorageContext, load_index_from_storage
11
- from llama_index.chat_engine import CondenseQuestionChatEngine;
12
- from llama_index.node_parser import SimpleNodeParser
13
- from llama_index.langchain_helpers.text_splitter import TokenTextSplitter
14
- from llama_index.constants import DEFAULT_CHUNK_OVERLAP
15
- from llama_index.response_synthesizers import get_response_synthesizer
16
- from llama_index.vector_stores.faiss import FaissVectorStore
17
- from llama_index.graph_stores import SimpleGraphStore
18
- from llama_index.storage.docstore import SimpleDocumentStore
19
- from llama_index.storage.index_store import SimpleIndexStore
20
- import tiktoken
21
- import requests
22
-
23
- from requests_oauthlib import OAuth2Session
24
- from time import time
25
- from dotenv import load_dotenv
26
- from streamlit import net_util
27
-
28
- load_dotenv()
29
-
30
- # 接続元制御
31
- ALLOW_IP_ADDRESS = os.environ["ALLOW_IP_ADDRESS"]
32
-
33
- # Azure AD app registration details
34
- CLIENT_ID = os.environ["CLIENT_ID"]
35
- CLIENT_SECRET = os.environ["CLIENT_SECRET"]
36
- TENANT_ID = os.environ["TENANT_ID"]
37
-
38
- # Azure API
39
- AUTHORITY = f"https://login.microsoftonline.com/{TENANT_ID}"
40
- REDIRECT_PATH = os.environ["REDIRECT_PATH"]
41
- TOKEN_URL = f"{AUTHORITY}/oauth2/v2.0/token"
42
- AUTHORIZATION_URL = f"{AUTHORITY}/oauth2/v2.0/authorize"
43
- SCOPES = ["openid", "profile", "User.Read"]
44
-
45
- # 認証用URL取得
46
- def authorization_request():
47
- oauth = OAuth2Session(CLIENT_ID, redirect_uri=REDIRECT_PATH, scope=SCOPES)
48
- authorization_url, state = oauth.authorization_url(AUTHORIZATION_URL)
49
- return authorization_url, state
50
-
51
- # 認証トークン取得
52
- def token_request(authorization_response, state):
53
- oauth = OAuth2Session(CLIENT_ID, state=state)
54
- token = oauth.fetch_token(
55
- TOKEN_URL,
56
- code=authorization_response[0],
57
- authorization_response=authorization_response,
58
- client_secret=CLIENT_SECRET,
59
-
60
- )
61
- return token
62
-
63
- index_name = "./data/storage"
64
- pkl_name = "./data/stored_documents.pkl"
65
-
66
- custom_prompt = Prompt("""\
67
- 以下はこれまでの会話履歴と、ドキュメントを検索して回答する必要がある、ユーザーからの会話文です。
68
- 会話と新しい会話文に基づいて、検索クエリを作成します。回答は日本語で行います。
69
- 新しい会話文が挨拶の場合、挨拶を返してください。
70
- 新しい会話文が質問の場合、検索した結果の回答を返してください。
71
- 答えがわからない場合は正直にわからないと回答してください。
72
- 会話履歴:
73
- {chat_history}
74
- 新しい会話文:
75
- {question}
76
- Search query:
77
- """)
78
-
79
- chat_history = []
80
-
81
- logging.basicConfig(level=logging.INFO)
82
- logger = logging.getLogger("__name__")
83
- logger.debug("調査用ログ")
84
-
85
- def initialize_index():
86
- logger.info("initialize_index start")
87
- text_splitter = TokenTextSplitter(separator="。", chunk_size=1500
88
- , chunk_overlap=DEFAULT_CHUNK_OVERLAP
89
- , tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode)
90
- node_parser = SimpleNodeParser(text_splitter=text_splitter)
91
- d = 1536
92
- k=2
93
- faiss_index = faiss.IndexFlatL2(d)
94
- # デバッグ用
95
- llama_debug_handler = LlamaDebugHandler()
96
- callback_manager = CallbackManager([llama_debug_handler])
97
- service_context = ServiceContext.from_defaults(node_parser=node_parser,callback_manager=callback_manager)
98
- lock = Lock()
99
- with lock:
100
- if os.path.exists(index_name):
101
- storage_context = StorageContext.from_defaults(
102
- docstore=SimpleDocumentStore.from_persist_dir(persist_dir=index_name),
103
- graph_store=SimpleGraphStore.from_persist_dir(persist_dir=index_name),
104
- vector_store=FaissVectorStore.from_persist_dir(persist_dir=index_name),
105
- index_store=SimpleIndexStore.from_persist_dir(persist_dir=index_name),
106
- )
107
- st.session_state.index = load_index_from_storage(storage_context=storage_context,service_context=service_context)
108
- response_synthesizer = get_response_synthesizer(response_mode='refine')
109
- st.session_state.query_engine = st.session_state.index.as_query_engine(response_synthesizer=response_synthesizer,service_context=service_context)
110
- st.session_state.chat_engine = CondenseQuestionChatEngine.from_defaults(
111
- query_engine=st.session_state.query_engine,
112
- condense_question_prompt=custom_prompt,
113
- chat_history=chat_history,
114
- verbose=True
115
- )
116
- else:
117
- documents = SimpleDirectoryReader("./documents").load_data()
118
- vector_store = FaissVectorStore(faiss_index=faiss_index)
119
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
120
- st.session_state.index = VectorStoreIndex.from_documents(documents, storage_context=storage_context,service_context=service_context)
121
- st.session_state.index.storage_context.persist(persist_dir=index_name)
122
- response_synthesizer = get_response_synthesizer(response_mode='refine')
123
- st.session_state.query_engine = st.session_state.index.as_query_engine(response_synthesizer=response_synthesizer,service_context=service_context)
124
- st.session_state.chat_engine = CondenseQuestionChatEngine.from_defaults(
125
- query_engine=st.session_state.query_engine,
126
- condense_question_prompt=custom_prompt,
127
- chat_history=chat_history,
128
- verbose=True
129
- )
130
- if os.path.exists(pkl_name):
131
- with open(pkl_name, "rb") as f:
132
- st.session_state.stored_docs = pickle.load(f)
133
- else:
134
- st.session_state.stored_docs=list()
135
-
136
- def logout():
137
- st.session_state["token"] = None
138
- st.session_state["token_expires"] = None
139
- st.session_state["authorization_state"] = None
140
-
141
- # メイン
142
- def app():
143
- # 初期化
144
- st.session_state["token"] = None
145
- st.session_state["token_expires"] = time()
146
- st.session_state["authorization_state"] = None
147
-
148
- # 接続元IP許可判定
149
- if not net_util.get_external_ip() in ALLOW_IP_ADDRESS:
150
- st.title("HTTP 403 Forbidden")
151
- return
152
-
153
- # 接続元OK
154
- st.title("Azure AD Login with Streamlit")
155
-
156
- # 認証後のリダイレクトのGETパラメータ値を取得
157
- authorization_response = st.experimental_get_query_params().get("code")
158
-
159
- # 認証OK、トークン無し
160
- if authorization_response and st.session_state["token"] is None:
161
- # トークン設定
162
- token = token_request(authorization_response, st.session_state["authorization_state"])
163
- st.session_state["token"] = token
164
- st.session_state["token_expires"] = token["expires_at"]
165
-
166
- # トークン無し or 期限切れ
167
- if st.session_state["token"] is None or float(st.session_state["token_expires"]) <= time():
168
- # 認証用リンク表示
169
- authorization_url, st.session_state["authorization_state"] = authorization_request()
170
- st.markdown(f'[Click here to log in]({authorization_url})', unsafe_allow_html=True)
171
- else:
172
- # 認証OK
173
- st.markdown(f"Logged in successfully. Welcome, {st.session_state['token']['token_type']}!")
174
- if st.button("logout",use_container_width=True):
175
- logout()
176
- st.experimental_set_query_params(test="test")
177
- st.experimental_rerun()
178
- st.text("サイドバーから利用するメニューをお選びください。")
179
- initialize_index()
180
-
181
- if __name__ == "__main__":
182
- if "token" not in st.session_state or st.session_state["token"] is None or float(st.session_state["token_expires"]) <= time():
183
- app()
184
- else:
185
- st.title("Azure AD Login with Streamlit")
186
- if st.button("logout",use_container_width=True):
187
- logout()
188
- st.experimental_set_query_params(test="test")
189
- st.experimental_rerun()
190
- st.text("ログイン済みです。")
191
- st.text("サイドバーから利用するメニューをお選びください。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import pickle
4
+ import faiss
5
+ import logging
6
+
7
+ from multiprocessing import Lock
8
+ from multiprocessing.managers import BaseManager
9
+ from llama_index.callbacks import CallbackManager, LlamaDebugHandler
10
+ from llama_index import VectorStoreIndex, Document,Prompt, SimpleDirectoryReader, ServiceContext, StorageContext, load_index_from_storage
11
+ from llama_index.chat_engine import CondenseQuestionChatEngine;
12
+ from llama_index.node_parser import SimpleNodeParser
13
+ from llama_index.langchain_helpers.text_splitter import TokenTextSplitter
14
+ from llama_index.constants import DEFAULT_CHUNK_OVERLAP
15
+ from llama_index.response_synthesizers import get_response_synthesizer
16
+ from llama_index.vector_stores.faiss import FaissVectorStore
17
+ from llama_index.graph_stores import SimpleGraphStore
18
+ from llama_index.storage.docstore import SimpleDocumentStore
19
+ from llama_index.storage.index_store import SimpleIndexStore
20
+ import tiktoken
21
+ from streamlit import runtime
22
+ from streamlit.runtime.scriptrunner import get_script_run_ctx
23
+ import ipaddress
24
+
25
+ from requests_oauthlib import OAuth2Session
26
+ from time import time
27
+ from dotenv import load_dotenv
28
+ from streamlit import net_util
29
+
30
+ load_dotenv()
31
+
32
+ # 接続元制御
33
+ ALLOW_IP_ADDRESS = os.environ["ALLOW_IP_ADDRESS"]
34
+
35
+ # Azure AD app registration details
36
+ CLIENT_ID = os.environ["CLIENT_ID"]
37
+ CLIENT_SECRET = os.environ["CLIENT_SECRET"]
38
+ TENANT_ID = os.environ["TENANT_ID"]
39
+
40
+ # Azure API
41
+ AUTHORITY = f"https://login.microsoftonline.com/{TENANT_ID}"
42
+ REDIRECT_PATH = os.environ["REDIRECT_PATH"]
43
+ TOKEN_URL = f"{AUTHORITY}/oauth2/v2.0/token"
44
+ AUTHORIZATION_URL = f"{AUTHORITY}/oauth2/v2.0/authorize"
45
+ SCOPES = ["openid", "profile", "User.Read"]
46
+
47
+ # 認証用URL取得
48
+ def authorization_request():
49
+ oauth = OAuth2Session(CLIENT_ID, redirect_uri=REDIRECT_PATH, scope=SCOPES)
50
+ authorization_url, state = oauth.authorization_url(AUTHORIZATION_URL)
51
+ return authorization_url, state
52
+
53
+ # 認証トークン取得
54
+ def token_request(authorization_response, state):
55
+ oauth = OAuth2Session(CLIENT_ID, state=state)
56
+ token = oauth.fetch_token(
57
+ TOKEN_URL,
58
+ code=authorization_response[0],
59
+ authorization_response=authorization_response,
60
+ client_secret=CLIENT_SECRET,
61
+
62
+ )
63
+ return token
64
+
65
+ index_name = "./data/storage"
66
+ pkl_name = "./data/stored_documents.pkl"
67
+
68
+ custom_prompt = Prompt("""\
69
+ 以下はこれまでの会話履歴と、ドキュメントを検索して回答する必要がある、ユーザーからの会話文です。
70
+ 会話と新しい会話文に基づいて、検索クエリを作成します。回答は日本語で行います。
71
+ 新しい会話文が挨拶の場合、挨拶を返してください。
72
+ 新しい会話文が質問の場合、検索した結果の回答を返してください。
73
+ 答えがわからない場合は正直にわからないと回答してください。
74
+ 会話履歴:
75
+ {chat_history}
76
+ 新しい会話文:
77
+ {question}
78
+ Search query:
79
+ """)
80
+
81
+ chat_history = []
82
+
83
+ logging.basicConfig(level=logging.INFO)
84
+ logger = logging.getLogger("__name__")
85
+ logger.debug("調査用ログ")
86
+
87
+ def initialize_index():
88
+ logger.info("initialize_index start")
89
+ text_splitter = TokenTextSplitter(separator="", chunk_size=1500
90
+ , chunk_overlap=DEFAULT_CHUNK_OVERLAP
91
+ , tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode)
92
+ node_parser = SimpleNodeParser(text_splitter=text_splitter)
93
+ d = 1536
94
+ k=2
95
+ faiss_index = faiss.IndexFlatL2(d)
96
+ # デバッグ用
97
+ llama_debug_handler = LlamaDebugHandler()
98
+ callback_manager = CallbackManager([llama_debug_handler])
99
+ service_context = ServiceContext.from_defaults(node_parser=node_parser,callback_manager=callback_manager)
100
+ lock = Lock()
101
+ with lock:
102
+ if os.path.exists(index_name):
103
+ storage_context = StorageContext.from_defaults(
104
+ docstore=SimpleDocumentStore.from_persist_dir(persist_dir=index_name),
105
+ graph_store=SimpleGraphStore.from_persist_dir(persist_dir=index_name),
106
+ vector_store=FaissVectorStore.from_persist_dir(persist_dir=index_name),
107
+ index_store=SimpleIndexStore.from_persist_dir(persist_dir=index_name),
108
+ )
109
+ st.session_state.index = load_index_from_storage(storage_context=storage_context,service_context=service_context)
110
+ response_synthesizer = get_response_synthesizer(response_mode='refine')
111
+ st.session_state.query_engine = st.session_state.index.as_query_engine(response_synthesizer=response_synthesizer,service_context=service_context)
112
+ st.session_state.chat_engine = CondenseQuestionChatEngine.from_defaults(
113
+ query_engine=st.session_state.query_engine,
114
+ condense_question_prompt=custom_prompt,
115
+ chat_history=chat_history,
116
+ verbose=True
117
+ )
118
+ else:
119
+ documents = SimpleDirectoryReader("./documents").load_data()
120
+ vector_store = FaissVectorStore(faiss_index=faiss_index)
121
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
122
+ st.session_state.index = VectorStoreIndex.from_documents(documents, storage_context=storage_context,service_context=service_context)
123
+ st.session_state.index.storage_context.persist(persist_dir=index_name)
124
+ response_synthesizer = get_response_synthesizer(response_mode='refine')
125
+ st.session_state.query_engine = st.session_state.index.as_query_engine(response_synthesizer=response_synthesizer,service_context=service_context)
126
+ st.session_state.chat_engine = CondenseQuestionChatEngine.from_defaults(
127
+ query_engine=st.session_state.query_engine,
128
+ condense_question_prompt=custom_prompt,
129
+ chat_history=chat_history,
130
+ verbose=True
131
+ )
132
+ if os.path.exists(pkl_name):
133
+ with open(pkl_name, "rb") as f:
134
+ st.session_state.stored_docs = pickle.load(f)
135
+ else:
136
+ st.session_state.stored_docs=list()
137
+
138
+ # 接続元IP取得
139
+ def get_remote_ip():
140
+ ctx = get_script_run_ctx()
141
+ session_info = runtime.get_instance().get_client(ctx.session_id)
142
+ return session_info.request.remote_ip
143
+
144
+ # 接続元IP許可判定
145
+ def is_allow_ip_address():
146
+ remote_ip = get_remote_ip()
147
+
148
+ # localhost
149
+ if remote_ip == "::1":
150
+ return True
151
+
152
+ # プライベートIP
153
+ ipaddr = ipaddress.IPv4Address(remote_ip)
154
+ if ipaddr.is_private:
155
+ return True
156
+
157
+ # そ��他(許可リスト判定)
158
+ return remote_ip in ALLOW_IP_ADDRESS
159
+
160
+ def logout():
161
+ st.session_state["token"] = None
162
+ st.session_state["token_expires"] = None
163
+ st.session_state["authorization_state"] = None
164
+
165
+ # メイン
166
+ def app():
167
+ # 初期化
168
+ st.session_state["token"] = None
169
+ st.session_state["token_expires"] = time()
170
+ st.session_state["authorization_state"] = None
171
+
172
+ # 接続元IP許可判定
173
+ if not is_allow_ip_address():
174
+ st.title("HTTP 403 Forbidden")
175
+ return
176
+
177
+ # 接続元OK
178
+ st.title("Azure AD Login with Streamlit")
179
+
180
+ # 認証後のリダイレクトのGETパラメータ値を取得
181
+ authorization_response = st.experimental_get_query_params().get("code")
182
+
183
+ # 認証OK、トークン無し
184
+ if authorization_response and st.session_state["token"] is None:
185
+ # トークン設定
186
+ token = token_request(authorization_response, st.session_state["authorization_state"])
187
+ st.session_state["token"] = token
188
+ st.session_state["token_expires"] = token["expires_at"]
189
+
190
+ # トークン無し or 期限切れ
191
+ if st.session_state["token"] is None or float(st.session_state["token_expires"]) <= time():
192
+ # 認証用リンク表示
193
+ authorization_url, st.session_state["authorization_state"] = authorization_request()
194
+ st.markdown(f'[Click here to log in]({authorization_url})', unsafe_allow_html=True)
195
+ else:
196
+ # 認証OK
197
+ st.markdown(f"Logged in successfully. Welcome, {st.session_state['token']['token_type']}!")
198
+ if st.button("logout",use_container_width=True):
199
+ logout()
200
+ st.experimental_set_query_params()
201
+ st.experimental_rerun()
202
+ st.text("サイドバーから利用するメニューをお選びください。")
203
+ initialize_index()
204
+
205
+ if __name__ == "__main__":
206
+ if "token" not in st.session_state or st.session_state["token"] is None or float(st.session_state["token_expires"]) <= time():
207
+ app()
208
+ else:
209
+ st.title("Azure AD Login with Streamlit")
210
+ if st.button("logout",use_container_width=True):
211
+ logout()
212
+ st.experimental_set_query_params()
213
+ st.experimental_rerun()
214
+ st.text("ログイン済みです。")
215
+ st.text("サイドバーから利用するメニューをお選びください。")