import React, { useEffect, useState } from 'react' import LeaderboardFilter from './LeaderboardFilter' import LoadingSpinner from './LoadingSpinner' interface LeaderboardTableProps { benchmarkData: any selectedModels: Set } interface Row { metric: string [key: string]: string | number } interface Groups { [group: string]: { [subgroup: string]: string[] } } const OverallMetricFilter: React.FC<{ overallMetrics: string[] selectedOverallMetrics: Set setSelectedOverallMetrics: (metrics: Set) => void }> = ({ overallMetrics, selectedOverallMetrics, setSelectedOverallMetrics }) => { const toggleMetric = (metric: string) => { const newSelected = new Set(selectedOverallMetrics) if (newSelected.has(metric)) { newSelected.delete(metric) } else { newSelected.add(metric) } setSelectedOverallMetrics(newSelected) } return (
Metrics ({selectedOverallMetrics.size}/{overallMetrics.length})
{overallMetrics.map((metric) => ( ))}
) } const LeaderboardTable: React.FC = ({ benchmarkData, selectedModels }) => { const [tableRows, setTableRows] = useState([]) const [tableHeader, setTableHeader] = useState([]) const [error, setError] = useState(null) const [groups, setGroups] = useState({}) const [openGroups, setOpenGroups] = useState<{ [key: string]: boolean }>({}) const [openSubGroups, setOpenSubGroups] = useState<{ [key: string]: { [key: string]: boolean } }>( {} ) const [selectedMetrics, setSelectedMetrics] = useState>(new Set()) const [overallMetrics, setOverallMetrics] = useState([]) const [selectedOverallMetrics, setSelectedOverallMetrics] = useState>(new Set()) useEffect(() => { if (!benchmarkData) { return } try { const data = benchmarkData const rows: Row[] = data['rows'] const allGroups = data['groups'] as { [key: string]: string[] } const { Overall: overallGroup, ...groups } = allGroups const uniqueMetrics = new Set() overallGroup?.forEach((metric) => { if (metric.includes('_')) { const metricName = metric.split('_').slice(1).join('_') uniqueMetrics.add(metricName) } }) setOverallMetrics(Array.from(uniqueMetrics).sort()) setSelectedOverallMetrics(new Set(Array.from(uniqueMetrics))) const groupsData = Object.entries(groups) .sort(([groupA], [groupB]) => { if (groupA === 'Overall') return -1 if (groupB === 'Overall') return 1 return groupA.localeCompare(groupB) }) .reduce( (acc, [group, metrics]) => { const sortedMetrics = [...metrics].sort() acc[group] = sortedMetrics.reduce<{ [key: string]: string[] }>((subAcc, metric) => { const [mainGroup, subGroup] = metric.split('_') if (!subAcc[mainGroup]) { subAcc[mainGroup] = [] } subAcc[mainGroup].push(metric) return subAcc }, {}) acc[group] = Object.fromEntries( Object.entries(acc[group]).sort(([subGroupA], [subGroupB]) => subGroupA.localeCompare(subGroupB) ) ) return acc }, {} as { [key: string]: { [key: string]: string[] } } ) const allKeys: string[] = Array.from(new Set(rows.flatMap((row) => Object.keys(row)))) const headers = allKeys.filter((key) => key !== 'metric') const initialOpenGroups: { [key: string]: boolean } = {} const initialOpenSubGroups: { [key: string]: { [key: string]: boolean } } = {} Object.keys(groupsData).forEach((group) => { initialOpenGroups[group] = false initialOpenSubGroups[group] = {} Object.keys(groupsData[group]).forEach((subGroup) => { initialOpenSubGroups[group][subGroup] = false }) }) const allMetrics = Object.values(groups).flat() setSelectedMetrics(new Set(allMetrics)) setTableHeader(headers) setTableRows(rows) setGroups(groupsData) setOpenGroups(initialOpenGroups) setOpenSubGroups(initialOpenSubGroups) setError(null) } catch (err: any) { setError('Failed to parse benchmark data, please try again: ' + err.message) } }, [benchmarkData]) const toggleGroup = (group: string) => { setOpenGroups((prev) => ({ ...prev, [group]: !prev[group] })) } const toggleSubGroup = (group: string, subGroup: string) => { setOpenSubGroups((prev) => ({ ...prev, [group]: { ...(prev[group] || {}), [subGroup]: !prev[group]?.[subGroup], }, })) } // Find all metrics matching a particular extracted metric name (like "log10_p_value") const findAllMetricsForName = (metricName: string): string[] => { return tableRows .filter((row) => { const metric = row.metric as string if (metric.includes('_')) { const extractedName = metric.split('_').slice(1).join('_') return extractedName.endsWith(metricName) } return false }) .map((row) => row.metric as string) } // Identify metrics that don't belong to any overall metric group const findStandaloneMetrics = (): string[] => { // Get all metrics from the table rows const allMetrics = tableRows.map((row) => row.metric as string) // Filter to only include metrics that aren't part of any of the overall metrics return allMetrics.filter((metric) => { // Check if this metric is part of any of the overall metrics for (const overall of overallMetrics) { if (metric.endsWith(`_${overall}`) || metric === overall) { return false // This metric belongs to an overall group } } return true }) } // Calculate average and standard deviation for a set of metrics for a specific column const calculateStats = ( metricNames: string[], columnKey: string ): { avg: number; stdDev: number } => { const values = metricNames .map((metricName) => { const row = tableRows.find((row) => row.metric === metricName) return row ? Number(row[columnKey]) : NaN }) .filter((value) => !isNaN(value)) if (values.length === 0) return { avg: NaN, stdDev: NaN } const avg = values.reduce((sum, val) => sum + val, 0) / values.length const squareDiffs = values.map((value) => { const diff = value - avg return diff * diff }) const variance = squareDiffs.reduce((sum, sqrDiff) => sum + sqrDiff, 0) / values.length const stdDev = Math.sqrt(variance) return { avg, stdDev } } // Filter metrics by group and/or subgroup const filterMetricsByGroupAndSubgroup = ( metricNames: string[], group: string | null = null, subgroup: string | null = null ): string[] => { // If no group specified, return all metrics if (!group) return metricNames // Get all metrics for the specified group const groupMetrics = Object.values(groups[group] || {}).flat() // If subgroup is specified, further filter to that subgroup if (subgroup && groups[group]?.[subgroup]) { return metricNames.filter( (metric) => groups[group][subgroup].includes(metric) && selectedMetrics.has(metric) ) } // Otherwise return all metrics in the group return metricNames.filter( (metric) => groupMetrics.includes(metric) && selectedMetrics.has(metric) ) } return (
{error &&
{error}
} {!error && (
{/* */}
{selectedModels.size === 0 || selectedMetrics.size === 0 ? (
Please select at least one model and one metric to display the data
) : ( <> {/* Standalone metrics table */} {(() => { const standaloneMetrics = findStandaloneMetrics() if (standaloneMetrics.length === 0) return null return (
{tableHeader .filter((model) => selectedModels.has(model)) .map((model) => ( ))} {standaloneMetrics.sort().map((metric) => { const row = tableRows.find((r) => r.metric === metric) if (!row) return null return ( {tableHeader .filter((model) => selectedModels.has(model)) .map((col) => { const cell = row[col] return ( ) })} ) })}
Metric {model}
{metric} {!isNaN(Number(cell)) ? Number(Number(cell).toFixed(3)) : cell}
) })()} {/* Main metrics table */}
{overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .map((metric) => ( ))} {overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .map((metric) => ( {tableHeader .filter((model) => selectedModels.has(model)) .map((model) => ( ))} ))} {/* First render each group */} {Object.entries(groups).map(([group, subGroups]) => { // Skip the "Overall" group completely if (group === 'Overall') return null // Get all metrics for this group const allGroupMetrics = Object.values(subGroups).flat() // Filter to only include selected metrics const visibleGroupMetrics = filterMetricsByGroupAndSubgroup( allGroupMetrics, group ) // Skip this group if no metrics are selected if (visibleGroupMetrics.length === 0) return null return ( {/* Group row with average stats for the entire group */} toggleGroup(group)} > {/* For each metric column */} {overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .map((metric) => ( // Render sub-columns for each model {tableHeader .filter((model) => selectedModels.has(model)) .map((col) => { // Find all metrics in this group that match the current metric name const allMetricsWithName = findAllMetricsForName(metric) const metricsInGroupForThisMetric = visibleGroupMetrics.filter((m) => allMetricsWithName.includes(m) ) const stats = calculateStats(metricsInGroupForThisMetric, col) return ( ) })} ))} {/* Only render subgroups if group is open */} {openGroups[group] && Object.entries(subGroups).map(([subGroup, metrics]) => { // Filter to only include selected metrics in this subgroup const visibleSubgroupMetrics = filterMetricsByGroupAndSubgroup( metrics, group, subGroup ) // Skip this subgroup if no metrics are selected if (visibleSubgroupMetrics.length === 0) return null return ( {/* Subgroup row with average stats for the subgroup */} toggleSubGroup(group, subGroup)} > {/* For each metric column */} {overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .map((metric) => ( // Render sub-columns for each model {tableHeader .filter((model) => selectedModels.has(model)) .map((col) => { // Find all metrics in this subgroup that match the current metric name const allMetricsWithName = findAllMetricsForName(metric) const metricsInSubgroupForThisMetric = visibleSubgroupMetrics.filter((m) => allMetricsWithName.includes(m) ) const stats = calculateStats( metricsInSubgroupForThisMetric, col ) return ( ) })} ))} {/* Individual metric rows */} {openSubGroups[group]?.[subGroup] && // Sort visibleSubgroupMetrics alphabetically by the clean metric name [...visibleSubgroupMetrics] .sort((a, b) => { // For metrics with format {category}_{strength}_{overall_metric_name}, // First sort by category, then by overall_metric_name, then by strength // First extract the overall metric group const getOverallMetricGroup = (metric: string) => { for (const overall of overallMetrics) { if ( metric.endsWith(`_${overall}`) || metric === overall ) { return overall } } return '' } const overallA = getOverallMetricGroup(a) const overallB = getOverallMetricGroup(b) // Extract the strength (last part before the overall metric) const stripOverall = (metric: string, overall: string) => { if (metric.endsWith(`_${overall}`)) { // Remove the overall metric group and any preceding underscore const stripped = metric.slice( 0, metric.length - overall.length - 1 ) const parts = stripped.split('_') return parts.length > 0 ? parts[parts.length - 1] : '' } return metric } // Extract the category (what remains after removing strength and overall_metric_name) const getCategory = (metric: string, overall: string) => { if (metric.endsWith(`_${overall}`)) { const stripped = metric.slice( 0, metric.length - overall.length - 1 ) const parts = stripped.split('_') // Remove the last part (strength) and join the rest (category) return parts.length > 1 ? parts.slice(0, parts.length - 1).join('_') : '' } return metric } const categoryA = getCategory(a, overallA) const categoryB = getCategory(b, overallB) // First sort by category if (categoryA !== categoryB) { return categoryA.localeCompare(categoryB) } // Then sort by overall metric name if (overallA !== overallB) { return overallA.localeCompare(overallB) } // Finally sort by strength const subA = stripOverall(a, overallA) const subB = stripOverall(b, overallB) // Try to parse subA and subB as numbers, handling k/m/b suffixes const parseNumber = (str: string) => { const match = str.match(/^(\d+(?:\.\d+)?)([kKmMbB]?)$/) if (!match) return NaN let [_, num, suffix] = match let value = parseFloat(num) switch (suffix.toLowerCase()) { case 'k': value *= 1e3 break case 'm': value *= 1e6 break case 'b': value *= 1e9 break } return value } const numA = parseNumber(subA) const numB = parseNumber(subB) if (!isNaN(numA) && !isNaN(numB)) { return numA - numB } // Fallback to string comparison if not both numbers return subA.localeCompare(subB) }) .map((metric) => { const row = tableRows.find((r) => r.metric === metric) if (!row) return null // Extract the metric name (after the underscore) const metricName = metric.includes('_') ? metric.split('_').slice(1).join('_') : metric return ( {/* For each metric column */} {overallMetrics .filter((oMetric) => selectedOverallMetrics.has(oMetric) ) .map((oMetric) => { // Only show values for the matching metric const isMatchingMetric = findAllMetricsForName(oMetric).includes(metric) if (!isMatchingMetric) { // Fill empty cells for non-matching metrics return ( {tableHeader .filter((model) => selectedModels.has(model) ) .map((col) => ( ))} ) } return ( {tableHeader .filter((model) => selectedModels.has(model)) .map((col) => { const cell = row[col] return ( ) })} ) })} ) })} ) })} ) })}
Attack Category Metrics selectedModels.has(model)).length } className="sticky top-0 bg-base-100 z-10 text-center border-x border-gray-300 border border-gray-700 border" > {metric}
{model}
{openGroups[group] ? '▼ ' : '▶ '} {group} {!isNaN(stats.avg) ? `${stats.avg.toFixed(3)} ± ${stats.stdDev.toFixed(3)}` : 'N/A'}
{openSubGroups[group]?.[subGroup] ? '▼ ' : '▶ '} {subGroup} {!isNaN(stats.avg) ? `${stats.avg.toFixed(3)} ± ${stats.stdDev.toFixed(3)}` : 'N/A'}
{metric} {!isNaN(Number(cell)) ? Number(Number(cell).toFixed(3)) : cell}
)}
)}
) } export default LeaderboardTable