ayushnoori commited on
Commit
b2ee56b
·
1 Parent(s): c27bf27

Update model version.

Browse files
Files changed (2) hide show
  1. pages/predict.py +8 -3
  2. requirements.txt +1 -2
pages/predict.py CHANGED
@@ -36,16 +36,21 @@ st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' '
36
 
37
  @st.cache_data(show_spinner = 'Downloading AI model...')
38
  def get_embeddings():
 
 
 
 
 
39
  # Get paths to embeddings, relation weights, and edge types
40
  # with st.spinner('Downloading AI model...'):
41
  embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
42
- filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt",
43
  token=st.secrets["HF_TOKEN"])
44
  relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
45
- filename="2024_03_29_04_12_52_epoch=3-step=54291_relation_weights.pt",
46
  token=st.secrets["HF_TOKEN"])
47
  edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
48
- filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt",
49
  token=st.secrets["HF_TOKEN"])
50
  return embed_path, relation_weights_path, edge_types_path
51
 
 
36
 
37
  @st.cache_data(show_spinner = 'Downloading AI model...')
38
  def get_embeddings():
39
+ # Get checkpoint name
40
+ # best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912"
41
+ best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
42
+ # best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291"
43
+
44
  # Get paths to embeddings, relation weights, and edge types
45
  # with st.spinner('Downloading AI model...'):
46
  embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
47
+ filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
48
  token=st.secrets["HF_TOKEN"])
49
  relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
50
+ filename=(best_ckpt + "_relation_weights.pt"),
51
  token=st.secrets["HF_TOKEN"])
52
  edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
53
+ filename=(best_ckpt + "_edge_types.pt"),
54
  token=st.secrets["HF_TOKEN"])
55
  return embed_path, relation_weights_path, edge_types_path
56
 
requirements.txt CHANGED
@@ -8,5 +8,4 @@ torch
8
  altair<5
9
  gspread
10
  oauth2client
11
- huggingface_hub
12
- matplotlib
 
8
  altair<5
9
  gspread
10
  oauth2client
11
+ huggingface_hub