marcenacp commited on
Commit
041af8a
·
1 Parent(s): 636f5d2

Add cache.

Browse files
Files changed (5) hide show
  1. app.py +24 -7
  2. core/past_projects.py +2 -2
  3. core/state.py +20 -8
  4. views/previous_files.py +1 -1
  5. views/splash.py +27 -0
app.py CHANGED
@@ -6,27 +6,36 @@ from core.constants import OAUTH_CLIENT_ID
6
  from core.constants import OAUTH_STATE
7
  from core.constants import REDIRECT_URI
8
  from core.state import CurrentStep
 
9
  from core.state import User
10
  from utils import init_state
11
  from views.splash import render_splash
12
  from views.wizard import render_editor
13
 
14
- init_state()
15
-
16
-
17
  st.set_page_config(page_title="Croissant Editor", page_icon="🥐", layout="wide")
18
- col1, col2 = st.columns([10, 1])
19
  col1.header("Croissant Editor")
20
 
21
- if OAUTH_CLIENT_ID and not st.session_state.get(User):
 
 
 
 
22
  query_params = st.experimental_get_query_params()
23
  state = query_params.get("state")
24
  if state and state[0] == OAUTH_STATE:
25
  code = query_params.get("code")
26
  if not code:
27
  st.stop()
28
- st.session_state[User] = User.connect(code)
29
- st.experimental_set_query_params()
 
 
 
 
 
 
 
30
  else:
31
  redirect_uri = urllib.parse.quote(REDIRECT_URI, safe="")
32
  client_id = urllib.parse.quote(OAUTH_CLIENT_ID, safe="")
@@ -42,10 +51,18 @@ def _back_to_menu():
42
  init_state(force=True)
43
 
44
 
 
 
 
 
 
45
  if st.session_state[CurrentStep] != CurrentStep.splash:
46
  col2.write("\n") # Vertical box to shift the button menu
47
  col2.button("Menu", on_click=_back_to_menu)
48
 
 
 
 
49
 
50
  if st.session_state[CurrentStep] == CurrentStep.splash:
51
  render_splash()
 
6
  from core.constants import OAUTH_STATE
7
  from core.constants import REDIRECT_URI
8
  from core.state import CurrentStep
9
+ from core.state import get_cached_user
10
  from core.state import User
11
  from utils import init_state
12
  from views.splash import render_splash
13
  from views.wizard import render_editor
14
 
 
 
 
15
  st.set_page_config(page_title="Croissant Editor", page_icon="🥐", layout="wide")
16
+ col1, col2, col3 = st.columns([10, 1, 1])
17
  col1.header("Croissant Editor")
18
 
19
+ init_state()
20
+
21
+ user = get_cached_user()
22
+
23
+ if OAUTH_CLIENT_ID and not user:
24
  query_params = st.experimental_get_query_params()
25
  state = query_params.get("state")
26
  if state and state[0] == OAUTH_STATE:
27
  code = query_params.get("code")
28
  if not code:
29
  st.stop()
30
+ try:
31
+ st.session_state[User] = User.connect(code)
32
+ # Clear the cache to force retrieving the new user.
33
+ get_cached_user.clear()
34
+ get_cached_user()
35
+ except:
36
+ raise
37
+ finally:
38
+ st.experimental_set_query_params()
39
  else:
40
  redirect_uri = urllib.parse.quote(REDIRECT_URI, safe="")
41
  client_id = urllib.parse.quote(OAUTH_CLIENT_ID, safe="")
 
51
  init_state(force=True)
52
 
53
 
54
+ def _logout():
55
+ """Logs the user out."""
56
+ st.cache_data.clear()
57
+
58
+
59
  if st.session_state[CurrentStep] != CurrentStep.splash:
60
  col2.write("\n") # Vertical box to shift the button menu
61
  col2.button("Menu", on_click=_back_to_menu)
62
 
63
+ col3.write("\n") # Vertical box to shift the lgout menu
64
+ col3.button("Log out", on_click=_logout)
65
+
66
 
67
  if st.session_state[CurrentStep] == CurrentStep.splash:
68
  render_splash()
core/past_projects.py CHANGED
@@ -6,12 +6,12 @@ import streamlit as st
6
 
7
  from core.constants import PAST_PROJECTS_PATH
8
  from core.state import CurrentProject
 
9
  from core.state import Metadata
10
- from core.state import User
11
 
12
 
13
  def load_past_projects_paths() -> list[epath.Path]:
14
- user = st.session_state.get(User)
15
  past_projects_path = PAST_PROJECTS_PATH(user)
16
  past_projects_path.mkdir(parents=True, exist_ok=True)
17
  return sorted(list(past_projects_path.iterdir()), reverse=True)
 
6
 
7
  from core.constants import PAST_PROJECTS_PATH
8
  from core.state import CurrentProject
9
+ from core.state import get_cached_user
10
  from core.state import Metadata
 
11
 
12
 
13
  def load_past_projects_paths() -> list[epath.Path]:
14
+ user = get_cached_user()
15
  past_projects_path = PAST_PROJECTS_PATH(user)
16
  past_projects_path.mkdir(parents=True, exist_ok=True)
17
  return sorted(list(past_projects_path.iterdir()), reverse=True)
core/state.py CHANGED
@@ -8,7 +8,6 @@ from __future__ import annotations
8
  import base64
9
  import dataclasses
10
  import datetime
11
- import hashlib
12
  from typing import Any
13
 
14
  from etils import epath
@@ -64,17 +63,30 @@ class User:
64
  access_token = response.get("access_token")
65
  id_token = response.get("id_token")
66
  if access_token and id_token:
