Tony Wu commited on
Commit
66d537f
·
1 Parent(s): 658164b

style: apply linter

Browse files
Files changed (1) hide show
  1. data/model_handler.py +16 -14
data/model_handler.py CHANGED
@@ -1,12 +1,15 @@
1
  import json
2
  import os
3
  from typing import Dict
4
- from huggingface_hub import HfApi, hf_hub_download, metadata_load
5
  import pandas as pd
6
- from .dataset_handler import get_datasets_nickname, VIDORE_DATASETS_KEYWORDS
 
 
7
 
8
  BLOCKLIST = ["impactframes"]
9
 
 
10
  class ModelHandler:
11
  def __init__(self, model_infos_path="model_infos.json"):
12
  self.api = HfApi()
@@ -28,21 +31,20 @@ class ModelHandler:
28
  repositories = [model.modelId for model in models] # type: ignore
29
 
30
  for repo_id in repositories:
31
- org_name = repo_id.split('/')[0]
32
  if org_name in BLOCKLIST:
33
  continue
34
-
35
- files = [f for f in self.api.list_repo_files(repo_id) if f.endswith('_metrics.json') or f == 'results.json']
36
 
37
-
 
38
  if len(files) == 0:
39
  continue
40
  else:
41
  for file in files:
42
- if file.endswith('results.json'):
43
- model_name = repo_id.replace('/', '_')
44
  else:
45
- model_name = file.split('_metrics.json')[0]
46
 
47
  if model_name not in self.model_infos:
48
  readme_path = hf_hub_download(repo_id, filename="README.md")
@@ -61,7 +63,7 @@ class ModelHandler:
61
  print(f"Error loading {model_name} - {e}")
62
  continue
63
 
64
- #self._save_model_infos()
65
 
66
  model_res = {}
67
  if len(self.model_infos) > 0:
@@ -69,7 +71,7 @@ class ModelHandler:
69
  res = self.model_infos[model]["results"]
70
  dataset_res = {}
71
  for dataset in res.keys():
72
- #for each keyword check if it is in the dataset name if not continue
73
  if not any(keyword in dataset for keyword in VIDORE_DATASETS_KEYWORDS):
74
  print(f"{dataset} not found in ViDoRe datasets. Skipping ...")
75
  continue
@@ -77,9 +79,9 @@ class ModelHandler:
77
  dataset_nickname = get_datasets_nickname(dataset)
78
  dataset_res[dataset_nickname] = res[dataset][metric]
79
  model_res[model] = dataset_res
80
-
81
  df = pd.DataFrame(model_res).T
82
-
83
  return df
84
  return pd.DataFrame()
85
 
@@ -104,7 +106,7 @@ class ModelHandler:
104
  df.insert(len(df.columns) - len(cols_to_rank), "Average", df[cols_to_rank].mean(axis=1, skipna=False))
105
  df.sort_values("Average", ascending=False, inplace=True)
106
  df.insert(0, "Rank", list(range(1, len(df) + 1)))
107
- #multiply values by 100 if they are floats and round to 1 decimal place
108
  for col in df.columns:
109
  if df[col].dtype == "float64":
110
  df[col] = df[col].apply(lambda x: round(x * 100, 1))
 
1
  import json
2
  import os
3
  from typing import Dict
4
+
5
  import pandas as pd
6
+ from huggingface_hub import HfApi, hf_hub_download, metadata_load
7
+
8
+ from .dataset_handler import VIDORE_DATASETS_KEYWORDS, get_datasets_nickname
9
 
10
  BLOCKLIST = ["impactframes"]
11
 
12
+
13
  class ModelHandler:
14
  def __init__(self, model_infos_path="model_infos.json"):
15
  self.api = HfApi()
 
31
  repositories = [model.modelId for model in models] # type: ignore
32
 
33
  for repo_id in repositories:
34
+ org_name = repo_id.split("/")[0]
35
  if org_name in BLOCKLIST:
36
  continue
 
 
37
 
38
+ files = [f for f in self.api.list_repo_files(repo_id) if f.endswith("_metrics.json") or f == "results.json"]
39
+
40
  if len(files) == 0:
41
  continue
42
  else:
43
  for file in files:
44
+ if file.endswith("results.json"):
45
+ model_name = repo_id.replace("/", "_")
46
  else:
47
+ model_name = file.split("_metrics.json")[0]
48
 
49
  if model_name not in self.model_infos:
50
  readme_path = hf_hub_download(repo_id, filename="README.md")
 
63
  print(f"Error loading {model_name} - {e}")
64
  continue
65
 
66
+ # self._save_model_infos()
67
 
68
  model_res = {}
69
  if len(self.model_infos) > 0:
 
71
  res = self.model_infos[model]["results"]
72
  dataset_res = {}
73
  for dataset in res.keys():
74
+ # for each keyword check if it is in the dataset name if not continue
75
  if not any(keyword in dataset for keyword in VIDORE_DATASETS_KEYWORDS):
76
  print(f"{dataset} not found in ViDoRe datasets. Skipping ...")
77
  continue
 
79
  dataset_nickname = get_datasets_nickname(dataset)
80
  dataset_res[dataset_nickname] = res[dataset][metric]
81
  model_res[model] = dataset_res
82
+
83
  df = pd.DataFrame(model_res).T
84
+
85
  return df
86
  return pd.DataFrame()
87
 
 
106
  df.insert(len(df.columns) - len(cols_to_rank), "Average", df[cols_to_rank].mean(axis=1, skipna=False))
107
  df.sort_values("Average", ascending=False, inplace=True)
108
  df.insert(0, "Rank", list(range(1, len(df) + 1)))
109
+ # multiply values by 100 if they are floats and round to 1 decimal place
110
  for col in df.columns:
111
  if df[col].dtype == "float64":
112
  df[col] = df[col].apply(lambda x: round(x * 100, 1))