naotakigawa commited on
Commit
e6a7ab6
·
1 Parent(s): 231ac24

ログイン修正

Browse files
Files changed (2) hide show
  1. app.py +29 -72
  2. common.py +4 -25
app.py CHANGED
@@ -21,7 +21,8 @@ import tiktoken
21
  from streamlit import runtime
22
  from streamlit.runtime.scriptrunner import get_script_run_ctx
23
  import ipaddress
24
-
 
25
  from requests_oauthlib import OAuth2Session
26
  from time import time
27
  from dotenv import load_dotenv
@@ -32,36 +33,6 @@ load_dotenv()
32
  # 接続元制御
33
  ALLOW_IP_ADDRESS = os.environ["ALLOW_IP_ADDRESS"]
34
 
35
- # Azure AD app registration details
36
- CLIENT_ID = os.environ["CLIENT_ID"]
37
- CLIENT_SECRET = os.environ["CLIENT_SECRET"]
38
- TENANT_ID = os.environ["TENANT_ID"]
39
-
40
- # Azure API
41
- AUTHORITY = f"https://login.microsoftonline.com/{TENANT_ID}"
42
- REDIRECT_PATH = os.environ["REDIRECT_PATH"]
43
- TOKEN_URL = f"{AUTHORITY}/oauth2/v2.0/token"
44
- AUTHORIZATION_URL = f"{AUTHORITY}/oauth2/v2.0/authorize"
45
- SCOPES = ["openid", "profile", "User.Read"]
46
-
47
- # 認証用URL取得
48
- def authorization_request():
49
- oauth = OAuth2Session(CLIENT_ID, redirect_uri=REDIRECT_PATH, scope=SCOPES)
50
- authorization_url, state = oauth.authorization_url(AUTHORIZATION_URL)
51
- return authorization_url, state
52
-
53
- # 認証トークン取得
54
- def token_request(authorization_response, state):
55
- oauth = OAuth2Session(CLIENT_ID, state=state)
56
- token = oauth.fetch_token(
57
- TOKEN_URL,
58
- code=authorization_response[0],
59
- authorization_response=authorization_response,
60
- client_secret=CLIENT_SECRET,
61
-
62
- )
63
- return token
64
-
65
  index_name = "./data/storage"
66
  pkl_name = "./data/stored_documents.pkl"
67
 
@@ -160,11 +131,6 @@ def is_allow_ip_address():
160
  # その他(許可リスト判定)
161
  return remote_ip in ALLOW_IP_ADDRESS
162
 
163
- def logout():
164
- st.session_state["token"] = None
165
- st.session_state["token_expires"] = None
166
- st.session_state["authorization_state"] = None
167
-
168
  # メイン
169
  def app():
170
  # 初期化
@@ -180,39 +146,30 @@ def app():
180
  # 接続元OK
181
  st.title("Azure AD Login with Streamlit")
182
 
183
- # 認証後のリダイレクトのGETパラメータ値を取得
184
- authorization_response = st.experimental_get_query_params().get("code")
185
-
186
- # 認証OK、トークン無し
187
- if authorization_response and st.session_state["token"] is None:
188
- # トークン設定
189
- token = token_request(authorization_response, st.session_state["authorization_state"])
190
- st.session_state["token"] = token
191
- st.session_state["token_expires"] = token["expires_at"]
192
-
193
- # トークン無し or 期限切れ
194
- if st.session_state["token"] is None or float(st.session_state["token_expires"]) <= time():
195
- # 認証用リンク表示
196
- authorization_url, st.session_state["authorization_state"] = authorization_request()
197
- st.markdown(f'[Click here to log in]({authorization_url})', unsafe_allow_html=True)
198
- else:
199
- # 認証OK
200
- st.markdown(f"Logged in successfully. Welcome, {st.session_state['token']['token_type']}!")
201
- if st.button("logout",use_container_width=True):
202
- logout()
203
- st.experimental_set_query_params()
204
- st.experimental_rerun()
205
- st.text("サイドバーから利用するメニューをお選びください。")
206
- initialize_index()
207
-
208
- if __name__ == "__main__":
209
- if "token" not in st.session_state or st.session_state["token"] is None or float(st.session_state["token_expires"]) <= time():
210
- app()
211
- else:
212
- st.title("Azure AD Login with Streamlit")
213
- if st.button("logout",use_container_width=True):
214
- logout()
215
- st.experimental_set_query_params()
216
- st.experimental_rerun()
217
- st.text("ログイン済みです。")
218
- st.text("サイドバーから利用するメニューをお選びください。")
 