67
- # Warning: this is temporary while being able to retrieve the username.
68
- username = hashlib.sha256(access_token.encode()).hexdigest()
69
- return User(
70
- access_token=access_token, username=username, id_token=id_token
71
- )
 
 
 
 
 
 
 
72
  raise Exception(
73
  f"Could not connect to Hugging Face. Please, go to {REDIRECT_URI}."
74
  f" ({response=})."
75
  )
76
 
77
 
 
 
 
 
 
 
78
  class CurrentStep:
79
  """Holds all major state variables for the application."""
80
 
@@ -91,8 +103,8 @@ class CurrentProject:
91
  @classmethod
92
  def create_new(cls) -> CurrentProject:
93
  timestamp = datetime.datetime.now().strftime(PROJECT_FOLDER_PATTERN)
94
- user = st.session_state.get(User)
95
- if user is None:
96
  return None
97
  else:
98
  path = PAST_PROJECTS_PATH(user)
 
8
  import base64
9
  import dataclasses
10
  import datetime
 
11
  from typing import Any
12
 
13
  from etils import epath
 
63
  access_token = response.get("access_token")
64
  id_token = response.get("id_token")
65
  if access_token and id_token:
66
+ url = "https://huggingface.co/oauth/userinfo"
67
+ headers = {"Authorization": f"Bearer {access_token}"}
68
+ response = requests.get(url, headers=headers)
69
+ if response.status_code == 200:
70
+ response = response.json()
71
+ username = response.get("preferred_username")
72
+ if username:
73
+ return User(
74
+ access_token=access_token,
75
+ username=username,
76
+ id_token=id_token,
77
+ )
78
  raise Exception(
79
  f"Could not connect to Hugging Face. Please, go to {REDIRECT_URI}."
80
  f" ({response=})."
81
  )
82
 
83
 
84
+ @st.cache_data(ttl=datetime.timedelta(hours=1))
85
+ def get_cached_user():
86
+ """Caches user in session_state."""
87
+ return st.session_state.get(User)
88
+
89
+
90
  class CurrentStep:
91
  """Holds all major state variables for the application."""
92
 
 
103
  @classmethod
104
  def create_new(cls) -> CurrentProject:
105
  timestamp = datetime.datetime.now().strftime(PROJECT_FOLDER_PATTERN)
106
+ user = get_cached_user()
107
+ if user is None and OAUTH_CLIENT_ID:
108
  return None
109
  else:
110
  path = PAST_PROJECTS_PATH(user)
views/previous_files.py CHANGED
@@ -29,12 +29,12 @@ def render_previous_files():
29
  else:
30
  for index, path in enumerate(paths):
31
  try:
32
- col1, col2 = st.columns([10, 1])
33
  metadata = open_project(path)
34
  timestamp = datetime.datetime.strptime(
35
  path.name, PROJECT_FOLDER_PATTERN
36
  ).strftime("%Y/%m/%d %H:%M")
37
  label = f"{metadata.name or 'Unnamed dataset'} - {timestamp}"
 
38
  col1.button(
39
  label,
40
  key=f"splash-{index}-load",
 
29
  else:
30
  for index, path in enumerate(paths):
31
  try:
 
32
  metadata = open_project(path)
33
  timestamp = datetime.datetime.strptime(
34
  path.name, PROJECT_FOLDER_PATTERN
35
  ).strftime("%Y/%m/%d %H:%M")
36
  label = f"{metadata.name or 'Unnamed dataset'} - {timestamp}"
37
+ col1, col2 = st.columns([10, 1])
38
  col1.button(
39
  label,
40
  key=f"splash-{index}-load",
views/splash.py CHANGED
@@ -1,9 +1,13 @@
 
 
 
1
  import streamlit as st
2
 
3
  from core.constants import OAUTH_CLIENT_ID
4
  from core.state import CurrentProject
5
  from core.state import CurrentStep
6
  from core.state import Metadata
 
7
  from utils import jump_to
8
  from views.load import render_load
9
  from views.previous_files import render_previous_files
@@ -33,6 +37,29 @@ def render_splash():
33
  on_click=create_new_croissant,
34
  type="primary",
35
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  with col2:
37
  with st.expander("**Past projects**", expanded=True):
38
  render_previous_files()
 
1
+ import logging
2
+
3
+ import requests
4
  import streamlit as st
5
 
6
  from core.constants import OAUTH_CLIENT_ID
7
  from core.state import CurrentProject
8
  from core.state import CurrentStep
9
  from core.state import Metadata
10
+ import mlcroissant as mlc
11
  from utils import jump_to
12
  from views.load import render_load
13
  from views.previous_files import render_previous_files
 
37
  on_click=create_new_croissant,
38
  type="primary",
39
  )
40
+ with st.expander("**Try out an example!**", expanded=True):
41
+
42
+ def create_example():
43
+ url = "https://raw.githubusercontent.com/mlcommons/croissant/main/datasets/titanic/metadata.json"
44
+ try:
45
+ json = requests.get(url).json()
46
+ metadata = mlc.Metadata.from_json(mlc.Issues(), json, None)
47
+ st.session_state[Metadata] = Metadata.from_canonical(metadata)
48
+ st.session_state[CurrentProject] = CurrentProject.create_new()
49
+ jump_to(CurrentStep.editor)
50
+ except Exception as exception:
51
+ logging.error(exception)
52
+ st.write(
53
+ "Sorry, it seems that the example is broken... Can you please"
54
+ " [open an issue on"
55
+ " GitHub](https://github.com/mlcommons/croissant/issues/new)?"
56
+ )
57
+
58
+ st.button(
59
+ "Titanic dataset",
60
+ on_click=create_example,
61
+ type="primary",
62
+ )
63
  with col2:
64
  with st.expander("**Past projects**", expanded=True):
65
  render_previous_files()