Update crs_arena/utils.py
Browse files- crs_arena/utils.py +31 -1
crs_arena/utils.py
CHANGED
@@ -27,7 +27,37 @@ HF_API = HfApi(token=st.secrets["hf_token"])
|
|
27 |
|
28 |
|
29 |
@st.cache_resource(
|
30 |
-
show_spinner="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
)
|
32 |
def get_crs_model(model_name: str, model_config_file: str) -> CRSModel:
|
33 |
"""Returns a CRS model.
|
|
|
27 |
|
28 |
|
29 |
@st.cache_resource(
|
30 |
+
show_spinner="<h2 style='text-align: center; top:50%; left:50%; transform: translate(-50%, -50%);'>The fighters are warming up... :robot: :punch: :gun: :boom:</h2>",
|
31 |
+
ttl=timedelta(days=3),
|
32 |
+
)
|
33 |
+
def get_crs_model(model_name: str, model_config_file: str) -> CRSModel:
|
34 |
+
"""Returns a CRS model.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
model_name: Model name.
|
38 |
+
model_config_file: Model configuration file.
|
39 |
+
|
40 |
+
Raises:
|
41 |
+
FileNotFoundError: If model configuration file is not found.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
CRS model.
|
45 |
+
"""
|
46 |
+
logging.debug(f"Loading CRS model {model_name}.")
|
47 |
+
if not os.path.exists(model_config_file):
|
48 |
+
raise FileNotFoundError(
|
49 |
+
f"Model configuration file {model_config_file} not found."
|
50 |
+
)
|
51 |
+
|
52 |
+
model_args = yaml.safe_load(open(model_config_file, "r"))
|
53 |
+
|
54 |
+
if "chatgpt" in model_name:
|
55 |
+
openai.api_key = st.secrets["openai_api_key"]
|
56 |
+
|
57 |
+
# Extract crs model from name
|
58 |
+
name = model_name.split("_")[0]
|
59 |
+
|
60 |
+
return CRSModel(name, **model_args), ttl=timedelta(days=3)
|
61 |
)
|
62 |
def get_crs_model(model_name: str, model_config_file: str) -> CRSModel:
|
63 |
"""Returns a CRS model.
|