Select pipeline_tag task by default.
Browse files
app.py
CHANGED
@@ -63,12 +63,26 @@ tasks_mapping = {
|
|
63 |
"semantic-segmentation": "Semantic Segmentation",
|
64 |
"seq2seq-lm": "Text to Text Generation",
|
65 |
"sequence-classification": "Text Classification",
|
66 |
-
"speech-seq2seq": "
|
67 |
"token-classification": "Token Classification",
|
68 |
}
|
69 |
reverse_tasks_mapping = {v: k for k, v in tasks_mapping.items()}
|
70 |
tasks_labels = list(tasks_mapping.keys())
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def error_str(error, title="Error", model=None, task=None, framework=None, compute_units=None, precision=None, tolerance=None, destination=None, open_discussion=True):
|
73 |
if not error: return ""
|
74 |
|
@@ -112,8 +126,18 @@ def get_pr_url(api, repo_id, title):
|
|
112 |
and discussion.title == title
|
113 |
):
|
114 |
return f"https://huggingface.co/{repo_id}/discussions/{discussion.num}"
|
115 |
-
|
116 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
"""
|
118 |
Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id.
|
119 |
Only PyTorch and Tensorflow are supported.
|
@@ -130,6 +154,7 @@ def on_model_change(model):
|
|
130 |
error = None
|
131 |
frameworks = []
|
132 |
selected_framework = None
|
|
|
133 |
|
134 |
try:
|
135 |
config_file = hf_hub_download(model, filename="config.json")
|
@@ -144,17 +169,30 @@ def on_model_change(model):
|
|
144 |
|
145 |
features = FeaturesManager.get_supported_features_for_model_type(model_type)
|
146 |
tasks = list(features.keys())
|
147 |
-
tasks = [tasks_mapping[task] for task in tasks]
|
148 |
|
149 |
-
|
|
|
150 |
selected_framework = frameworks[0] if len(frameworks) > 0 else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
except Exception as e:
|
152 |
error = e
|
153 |
model_type = None
|
154 |
|
155 |
return (
|
156 |
gr.update(visible=bool(model_type)), # Settings column
|
157 |
-
gr.update(choices=tasks, value=
|
158 |
gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework), # Frameworks
|
159 |
gr.update(value=error_str(error, model=model)), # Error
|
160 |
)
|
|
|
63 |
"semantic-segmentation": "Semantic Segmentation",
|
64 |
"seq2seq-lm": "Text to Text Generation",
|
65 |
"sequence-classification": "Text Classification",
|
66 |
+
"speech-seq2seq": "Audio to Audio",
|
67 |
"token-classification": "Token Classification",
|
68 |
}
|
69 |
reverse_tasks_mapping = {v: k for k, v in tasks_mapping.items()}
|
70 |
tasks_labels = list(tasks_mapping.keys())
|
71 |
|
72 |
+
# Map pipeline_tag to internal exporters features/tasks
|
73 |
+
tags_to_tasks_mapping = {
|
74 |
+
"feature-extraction": "default",
|
75 |
+
"text-generation": "causal-lm",
|
76 |
+
"image-classification": "image-classification",
|
77 |
+
"image-segmentation": "image-segmentation",
|
78 |
+
"fill-mask": "masked-lm",
|
79 |
+
"object-detection": "object-detection",
|
80 |
+
"question-answering": "question-answering",
|
81 |
+
"text2text-generation": "seq2seq-lm",
|
82 |
+
"text-classification": "sequence-classification",
|
83 |
+
"token-classification": "token-classification",
|
84 |
+
}
|
85 |
+
|
86 |
def error_str(error, title="Error", model=None, task=None, framework=None, compute_units=None, precision=None, tolerance=None, destination=None, open_discussion=True):
|
87 |
if not error: return ""
|
88 |
|
|
|
126 |
and discussion.title == title
|
127 |
):
|
128 |
return f"https://huggingface.co/{repo_id}/discussions/{discussion.num}"
|
129 |
+
|
130 |
+
def retrieve_model_info(model_id):
|
131 |
+
api = HfApi()
|
132 |
+
model_info = api.model_info(model_id)
|
133 |
+
tags = model_info.tags
|
134 |
+
frameworks = [tag for tag in tags if tag in ["pytorch", "tf"]]
|
135 |
+
return {
|
136 |
+
"pipeline_tag": model_info.pipeline_tag,
|
137 |
+
"frameworks": sorted(["PyTorch" if f == "pytorch" else "TensorFlow" for f in frameworks]),
|
138 |
+
}
|
139 |
+
|
140 |
+
def supported_frameworks(model_info):
|
141 |
"""
|
142 |
Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id.
|
143 |
Only PyTorch and Tensorflow are supported.
|
|
|
154 |
error = None
|
155 |
frameworks = []
|
156 |
selected_framework = None
|
157 |
+
selected_task = None
|
158 |
|
159 |
try:
|
160 |
config_file = hf_hub_download(model, filename="config.json")
|
|
|
169 |
|
170 |
features = FeaturesManager.get_supported_features_for_model_type(model_type)
|
171 |
tasks = list(features.keys())
|
|
|
172 |
|
173 |
+
model_info = retrieve_model_info(model)
|
174 |
+
frameworks = model_info["frameworks"]
|
175 |
selected_framework = frameworks[0] if len(frameworks) > 0 else None
|
176 |
+
|
177 |
+
pipeline_tag = model_info["pipeline_tag"]
|
178 |
+
# Select the task corresponding to the pipeline tag
|
179 |
+
if tasks:
|
180 |
+
if pipeline_tag in tags_to_tasks_mapping:
|
181 |
+
selected_task = tags_to_tasks_mapping[pipeline_tag]
|
182 |
+
else:
|
183 |
+
selected_task = tasks[0]
|
184 |
+
|
185 |
+
# Convert to UI labels
|
186 |
+
tasks = [tasks_mapping[task] for task in tasks]
|
187 |
+
selected_task = tasks_mapping[selected_task]
|
188 |
+
|
189 |
except Exception as e:
|
190 |
error = e
|
191 |
model_type = None
|
192 |
|
193 |
return (
|
194 |
gr.update(visible=bool(model_type)), # Settings column
|
195 |
+
gr.update(choices=tasks, value=selected_task), # Tasks
|
196 |
gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework), # Frameworks
|
197 |
gr.update(value=error_str(error, model=model)), # Error
|
198 |
)
|