Nol00 commited on
Commit
1e5f915
·
verified ·
1 Parent(s): 9b43388

Update crs_arena/utils.py

Browse files
Files changed (1) hide show
  1. crs_arena/utils.py +31 -1
crs_arena/utils.py CHANGED
@@ -27,7 +27,37 @@ HF_API = HfApi(token=st.secrets["hf_token"])
27
 
28
 
29
  @st.cache_resource(
30
- show_spinner="Loading CRS...", ttl=timedelta(days=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  )
32
  def get_crs_model(model_name: str, model_config_file: str) -> CRSModel:
33
  """Returns a CRS model.
 
27
 
28
 
29
  @st.cache_resource(
30
+ show_spinner="<h2 style='text-align: center; top:50%; left:50%; transform: translate(-50%, -50%);'>The fighters are warming up... :robot: :punch: :gun: :boom:</h2>",
31
+ ttl=timedelta(days=3),
32
+ )
33
+ def get_crs_model(model_name: str, model_config_file: str) -> CRSModel:
34
+ """Returns a CRS model.
35
+
36
+ Args:
37
+ model_name: Model name.
38
+ model_config_file: Model configuration file.
39
+
40
+ Raises:
41
+ FileNotFoundError: If model configuration file is not found.
42
+
43
+ Returns:
44
+ CRS model.
45
+ """
46
+ logging.debug(f"Loading CRS model {model_name}.")
47
+ if not os.path.exists(model_config_file):
48
+ raise FileNotFoundError(
49
+ f"Model configuration file {model_config_file} not found."
50
+ )
51
+
52
+ model_args = yaml.safe_load(open(model_config_file, "r"))
53
+
54
+ if "chatgpt" in model_name:
55
+ openai.api_key = st.secrets["openai_api_key"]
56
+
57
+ # Extract crs model from name
58
+ name = model_name.split("_")[0]
59
+
60
+ return CRSModel(name, **model_args), ttl=timedelta(days=3)
61
  )
62
  def get_crs_model(model_name: str, model_config_file: str) -> CRSModel:
63
  """Returns a CRS model.