Spaces:
Sleeping
Sleeping
Commit
·
b2ee56b
1
Parent(s):
c27bf27
Update model version.
Browse files- pages/predict.py +8 -3
- requirements.txt +1 -2
pages/predict.py
CHANGED
@@ -36,16 +36,21 @@ st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' '
|
|
36 |
|
37 |
@st.cache_data(show_spinner = 'Downloading AI model...')
|
38 |
def get_embeddings():
|
|
|
|
|
|
|
|
|
|
|
39 |
# Get paths to embeddings, relation weights, and edge types
|
40 |
# with st.spinner('Downloading AI model...'):
|
41 |
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
42 |
-
filename="
|
43 |
token=st.secrets["HF_TOKEN"])
|
44 |
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
45 |
-
filename="
|
46 |
token=st.secrets["HF_TOKEN"])
|
47 |
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
48 |
-
filename="
|
49 |
token=st.secrets["HF_TOKEN"])
|
50 |
return embed_path, relation_weights_path, edge_types_path
|
51 |
|
|
|
36 |
|
37 |
@st.cache_data(show_spinner = 'Downloading AI model...')
|
38 |
def get_embeddings():
|
39 |
+
# Get checkpoint name
|
40 |
+
# best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912"
|
41 |
+
best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
|
42 |
+
# best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291"
|
43 |
+
|
44 |
# Get paths to embeddings, relation weights, and edge types
|
45 |
# with st.spinner('Downloading AI model...'):
|
46 |
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
47 |
+
filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
|
48 |
token=st.secrets["HF_TOKEN"])
|
49 |
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
50 |
+
filename=(best_ckpt + "_relation_weights.pt"),
|
51 |
token=st.secrets["HF_TOKEN"])
|
52 |
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
53 |
+
filename=(best_ckpt + "_edge_types.pt"),
|
54 |
token=st.secrets["HF_TOKEN"])
|
55 |
return embed_path, relation_weights_path, edge_types_path
|
56 |
|
requirements.txt
CHANGED
@@ -8,5 +8,4 @@ torch
|
|
8 |
altair<5
|
9 |
gspread
|
10 |
oauth2client
|
11 |
-
huggingface_hub
|
12 |
-
matplotlib
|
|
|
8 |
altair<5
|
9 |
gspread
|
10 |
oauth2client
|
11 |
+
huggingface_hub
|
|