Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -26,12 +26,12 @@ DEFAULT_PARAMS = {
|
|
26 |
"test_seed": 42, # must be non-negative
|
27 |
},
|
28 |
"image":{
|
29 |
-
"dataset_name": "
|
30 |
"test_size": 0.2, # must be between 0 and 1
|
31 |
"test_seed": 42, # must be non-negative
|
32 |
},
|
33 |
"audio":{
|
34 |
-
"dataset_name": "
|
35 |
"test_size": 0.2, # must be between 0 and 1
|
36 |
"test_seed": 42, # must be non-negative
|
37 |
}
|
@@ -61,19 +61,33 @@ def evaluate_model(task: str, space_url: str):
|
|
61 |
|
62 |
results = response.json()
|
63 |
|
64 |
-
# Check for required keys
|
65 |
-
|
66 |
-
"username", "space_url", "submission_timestamp", "model_description",
|
67 |
-
"
|
68 |
"api_route", "dataset_config"
|
69 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
missing_keys = required_keys - set(results.keys())
|
72 |
if missing_keys:
|
73 |
return None, None, None, gr.Warning(f"API response missing required keys: {', '.join(missing_keys)}")
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
return (
|
76 |
-
|
77 |
results["emissions_gco2eq"],
|
78 |
results["energy_consumed_wh"],
|
79 |
results
|
|
|
26 |
"test_seed": 42, # must be non-negative
|
27 |
},
|
28 |
"image":{
|
29 |
+
"dataset_name": "pyronear/pyro-sdis",
|
30 |
"test_size": 0.2, # must be between 0 and 1
|
31 |
"test_seed": 42, # must be non-negative
|
32 |
},
|
33 |
"audio":{
|
34 |
+
"dataset_name": "rfcx/frugalai",
|
35 |
"test_size": 0.2, # must be between 0 and 1
|
36 |
"test_seed": 42, # must be non-negative
|
37 |
}
|
|
|
61 |
|
62 |
results = response.json()
|
63 |
|
64 |
+
# Check for required keys based on task
|
65 |
+
base_required_keys = {
|
66 |
+
"username", "space_url", "submission_timestamp", "model_description",
|
67 |
+
"energy_consumed_wh", "emissions_gco2eq", "emissions_data",
|
68 |
"api_route", "dataset_config"
|
69 |
}
|
70 |
+
|
71 |
+
# Add task-specific accuracy keys
|
72 |
+
if task == "image":
|
73 |
+
accuracy_keys = {"classification_accuracy", "mean_iou"}
|
74 |
+
else: # text and audio
|
75 |
+
accuracy_keys = {"accuracy"}
|
76 |
+
|
77 |
+
required_keys = base_required_keys | accuracy_keys
|
78 |
|
79 |
missing_keys = required_keys - set(results.keys())
|
80 |
if missing_keys:
|
81 |
return None, None, None, gr.Warning(f"API response missing required keys: {', '.join(missing_keys)}")
|
82 |
|
83 |
+
# Return appropriate accuracy metric based on task
|
84 |
+
if task == "image":
|
85 |
+
accuracy = results["classification_accuracy"] # For display in UI
|
86 |
+
else:
|
87 |
+
accuracy = results["accuracy"]
|
88 |
+
|
89 |
return (
|
90 |
+
accuracy,
|
91 |
results["emissions_gco2eq"],
|
92 |
results["energy_consumed_wh"],
|
93 |
results
|