21
  from streamlit import runtime
22
  from streamlit.runtime.scriptrunner import get_script_run_ctx
23
  import ipaddress
24
+ import streamlit_authenticator as stauth
25
+ import yaml
26
  from requests_oauthlib import OAuth2Session
27
  from time import time
28
  from dotenv import load_dotenv
 
33
  # 接続元制御
34
  ALLOW_IP_ADDRESS = os.environ["ALLOW_IP_ADDRESS"]
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  index_name = "./data/storage"
37
  pkl_name = "./data/stored_documents.pkl"
38
 
 
131
  # その他(許可リスト判定)
132
  return remote_ip in ALLOW_IP_ADDRESS
133
 
 
 
 
 
 
134
  # メイン
135
  def app():
136
  # 初期化
 
146
  # 接続元OK
147
  st.title("Azure AD Login with Streamlit")
148
 
149
+ with open('config.yaml') as file:
150
+ config = yaml.load(file, Loader=yaml.SafeLoader)
151
+
152
+ authenticator = stauth.Authenticate(
153
+ config['credentials'],
154
+ config['cookie']['name'],
155
+ config['cookie']['key'],
156
+ config['cookie']['expiry_days'],
157
+ config['preauthorized'],
158
+ )
159
+
160
+ name, authentication_status, username = authenticator.login('Login', 'main')
161
+
162
+
163
+ if 'authentication_status' not in st.session_state:
164
+ st.session_state['authentication_status'] = None
165
+
166
+ if st.session_state["authentication_status"]:
167
+ authenticator.logout('Logout', 'main')
168
+ st.write(f'ログインに成功しました')
169
+ initialize_index()
170
+ # ここにログイン後の処理を書く。
171
+ elif st.session_state["authentication_status"] is False:
172
+ st.error('ユーザ名またはパスワードが間違っています')
173
+ elif st.session_state["authentication_status"] is None:
174
+ st.warning('ユーザ名やパスワードを入力してください')
175
+
 
 
 
 
 
 
 
 
 
common.py CHANGED
@@ -17,22 +17,6 @@ logger.debug("調査用ログ")
17
  # 接続元制御
18
  ALLOW_IP_ADDRESS = os.environ["ALLOW_IP_ADDRESS"]
19
 
20
- # Azure AD app registration details
21
- CLIENT_ID = os.environ["CLIENT_ID"]
22
- TENANT_ID = os.environ["TENANT_ID"]
23
-
24
- # Azure API
25
- AUTHORITY = f"https://login.microsoftonline.com/{TENANT_ID}"
26
- REDIRECT_PATH = os.environ["REDIRECT_PATH"]
27
- AUTHORIZATION_URL = f"{AUTHORITY}/oauth2/v2.0/authorize"
28
- SCOPES = ["openid", "profile", "User.Read"]
29
-
30
- # 認証用URL取得
31
- def authorization_request():
32
- oauth = OAuth2Session(CLIENT_ID, redirect_uri=REDIRECT_PATH, scope=SCOPES)
33
- authorization_url, state = oauth.authorization_url(AUTHORIZATION_URL)
34
- return authorization_url, state
35
-
36
  # 接続元IP取得
37
  def get_remote_ip():
38
  ctx = get_script_run_ctx()
@@ -60,13 +44,8 @@ def is_allow_ip_address():
60
 
61
  #ログインの確認
62
  def check_login():
63
- # 接続元IP許可判定
64
- if not is_allow_ip_address():
65
- st.title("HTTP 403 Forbidden")
66
- return
67
-
68
- if "token" not in st.session_state or st.session_state["token"] is None or float(st.session_state["token_expires"]) <= time():
69
- # 認証用リンク表示
70
- authorization_url, st.session_state["authorization_state"] = authorization_request()
71
- st.markdown(f'[Click here to log in]({authorization_url})', unsafe_allow_html=True)
72
  st.stop()
 
17
  # 接続元制御
18
  ALLOW_IP_ADDRESS = os.environ["ALLOW_IP_ADDRESS"]
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # 接続元IP取得
21
  def get_remote_ip():
22
  ctx = get_script_run_ctx()
 
44
 
45
  #ログインの確認
46
  def check_login():
47
+ if 'authentication_status' not in st.session_state:
48
+ st.session_state['authentication_status'] = None
49
+ if st.session_state["authentication_status"] is None or False:
50
+ st.warning("**ログインしてください**")
 
 
 
 
 
51
  st.stop()