Spaces:
Running
Running
import React, { useEffect, useState } from 'react' | |
import API from '../API' | |
import LeaderboardFilter from './LeaderboardFilter' | |
interface LeaderboardTableProps { | |
file: string | |
} | |
interface Row { | |
metric: string | |
[key: string]: string | number | |
} | |
interface Groups { | |
[group: string]: { [subgroup: string]: string[] } | |
} | |
interface GroupStats { | |
average: { [key: string]: number } | |
stdDev: { [key: string]: number } | |
} | |
const LeaderboardTable: React.FC<LeaderboardTableProps> = ({ file }) => { | |
const [tableRows, setTableRows] = useState<Row[]>([]) | |
const [tableHeader, setTableHeader] = useState<string[]>([]) | |
const [loading, setLoading] = useState(true) | |
const [error, setError] = useState<string | null>(null) | |
const [groups, setGroups] = useState<Groups>({}) | |
const [openGroups, setOpenGroups] = useState<{ [key: string]: boolean }>({}) | |
const [openSubGroups, setOpenSubGroups] = useState<{ [key: string]: { [key: string]: boolean } }>( | |
{} | |
) | |
const [selectedMetrics, setSelectedMetrics] = useState<Set<string>>(new Set()) | |
const [defaultSelectedMetrics, setDefaultSelectedMetrics] = useState<string[]>([]) | |
// To store the unique metrics from the Overall group | |
const [overallMetrics, setOverallMetrics] = useState<string[]>([]) | |
useEffect(() => { | |
API.fetchStaticFile(`data/${file}_benchmark.csv`) | |
.then((response) => { | |
const data = JSON.parse(response) | |
const rows: Row[] = data['rows'] | |
const groups = data['groups'] as { [key: string]: string[] } | |
// Extract unique metrics from the Overall group (after the underscore) | |
const overallGroup = groups['Overall'] || [] | |
const uniqueMetrics = new Set<string>() | |
overallGroup.forEach((metric) => { | |
if (metric.includes('_')) { | |
// Extract the part after the first underscore | |
const metricName = metric.split('_').slice(1).join('_') | |
uniqueMetrics.add(metricName) | |
} | |
}) | |
setOverallMetrics(Array.from(uniqueMetrics).sort()) | |
// Each value of groups is a list of metrics, group them by the first part of the metric before the first _ | |
const groupsData = Object.entries(groups) | |
.sort(([groupA], [groupB]) => { | |
// Make sure "overall" comes first | |
if (groupA === 'Overall') return -1 | |
if (groupB === 'Overall') return 1 | |
// Otherwise sort alphabetically | |
return groupA.localeCompare(groupB) | |
}) | |
.reduce( | |
(acc, [group, metrics]) => { | |
// Sort metrics to ensure consistent subgroup order | |
const sortedMetrics = [...metrics].sort() | |
// Create and sort subgroups | |
acc[group] = sortedMetrics.reduce<{ [key: string]: string[] }>((subAcc, metric) => { | |
const [mainGroup, subGroup] = metric.split('_') | |
if (!subAcc[mainGroup]) { | |
subAcc[mainGroup] = [] | |
} | |
subAcc[mainGroup].push(metric) | |
return subAcc | |
}, {}) | |
// Convert to sorted entries and back to object | |
acc[group] = Object.fromEntries( | |
Object.entries(acc[group]).sort(([subGroupA], [subGroupB]) => | |
subGroupA.localeCompare(subGroupB) | |
) | |
) | |
return acc | |
}, | |
{} as { [key: string]: { [key: string]: string[] } } | |
) | |
const allKeys: string[] = Array.from(new Set(rows.flatMap((row) => Object.keys(row)))) | |
// Remove 'metric' from headers if it exists | |
const headers = allKeys.filter((key) => key !== 'metric') | |
// Initialize open states for groups and subgroups | |
const initialOpenGroups: { [key: string]: boolean } = {} | |
const initialOpenSubGroups: { [key: string]: { [key: string]: boolean } } = {} | |
Object.keys(groupsData).forEach((group) => { | |
initialOpenGroups[group] = false | |
initialOpenSubGroups[group] = {} | |
Object.keys(groupsData[group]).forEach((subGroup) => { | |
initialOpenSubGroups[group][subGroup] = false | |
}) | |
}) | |
setSelectedMetrics(new Set(data['default_selected_metrics'])) | |
setDefaultSelectedMetrics(data['default_selected_metrics']) | |
setTableHeader(headers) | |
setTableRows(rows) | |
setGroups(groupsData) | |
setOpenGroups(initialOpenGroups) | |
setOpenSubGroups(initialOpenSubGroups) | |
setLoading(false) | |
}) | |
.catch((err) => { | |
setError('Failed to fetch JSON: ' + err.message) | |
setLoading(false) | |
}) | |
}, [file]) | |
const handleSelectDefaults = () => { | |
setSelectedMetrics(new Set(defaultSelectedMetrics)) | |
} | |
const toggleGroup = (group: string) => { | |
setOpenGroups((prev) => ({ ...prev, [group]: !prev[group] })) | |
} | |
const toggleSubGroup = (group: string, subGroup: string) => { | |
setOpenSubGroups((prev) => ({ | |
...prev, | |
[group]: { | |
...(prev[group] || {}), | |
[subGroup]: !prev[group]?.[subGroup], | |
}, | |
})) | |
} | |
// Find all metrics matching a particular extracted metric name (like "log10_p_value") | |
const findAllMetricsForName = (metricName: string): string[] => { | |
return tableRows | |
.filter((row) => { | |
const metric = row.metric as string | |
if (metric.includes('_')) { | |
const extractedName = metric.split('_').slice(1).join('_') | |
return extractedName.endsWith(metricName) | |
} | |
return false | |
}) | |
.map((row) => row.metric as string) | |
} | |
// Calculate average and standard deviation for a set of metrics for a specific column | |
const calculateStats = ( | |
metricNames: string[], | |
columnKey: string | |
): { avg: number; stdDev: number } => { | |
const values = metricNames | |
.map((metricName) => { | |
const row = tableRows.find((row) => row.metric === metricName) | |
return row ? Number(row[columnKey]) : NaN | |
}) | |
.filter((value) => !isNaN(value)) | |
if (values.length === 0) return { avg: NaN, stdDev: NaN } | |
const avg = values.reduce((sum, val) => sum + val, 0) / values.length | |
const squareDiffs = values.map((value) => { | |
const diff = value - avg | |
return diff * diff | |
}) | |
const variance = squareDiffs.reduce((sum, sqrDiff) => sum + sqrDiff, 0) / values.length | |
const stdDev = Math.sqrt(variance) | |
return { avg, stdDev } | |
} | |
// Filter metrics by group and/or subgroup | |
const filterMetricsByGroupAndSubgroup = ( | |
metricNames: string[], | |
group: string | null = null, | |
subgroup: string | null = null | |
): string[] => { | |
// If no group specified, return all metrics | |
if (!group) return metricNames | |
// Get all metrics for the specified group | |
const groupMetrics = Object.values(groups[group] || {}).flat() | |
// If subgroup is specified, further filter to that subgroup | |
if (subgroup && groups[group]?.[subgroup]) { | |
return metricNames.filter( | |
(metric) => groups[group][subgroup].includes(metric) && selectedMetrics.has(metric) | |
) | |
} | |
// Otherwise return all metrics in the group | |
return metricNames.filter( | |
(metric) => groupMetrics.includes(metric) && selectedMetrics.has(metric) | |
) | |
} | |
return ( | |
<div className="rounded shadow overflow-auto"> | |
<h3 className="font-bold mb-2">{file}</h3> | |
{loading && <div>Loading...</div>} | |
{error && <div className="text-red-500">{error}</div>} | |
{!loading && !error && ( | |
<div className="overflow-x-auto"> | |
<LeaderboardFilter | |
groups={groups} | |
selectedMetrics={selectedMetrics} | |
setSelectedMetrics={setSelectedMetrics} | |
defaultSelectedMetrics={defaultSelectedMetrics} | |
/> | |
<table className="table w-full"> | |
<thead> | |
<tr> | |
<th>Group / Subgroup</th> | |
{overallMetrics.map((metric) => ( | |
<th key={metric} colSpan={tableHeader.length} className="text-center border-x"> | |
{metric} | |
</th> | |
))} | |
</tr> | |
<tr> | |
<th></th> | |
{overallMetrics.map((metric) => ( | |
<React.Fragment key={`header-models-${metric}`}> | |
{tableHeader.map((model) => ( | |
<th key={`${metric}-${model}`} className="text-center text-xs"> | |
{model} | |
</th> | |
))} | |
</React.Fragment> | |
))} | |
</tr> | |
</thead> | |
<tbody> | |
{/* First render each group */} | |
{Object.entries(groups).map(([group, subGroups]) => { | |
// Get all metrics for this group | |
const allGroupMetrics = Object.values(subGroups).flat() | |
// Filter to only include selected metrics | |
const visibleGroupMetrics = filterMetricsByGroupAndSubgroup(allGroupMetrics, group) | |
// Skip this group if no metrics are selected | |
if (visibleGroupMetrics.length === 0) return null | |
return ( | |
<React.Fragment key={group}> | |
{/* Group row with average stats for the entire group */} | |
<tr | |
className="bg-base-200 cursor-pointer hover:bg-base-300" | |
onClick={() => toggleGroup(group)} | |
> | |
<td className="font-medium"> | |
{openGroups[group] ? '▼ ' : '▶ '} | |
{group} | |
</td> | |
{/* For each metric column */} | |
{overallMetrics.map((metric) => ( | |
// Render sub-columns for each model | |
<React.Fragment key={`${group}-${metric}`}> | |
{tableHeader.map((col) => { | |
// Find all metrics in this group that match the current metric name | |
const allMetricsWithName = findAllMetricsForName(metric) | |
const metricsInGroupForThisMetric = visibleGroupMetrics.filter((m) => | |
allMetricsWithName.includes(m) | |
) | |
const stats = calculateStats(metricsInGroupForThisMetric, col) | |
return ( | |
<td | |
key={`${group}-${metric}-${col}`} | |
className="font-medium text-center" | |
> | |
{!isNaN(stats.avg) | |
? `${stats.avg.toFixed(3)} ± ${stats.stdDev.toFixed(3)}` | |
: 'N/A'} | |
</td> | |
) | |
})} | |
</React.Fragment> | |
))} | |
</tr> | |
{/* Only render subgroups if group is open */} | |
{openGroups[group] && | |
Object.entries(subGroups).map(([subGroup, metrics]) => { | |
// Filter to only include selected metrics in this subgroup | |
const visibleSubgroupMetrics = filterMetricsByGroupAndSubgroup( | |
metrics, | |
group, | |
subGroup | |
) | |
// Skip this subgroup if no metrics are selected | |
if (visibleSubgroupMetrics.length === 0) return null | |
return ( | |
<React.Fragment key={`${group}-${subGroup}`}> | |
{/* Subgroup row with average stats for the subgroup */} | |
<tr | |
className="bg-base-100 cursor-pointer hover:bg-base-200" | |
onClick={() => toggleSubGroup(group, subGroup)} | |
> | |
<td className="pl-6 font-medium"> | |
{openSubGroups[group]?.[subGroup] ? '▼ ' : '▶ '} | |
{subGroup} | |
</td> | |
{/* For each metric column */} | |
{overallMetrics.map((metric) => ( | |
// Render sub-columns for each model | |
<React.Fragment key={`${group}-${subGroup}-${metric}`}> | |
{tableHeader.map((col) => { | |
// Find all metrics in this subgroup that match the current metric name | |
const allMetricsWithName = findAllMetricsForName(metric) | |
const metricsInSubgroupForThisMetric = | |
visibleSubgroupMetrics.filter((m) => | |
allMetricsWithName.includes(m) | |
) | |
const stats = calculateStats( | |
metricsInSubgroupForThisMetric, | |
col | |
) | |
return ( | |
<td | |
key={`${group}-${subGroup}-${metric}-${col}`} | |
className="font-medium text-center" | |
> | |
{!isNaN(stats.avg) | |
? `${stats.avg.toFixed(3)} ± ${stats.stdDev.toFixed(3)}` | |
: 'N/A'} | |
</td> | |
) | |
})} | |
</React.Fragment> | |
))} | |
</tr> | |
{/* Individual metric rows */} | |
{openSubGroups[group]?.[subGroup] && | |
// Sort visibleSubgroupMetrics alphabetically by the clean metric name | |
[...visibleSubgroupMetrics] | |
.sort((a, b) => { | |
// Extract clean metric names (after the underscore) | |
console.log({ a }) | |
// For metrics with format {category}_{strength}_{overall_metric_name}, | |
// First sort by category, then by overall_metric_name, then by strength | |
// First extract the overall metric group | |
const getOverallMetricGroup = (metric: string) => { | |
for (const overall of overallMetrics) { | |
if (metric.endsWith(`_${overall}`) || metric === overall) { | |
return overall | |
} | |
} | |
return '' | |
} | |
const overallA = getOverallMetricGroup(a) | |
const overallB = getOverallMetricGroup(b) | |
// Extract the strength (last part before the overall metric) | |
const stripOverall = (metric: string, overall: string) => { | |
if (metric.endsWith(`_${overall}`)) { | |
// Remove the overall metric group and any preceding underscore | |
const stripped = metric.slice( | |
0, | |
metric.length - overall.length - 1 | |
) | |
const parts = stripped.split('_') | |
return parts.length > 0 ? parts[parts.length - 1] : '' | |
} | |
return metric | |
} | |
// Extract the category (what remains after removing strength and overall_metric_name) | |
const getCategory = (metric: string, overall: string) => { | |
if (metric.endsWith(`_${overall}`)) { | |
const stripped = metric.slice( | |
0, | |
metric.length - overall.length - 1 | |
) | |
const parts = stripped.split('_') | |
// Remove the last part (strength) and join the rest (category) | |
return parts.length > 1 | |
? parts.slice(0, parts.length - 1).join('_') | |
: '' | |
} | |
return metric | |
} | |
const categoryA = getCategory(a, overallA) | |
const categoryB = getCategory(b, overallB) | |
// First sort by category | |
if (categoryA !== categoryB) { | |
return categoryA.localeCompare(categoryB) | |
} | |
// Then sort by overall metric name | |
if (overallA !== overallB) { | |
return overallA.localeCompare(overallB) | |
} | |
// Finally sort by strength | |
const subA = stripOverall(a, overallA) | |
const subB = stripOverall(b, overallB) | |
// Try to parse subA and subB as numbers, handling k/m/b suffixes | |
const parseNumber = (str: string) => { | |
const match = str.match(/^(\d+(?:\.\d+)?)([kKmMbB]?)$/) | |
if (!match) return NaN | |
let [_, num, suffix] = match | |
let value = parseFloat(num) | |
switch (suffix.toLowerCase()) { | |
case 'k': | |
value *= 1e3 | |
break | |
case 'm': | |
value *= 1e6 | |
break | |
case 'b': | |
value *= 1e9 | |
break | |
} | |
return value | |
} | |
const numA = parseNumber(subA) | |
const numB = parseNumber(subB) | |
if (!isNaN(numA) && !isNaN(numB)) { | |
return numA - numB | |
} | |
// Fallback to string comparison if not both numbers | |
return subA.localeCompare(subB) | |
}) | |
.map((metric) => { | |
const row = tableRows.find((r) => r.metric === metric) | |
if (!row) return null | |
// Extract the metric name (after the underscore) | |
const metricName = metric.includes('_') | |
? metric.split('_').slice(1).join('_') | |
: metric | |
return ( | |
<tr key={metric} className="hover:bg-base-100"> | |
<td className="pl-10">{metric}</td> | |
{/* For each metric column */} | |
{overallMetrics.map((oMetric) => { | |
// Only show values for the matching metric | |
const isMatchingMetric = | |
findAllMetricsForName(oMetric).includes(metric) | |
if (!isMatchingMetric) { | |
// Fill empty cells for non-matching metrics | |
return ( | |
<React.Fragment key={`${metric}-${oMetric}`}> | |
{tableHeader.map((col) => ( | |
<td | |
key={`${metric}-${oMetric}-${col}`} | |
className="text-center" | |
></td> | |
))} | |
</React.Fragment> | |
) | |
} | |
// Show values for the matching metric | |
return ( | |
<React.Fragment key={`${metric}-${oMetric}`}> | |
{tableHeader.map((col) => { | |
const cell = row[col] | |
return ( | |
<td | |
key={`${metric}-${oMetric}-${col}`} | |
className="text-center" | |
> | |
{!isNaN(Number(cell)) | |
? Number(Number(cell).toFixed(3)) | |
: cell} | |
</td> | |
) | |
})} | |
</React.Fragment> | |
) | |
})} | |
</tr> | |
) | |
})} | |
</React.Fragment> | |
) | |
})} | |
</React.Fragment> | |
) | |
})} | |
</tbody> | |
</table> | |
</div> | |
)} | |
</div> | |
) | |
} | |
export default LeaderboardTable | |