sasha HF Staff commited on
Commit
eddabf1
·
verified ·
1 Parent(s): efdbdd2

Update app.py

Browse files

adding task-model compatibility check

Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -5,8 +5,7 @@ from dataclasses import dataclass
5
  from datasets import load_dataset, Dataset
6
  import pandas as pd
7
  import gradio as gr
8
- from huggingface_hub import HfApi, snapshot_download
9
- from huggingface_hub.hf_api import ModelInfo
10
  from enum import Enum
11
 
12
 
@@ -52,22 +51,25 @@ def add_new_eval(
52
  current_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
53
  requests= load_dataset("EnergyStarAI/requests_debug", split="test", token=TOKEN)
54
  requests_dset = requests.to_pandas()
55
- model_list= requests_dset[requests_dset['status'] == 'COMPLETED']['model'].tolist()
 
 
56
  if repo_id in model_list:
57
  return 'This model has already been run!'
 
 
58
  else:
59
  # Is the model info correctly filled?
60
  try:
61
  model_info = API.model_info(repo_id=repo_id)
62
  except Exception:
63
  return "Could not find information for model %s" % (model)
64
-
65
  model_size = get_model_size(model_info=model_info)
66
 
67
  print("Adding request")
68
 
69
 
70
-
71
  request_dict = {
72
  "model": repo_id,
73
  "status": "PENDING",
 
5
  from datasets import load_dataset, Dataset
6
  import pandas as pd
7
  import gradio as gr
8
+ from huggingface_hub import HfApi, snapshot_download, ModelInfo, list_models
 
9
  from enum import Enum
10
 
11
 
 
51
  current_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
52
  requests= load_dataset("EnergyStarAI/requests_debug", split="test", token=TOKEN)
53
  requests_dset = requests.to_pandas()
54
+ model_list= requests_dset[requests_dset['status'] == 'COMPLETED']['model'].tolist
55
+ task_models = list(hf_api.list_models(filter=task.lower().replace(' ','-'))
56
+ task_model_names = [m.id for m in task_models]
57
  if repo_id in model_list:
58
  return 'This model has already been run!'
59
+ if model not in task_model_names:
60
+ return "This model isn't compatible with the chosen task! Pick a different model-task combination"
61
  else:
62
  # Is the model info correctly filled?
63
  try:
64
  model_info = API.model_info(repo_id=repo_id)
65
  except Exception:
66
  return "Could not find information for model %s" % (model)
67
+
68
  model_size = get_model_size(model_info=model_info)
69
 
70
  print("Adding request")
71
 
72
 
 
73
  request_dict = {
74
  "model": repo_id,
75
  "status": "PENDING",