pcuenq HF staff commited on
Commit
deafe70
·
1 Parent(s): 6e6214b

Select pipeline_tag task by default.

Browse files
Files changed (1) hide show
  1. app.py +44 -6
app.py CHANGED
@@ -63,12 +63,26 @@ tasks_mapping = {
63
  "semantic-segmentation": "Semantic Segmentation",
64
  "seq2seq-lm": "Text to Text Generation",
65
  "sequence-classification": "Text Classification",
66
- "speech-seq2seq": "Speech to Speech Generation",
67
  "token-classification": "Token Classification",
68
  }
69
  reverse_tasks_mapping = {v: k for k, v in tasks_mapping.items()}
70
  tasks_labels = list(tasks_mapping.keys())
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def error_str(error, title="Error", model=None, task=None, framework=None, compute_units=None, precision=None, tolerance=None, destination=None, open_discussion=True):
73
  if not error: return ""
74
 
@@ -112,8 +126,18 @@ def get_pr_url(api, repo_id, title):
112
  and discussion.title == title
113
  ):
114
  return f"https://huggingface.co/{repo_id}/discussions/{discussion.num}"
115
-
116
- def supported_frameworks(model_id):
 
 
 
 
 
 
 
 
 
 
117
  """
118
  Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id.
119
  Only PyTorch and Tensorflow are supported.
@@ -130,6 +154,7 @@ def on_model_change(model):
130
  error = None
131
  frameworks = []
132
  selected_framework = None
 
133
 
134
  try:
135
  config_file = hf_hub_download(model, filename="config.json")
@@ -144,17 +169,30 @@ def on_model_change(model):
144
 
145
  features = FeaturesManager.get_supported_features_for_model_type(model_type)
146
  tasks = list(features.keys())
147
- tasks = [tasks_mapping[task] for task in tasks]
148
 
149
- frameworks = supported_frameworks(model)
 
150
  selected_framework = frameworks[0] if len(frameworks) > 0 else None
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  except Exception as e:
152
  error = e
153
  model_type = None
154
 
155
  return (
156
  gr.update(visible=bool(model_type)), # Settings column
157
- gr.update(choices=tasks, value=tasks[0] if tasks else None), # Tasks
158
  gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework), # Frameworks
159
  gr.update(value=error_str(error, model=model)), # Error
160
  )
 
63
  "semantic-segmentation": "Semantic Segmentation",
64
  "seq2seq-lm": "Text to Text Generation",
65
  "sequence-classification": "Text Classification",
66
+ "speech-seq2seq": "Audio to Audio",
67
  "token-classification": "Token Classification",
68
  }
69
  reverse_tasks_mapping = {v: k for k, v in tasks_mapping.items()}
70
  tasks_labels = list(tasks_mapping.keys())
71
 
72
+ # Map pipeline_tag to internal exporters features/tasks
73
+ tags_to_tasks_mapping = {
74
+ "feature-extraction": "default",
75
+ "text-generation": "causal-lm",
76
+ "image-classification": "image-classification",
77
+ "image-segmentation": "image-segmentation",
78
+ "fill-mask": "masked-lm",
79
+ "object-detection": "object-detection",
80
+ "question-answering": "question-answering",
81
+ "text2text-generation": "seq2seq-lm",
82
+ "text-classification": "sequence-classification",
83
+ "token-classification": "token-classification",
84
+ }
85
+
86
  def error_str(error, title="Error", model=None, task=None, framework=None, compute_units=None, precision=None, tolerance=None, destination=None, open_discussion=True):
87
  if not error: return ""
88
 
 
126
  and discussion.title == title
127
  ):
128
  return f"https://huggingface.co/{repo_id}/discussions/{discussion.num}"
129
+
130
+ def retrieve_model_info(model_id):
131
+ api = HfApi()
132
+ model_info = api.model_info(model_id)
133
+ tags = model_info.tags
134
+ frameworks = [tag for tag in tags if tag in ["pytorch", "tf"]]
135
+ return {
136
+ "pipeline_tag": model_info.pipeline_tag,
137
+ "frameworks": sorted(["PyTorch" if f == "pytorch" else "TensorFlow" for f in frameworks]),
138
+ }
139
+
140
+ def supported_frameworks(model_info):
141
  """
142
  Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id.
143
  Only PyTorch and Tensorflow are supported.
 
154
  error = None
155
  frameworks = []
156
  selected_framework = None
157
+ selected_task = None
158
 
159
  try:
160
  config_file = hf_hub_download(model, filename="config.json")
 
169
 
170
  features = FeaturesManager.get_supported_features_for_model_type(model_type)
171
  tasks = list(features.keys())
 
172
 
173
+ model_info = retrieve_model_info(model)
174
+ frameworks = model_info["frameworks"]
175
  selected_framework = frameworks[0] if len(frameworks) > 0 else None
176
+
177
+ pipeline_tag = model_info["pipeline_tag"]
178
+ # Select the task corresponding to the pipeline tag
179
+ if tasks:
180
+ if pipeline_tag in tags_to_tasks_mapping:
181
+ selected_task = tags_to_tasks_mapping[pipeline_tag]
182
+ else:
183
+ selected_task = tasks[0]
184
+
185
+ # Convert to UI labels
186
+ tasks = [tasks_mapping[task] for task in tasks]
187
+ selected_task = tasks_mapping[selected_task]
188
+
189
  except Exception as e:
190
  error = e
191
  model_type = None
192
 
193
  return (
194
  gr.update(visible=bool(model_type)), # Settings column
195
+ gr.update(choices=tasks, value=selected_task), # Tasks
196
  gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework), # Frameworks
197
  gr.update(value=error_str(error, model=model)), # Error
198
  )