import React, { useEffect, useState } from 'react' import LeaderboardFilter from './LeaderboardFilter' import LoadingSpinner from './LoadingSpinner' import IndependentMetricsTable from './IndependentMetricsTable' interface LeaderboardTableProps { benchmarkData: any selectedModels: Set } interface Row { metric: string [key: string]: string | number } interface Groups { [group: string]: { [subgroup: string]: string[] } } interface SortState { [overallMetric: string]: { [model: string]: { direction: 'asc' | 'desc' } } } const OverallMetricFilter: React.FC<{ overallMetrics: string[] selectedOverallMetrics: Set setSelectedOverallMetrics: (metrics: Set) => void }> = ({ overallMetrics, selectedOverallMetrics, setSelectedOverallMetrics }) => { const toggleMetric = (metric: string) => { const newSelected = new Set(selectedOverallMetrics) if (newSelected.has(metric)) { newSelected.delete(metric) } else { newSelected.add(metric) } setSelectedOverallMetrics(newSelected) } return (
Metrics ({selectedOverallMetrics.size}/{overallMetrics.length})
{overallMetrics.map((metric) => ( ))}
) } const LeaderboardTable: React.FC = ({ benchmarkData, selectedModels }) => { const [tableRows, setTableRows] = useState([]) const [tableHeader, setTableHeader] = useState([]) const [error, setError] = useState(null) const [groupRows, setGroupRows] = useState({}) const [openGroupRows, setOpenGroupRows] = useState<{ [key: string]: boolean }>({}) const [openSubGroupRows, setOpenSubGroupRows] = useState<{ [key: string]: { [key: string]: boolean } }>({}) const [selectedMetrics, setSelectedMetrics] = useState>(new Set()) const [overallMetrics, setOverallMetrics] = useState([]) const [selectedOverallMetrics, setSelectedOverallMetrics] = useState>(new Set()) const [sortState, setSortState] = useState({}) const [columnSortState, setColumnSortState] = useState({}) // Add state for row-based column sorting const [selectedRowForSort, setSelectedRowForSort] = useState<{ [rowKey: string]: { direction: 'asc' | 'desc' } }>({}) useEffect(() => { if (!benchmarkData) { return } try { const data = benchmarkData const rows: Row[] = data['rows'] const allGroups = data['groups'] as { [key: string]: string[] } const { Overall: overallGroup, ...groups } = allGroups const uniqueMetrics = new Set() overallGroup?.forEach((metric) => { if (metric.includes('_')) { const metricName = metric.split('_').slice(1).join('_') uniqueMetrics.add(metricName) } }) setOverallMetrics(Array.from(uniqueMetrics).sort()) setSelectedOverallMetrics(new Set(Array.from(uniqueMetrics))) const groupsData = Object.entries(groups) .sort(([groupA], [groupB]) => { if (groupA === 'Overall') return -1 if (groupB === 'Overall') return 1 return groupA.localeCompare(groupB) }) .reduce( (acc, [group, metrics]) => { const sortedMetrics = [...metrics].sort() acc[group] = sortedMetrics.reduce<{ [key: string]: string[] }>((subAcc, metric) => { const [mainGroup, subGroup] = metric.split('_') if (!subAcc[mainGroup]) { subAcc[mainGroup] = [] } subAcc[mainGroup].push(metric) return subAcc }, {}) acc[group] = Object.fromEntries( Object.entries(acc[group]).sort(([subGroupA], [subGroupB]) => subGroupA.localeCompare(subGroupB) ) ) return acc }, {} as { [key: string]: { [key: string]: string[] } } ) const allKeys: string[] = Array.from(new Set(rows.flatMap((row) => Object.keys(row)))) const headers = allKeys.filter((key) => key !== 'metric') const initialOpenGroups: { [key: string]: boolean } = {} const initialOpenSubGroups: { [key: string]: { [key: string]: boolean } } = {} Object.keys(groupsData).forEach((group) => { initialOpenGroups[group] = false initialOpenSubGroups[group] = {} Object.keys(groupsData[group]).forEach((subGroup) => { initialOpenSubGroups[group][subGroup] = false }) }) const allMetrics = Object.values(groups).flat() setSelectedMetrics(new Set(allMetrics)) setTableHeader(headers) setTableRows(rows) setGroupRows(groupsData) setOpenGroupRows(initialOpenGroups) setOpenSubGroupRows(initialOpenSubGroups) setError(null) } catch (err: any) { setError('Failed to parse benchmark data, please try again: ' + err.message) } }, [benchmarkData]) const toggleGroup = (group: string) => { setOpenGroupRows((prev) => ({ ...prev, [group]: !prev[group] })) } const toggleSubGroup = (group: string, subGroup: string) => { setOpenSubGroupRows((prev) => ({ ...prev, [group]: { ...(prev[group] || {}), [subGroup]: !prev[group]?.[subGroup], }, })) } const handleSort = (overallMetric: string, model: string) => { setSortState((prev) => { const prevDir = prev[overallMetric]?.[model]?.direction let newSortState: SortState = {} if (!prevDir) { // No sort yet, set to 'asc' newSortState[overallMetric] = { [model]: { direction: 'asc' } } } else if (prevDir === 'asc') { // Was 'asc', set to 'desc' newSortState[overallMetric] = { [model]: { direction: 'desc' } } } // Else revert back to unsorted state return newSortState }) } // Helper to generate a stable composite key for row-based column sorting function getRowSortKey(group: string | null, subGroup: string | null, metric: string | null) { return `${group ?? ''}||${subGroup ?? ''}||${metric ?? ''}` } // Update handleColumnSort to use setSelectedRowForSort const handleColumnSort = ( group: string | null, subGroup: string | null, metric: string | null ) => { const rowKey = getRowSortKey(group, subGroup, metric) setSelectedRowForSort((prev) => { const prevDir = prev[rowKey]?.direction const newSortState: { [rowKey: string]: { direction: 'asc' | 'desc' } } = {} if (!prevDir) { newSortState[rowKey] = { direction: 'asc' } } else if (prevDir === 'asc') { newSortState[rowKey] = { direction: 'desc' } } else if (prevDir === 'desc') { delete newSortState[rowKey] } return newSortState }) } // Helper to get current row sort config for a row function getRowColumnSort(group: string | null, subGroup: string | null, metric: string | null) { return selectedRowForSort[getRowSortKey(group, subGroup, metric)] || null } const getSortConfig = () => { // Find the first sorted column (overallMetric, model) console.log({ sortState }) for (const overallMetric of overallMetrics) { if (!selectedOverallMetrics.has(overallMetric)) continue const models = tableHeader.filter((model) => selectedModels.has(model)) for (const model of models) { if (sortState[overallMetric]?.[model]) { return { overallMetric, model, direction: sortState[overallMetric][model].direction } } } } return null } // Move getRowSortConfig above sortModelColumns so it is defined before use const getRowSortConfig = () => { for (const overallMetric of overallMetrics) { if (!selectedOverallMetrics.has(overallMetric)) continue const models = tableHeader.filter((model) => selectedModels.has(model)) for (const model of models) { if (sortState[overallMetric]?.[model]) { return { overallMetric, model, direction: sortState[overallMetric][model].direction } } } } return null } const getColumnSortConfig = () => { for (const overallMetric of overallMetrics) { if (!selectedOverallMetrics.has(overallMetric)) continue if (columnSortState[overallMetric]?.['__col__']) { return { overallMetric, direction: columnSortState[overallMetric]['__col__'].direction } } } return null } const sortModelColumns = (models: string[], overallMetric: string): string[] => { // Column sort takes precedence; if no column sort, return models in default order const columnSortConfig = getColumnSortConfig() console.log({ columnSortConfig, overallMetric }) if (columnSortConfig && columnSortConfig.overallMetric === overallMetric) { // Sort by average value for each model in this overallMetric return [...models].sort((a, b) => { const valsA = tableRows .filter((row) => findAllMetricsForName(overallMetric).includes(row.metric as string)) .map((row) => Number(row[a])) .filter((v) => !isNaN(v)) const valsB = tableRows .filter((row) => findAllMetricsForName(overallMetric).includes(row.metric as string)) .map((row) => Number(row[b])) .filter((v) => !isNaN(v)) const avgA = valsA.length ? valsA.reduce((s, v) => s + v, 0) / valsA.length : NaN const avgB = valsB.length ? valsB.reduce((s, v) => s + v, 0) / valsB.length : NaN if (isNaN(avgA) && isNaN(avgB)) return 0 if (isNaN(avgA)) return 1 if (isNaN(avgB)) return -1 return columnSortConfig.direction === 'asc' ? avgA - avgB : avgB - avgA }) } // No column sort: return models in default order return models } const sortRowsBySubcolumn = ( rows: string[], overallMetric: string, model: string, direction: 'asc' | 'desc' ) => { return [...rows].sort((a, b) => { const rowA = tableRows.find((r) => r.metric === a) const rowB = tableRows.find((r) => r.metric === b) if (!rowA || !rowB) return 0 const valA = Number(rowA[model]) const valB = Number(rowB[model]) if (isNaN(valA) && isNaN(valB)) return 0 if (isNaN(valA)) return 1 if (isNaN(valB)) return -1 return direction === 'asc' ? valA - valB : valB - valA }) } // Find all metrics matching a particular extracted metric name (like "log10_p_value") const findAllMetricsForName = (metricName: string): string[] => { return tableRows .filter((row) => { const metric = row.metric as string if (metric.includes('_')) { const extractedName = metric.split('_').slice(1).join('_') return extractedName.endsWith(metricName) } return false }) .map((row) => row.metric as string) } // Identify metrics that don't belong to any overall metric group const findStandaloneMetrics = (): string[] => { // Get all metrics from the table rows const allMetrics = tableRows.map((row) => row.metric as string) // Filter to only include metrics that aren't part of any of the overall metrics return allMetrics.filter((metric) => { // Check if this metric is part of any of the overall metrics for (const overall of overallMetrics) { if (metric.endsWith(`_${overall}`) || metric === overall) { return false // This metric belongs to an overall group } } return true }) } // Calculate average and standard deviation for a set of metrics for a specific column const calculateStats = ( metricNames: string[], columnKey: string ): { avg: number; stdDev: number } => { const values = metricNames .map((metricName) => { const row = tableRows.find((row) => row.metric === metricName) return row ? Number(row[columnKey]) : NaN }) .filter((value) => !isNaN(value)) if (values.length === 0) return { avg: NaN, stdDev: NaN } const avg = values.reduce((sum, val) => sum + val, 0) / values.length const squareDiffs = values.map((value) => { const diff = value - avg return diff * diff }) const variance = squareDiffs.reduce((sum, sqrDiff) => sum + sqrDiff, 0) / values.length const stdDev = Math.sqrt(variance) return { avg, stdDev } } // Filter metrics by group and/or subgroup const filterMetricsByGroupAndSubgroup = ( metricNames: string[], group: string | null = null, subgroup: string | null = null ): string[] => { // If no group specified, return all metrics if (!group) return metricNames // Get all metrics for the specified group const groupMetrics = Object.values(groupRows[group] || {}).flat() as string[] // If subgroup is specified, further filter to that subgroup if (subgroup && groupRows[group]?.[subgroup]) { return metricNames.filter( (metric) => groupRows[group][subgroup].includes(metric) && selectedMetrics.has(metric) ) } // Otherwise return all metrics in the group return metricNames.filter( (metric) => groupMetrics.includes(metric) && selectedMetrics.has(metric) ) } // Before rendering group rows: const groupSortConfig = getSortConfig() let groupEntries = Object.entries(groupRows).filter(([group]) => group !== 'Overall') if (groupSortConfig) { groupEntries = groupEntries.sort(([groupA, subGroupsA], [groupB, subGroupsB]) => { // For each group, get all metrics in the group for the selected overallMetric const allMetricsWithName = findAllMetricsForName(groupSortConfig.overallMetric) const getGroupAvg = (subGroups: { [key: string]: string[] }) => { const allGroupMetrics = Object.values(subGroups).flat() const metricsInGroupForThisMetric = allGroupMetrics.filter((m) => allMetricsWithName.includes(m) ) const stats = calculateStats(metricsInGroupForThisMetric, groupSortConfig.model) return stats.avg } const avgA = getGroupAvg(subGroupsA) const avgB = getGroupAvg(subGroupsB) if (isNaN(avgA) && isNaN(avgB)) return 0 if (isNaN(avgA)) return 1 if (isNaN(avgB)) return -1 return groupSortConfig.direction === 'asc' ? avgA - avgB : avgB - avgA }) } // Compute model order for each overall metric before rendering const modelOrderByOverallMetric: { [metric: string]: string[] } = {} overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .forEach((metric) => { // Check if there is an active row-based column sort for this metric let sortedModels: string[] | null = null // Find the active rowKey for this metric in rowColumnSort const activeRowKey = Object.keys(selectedRowForSort).find((rowKey) => { // rowKey format: group||subGroup||metric const [group, subGroup, rowMetric] = rowKey.split('||') // If rowMetric is empty, it's a group or subgroup row if (rowMetric === '' && metric === metric) return true // If rowMetric matches this metric, it's an individual metric row if (rowMetric && findAllMetricsForName(metric).includes(rowMetric)) return true return false }) if (activeRowKey && selectedRowForSort[activeRowKey]) { const direction = selectedRowForSort[activeRowKey].direction const [group, subGroup, rowMetric] = activeRowKey.split('||') const models = tableHeader.filter((model) => selectedModels.has(model)) if (!rowMetric) { // Group or subgroup row: sort by average for this group/subgroup and metric // Find all metrics in this group/subgroup for this overall metric let relevantMetrics: string[] = [] if (group && !subGroup) { // Group row const groupMetrics = Object.values(groupRows[group] || {}).flat() as string[] relevantMetrics = groupMetrics.filter((m: string) => findAllMetricsForName(metric).includes(m) ) } else if (group && subGroup) { // Subgroup row relevantMetrics = (groupRows[group]?.[subGroup] || []).filter((m: string) => findAllMetricsForName(metric).includes(m) ) } sortedModels = [...models].sort((a, b) => { const statsA = calculateStats(relevantMetrics, a) const statsB = calculateStats(relevantMetrics, b) if (isNaN(statsA.avg) && isNaN(statsB.avg)) return 0 if (isNaN(statsA.avg)) return 1 if (isNaN(statsB.avg)) return -1 return direction === 'asc' ? statsA.avg - statsB.avg : statsB.avg - statsA.avg }) } else { // Individual metric row: sort by value for that metric sortedModels = [...models].sort((a, b) => { const rowA = tableRows.find((r) => r.metric === rowMetric) const rowB = rowA // same row const valA = rowA ? Number(rowA[a]) : NaN const valB = rowB ? Number(rowB[b]) : NaN if (isNaN(valA) && isNaN(valB)) return 0 if (isNaN(valA)) return 1 if (isNaN(valB)) return -1 return direction === 'asc' ? valA - valB : valB - valA }) } } modelOrderByOverallMetric[metric] = sortedModels || sortModelColumns( tableHeader.filter((model) => selectedModels.has(model)), metric ) }) console.log({ modelOrderByOverallMetric }) return (
{error &&
{error}
} {!error && (
{selectedModels.size === 0 || selectedMetrics.size === 0 ? (
Please select at least one model and one metric to display the data
) : ( <> {/* Standalone metrics table */} {/* Main metrics table */}
{overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .map((metric) => ( ))} {overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .map((metric) => ( {modelOrderByOverallMetric[metric].map((model) => { const isSorted = sortState[metric]?.[model]?.direction !== undefined const direction = sortState[metric]?.[model]?.direction || 'desc' return ( ) })} ))} {/* First render each group row */} {groupEntries.map(([group, subGroups]) => { // Skip the "Overall" group completely if (group === 'Overall') return null // Get all metrics for this group row const allGroupMetrics = Object.values(subGroups).flat() // Filter to only include selected metrics const visibleGroupMetrics = filterMetricsByGroupAndSubgroup( allGroupMetrics, group ) // Skip this group row if no metrics are selected if (visibleGroupMetrics.length === 0) return null // Sort subgroups by average if sort config is active let subGroupEntries = Object.entries(subGroups) if (groupSortConfig) { const allMetricsWithName = findAllMetricsForName( groupSortConfig.overallMetric ) const getSubGroupAvg = (metrics: string[]) => { const metricsInSubGroupForThisMetric = metrics.filter((m) => allMetricsWithName.includes(m) ) const stats = calculateStats( metricsInSubGroupForThisMetric, groupSortConfig.model ) return stats.avg } subGroupEntries = subGroupEntries.sort(([, metricsA], [, metricsB]) => { const avgA = getSubGroupAvg(metricsA) const avgB = getSubGroupAvg(metricsB) if (isNaN(avgA) && isNaN(avgB)) return 0 if (isNaN(avgA)) return 1 if (isNaN(avgB)) return -1 return groupSortConfig.direction === 'asc' ? avgA - avgB : avgB - avgA }) } return ( {/* Group row with average stats for the entire group */} toggleGroup(group)} > {/* For each metric column */} {overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .map((metric) => { const rowKey = getRowSortKey(group, null, null) return ( {modelOrderByOverallMetric[metric].map((col: string) => { const allMetricsWithName = findAllMetricsForName(metric) const metricsInGroupForThisMetric = visibleGroupMetrics.filter((m) => allMetricsWithName.includes(m) ) const stats = calculateStats(metricsInGroupForThisMetric, col) return ( ) })} ) })} {/* Only render subgroups if group row is open */} {openGroupRows[group] && subGroupEntries.map(([subGroup, metrics]) => { // Filter to only include selected metrics in this subgroup row const visibleSubgroupMetrics = filterMetricsByGroupAndSubgroup( metrics, group, subGroup ) // Skip this subgroup row if no metrics are selected if (visibleSubgroupMetrics.length === 0) return null return ( {/* Subgroup row with average stats for the subgroup */} toggleSubGroup(group, subGroup)} > {/* For each metric column */} {overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .map((metric) => { const rowKey = getRowSortKey(group, subGroup, null) return ( {modelOrderByOverallMetric[metric].map( (col: string) => { const allMetricsWithName = findAllMetricsForName(metric) const metricsInSubgroupForThisMetric = visibleSubgroupMetrics.filter((m) => allMetricsWithName.includes(m) ) const stats = calculateStats( metricsInSubgroupForThisMetric, col ) return ( ) } )} ) })} {/* Individual metric rows */} {openSubGroupRows[group]?.[subGroup] && (() => { // Sorting logic for individual metric rows const sortConfig = getSortConfig() let sortedMetrics = [...visibleSubgroupMetrics] console.log( 'Sorting metrics for subgroup:', group, subGroup, 'with config:', sortConfig ) if (sortConfig) { // Only sort metrics that match the selected overallMetric and model const allMetricsWithName = findAllMetricsForName( sortConfig.overallMetric ) const metricsInSubgroupForThisMetric = sortedMetrics.filter( (m) => allMetricsWithName.includes(m) ) const metricsNotInSubgroupForThisMetric = sortedMetrics.filter( (m) => !allMetricsWithName.includes(m) ) // Only apply subcolumn sort to matching metrics, leave others in original order sortedMetrics = [ ...sortRowsBySubcolumn( metricsInSubgroupForThisMetric, sortConfig.overallMetric, sortConfig.model, sortConfig.direction ), ...metricsNotInSubgroupForThisMetric, ] } else { // Fallback sort logic (category, overall, strength) sortedMetrics = sortedMetrics.sort((a, b) => { // For metrics with format {category}_{strength}_{overall_metric_name}, // First sort by category, then by overall_metric_name, then by strength // First extract the overall metric group const getOverallMetricGroup = (metric: string) => { for (const overall of overallMetrics) { if ( metric.endsWith(`_${overall}`) || metric === overall ) { return overall } } return '' } const overallA = getOverallMetricGroup(a) const overallB = getOverallMetricGroup(b) // Extract the strength (last part before the overall metric) const stripOverall = ( metric: string, overall: string ) => { if (metric.endsWith(`_${overall}`)) { // Remove the overall metric group and any preceding underscore const stripped = metric.slice( 0, metric.length - overall.length - 1 ) const parts = stripped.split('_') return parts.length > 0 ? parts[parts.length - 1] : '' } return metric } // Extract the category (what remains after removing strength and overall_metric_name) const getCategory = (metric: string, overall: string) => { if (metric.endsWith(`_${overall}`)) { const stripped = metric.slice( 0, metric.length - overall.length - 1 ) const parts = stripped.split('_') // Remove the last part (strength) and join the rest (category) return parts.length > 1 ? parts.slice(0, parts.length - 1).join('_') : '' } return metric } const categoryA = getCategory(a, overallA) const categoryB = getCategory(b, overallB) // First sort by category if (categoryA !== categoryB) { return categoryA.localeCompare(categoryB) } // Then sort by overall metric name if (overallA !== overallB) { return overallA.localeCompare(overallB) } // Finally sort by strength const subA = stripOverall(a, overallA) const subB = stripOverall(b, overallB) // Try to parse subA and subB as numbers, handling k/m/b suffixes const parseNumber = (str: string) => { const match = str.match(/^\d+(?:\.\d+)?([kKmMbB]?)$/) if (!match) return NaN let [_, suffix] = match let value = parseFloat(str) switch (suffix?.toLowerCase()) { case 'k': value *= 1e3 break case 'm': value *= 1e6 break case 'b': value *= 1e9 break } return value } const numA = parseNumber(subA) const numB = parseNumber(subB) if (!isNaN(numA) && !isNaN(numB)) { return numA - numB } // Fallback to string comparison if not both numbers return subA.localeCompare(subB) }) } return sortedMetrics.map((metric) => { const row = tableRows.find((r) => r.metric === metric) if (!row) return null // Extract the metric name (after the underscore) const metricName = metric.includes('_') ? metric.split('_').slice(1).join('_') : metric return ( {overallMetrics .filter((oMetric) => selectedOverallMetrics.has(oMetric) ) .map((oMetric) => { const isMatchingMetric = findAllMetricsForName(oMetric).includes(metric) if (!isMatchingMetric) { return ( {modelOrderByOverallMetric[oMetric].map( (col) => ( ) )} ) } return ( {modelOrderByOverallMetric[oMetric].map( (col) => { const cell = row[col] return ( ) } )} ) })} ) }) })()} ) })} ) })}
Attack Category Metrics {metric}
handleSort(metric, model)} > {model} {isSorted ? (direction === 'asc' ? '↑' : '↓') : '⇅'}
{openGroupRows[group] ? '▼ ' : '▶ '} {group} {/* Sort icon: only this triggers sort, and shows default if unsorted */} { e.stopPropagation() handleColumnSort(group, null, null) }} title={ getRowColumnSort(group, null, null) ? getRowColumnSort(group, null, null)?.direction === 'asc' ? 'Sort descending' : 'Clear sort' : 'Sort by this row' } > {getRowColumnSort(group, null, null) ? getRowColumnSort(group, null, null)?.direction === 'asc' ? '↑' : '↓' : '⇅'} {!isNaN(stats.avg) ? `${stats.avg.toFixed(3)} ± ${stats.stdDev.toFixed(3)}` : 'N/A'}
{openSubGroupRows[group]?.[subGroup] ? '▼ ' : '▶ '} {subGroup} { e.stopPropagation() handleColumnSort(group, subGroup, null) }} title={ getRowColumnSort(group, subGroup, null) ? getRowColumnSort(group, subGroup, null)?.direction === 'asc' ? 'Sort descending' : 'Clear sort' : 'Sort by this row' } > {getRowColumnSort(group, subGroup, null) ? getRowColumnSort(group, subGroup, null)?.direction === 'asc' ? '↑' : '↓' : '⇅'} {!isNaN(stats.avg) ? `${stats.avg.toFixed(3)} ± ${stats.stdDev.toFixed(3)}` : 'N/A'}
{metric} { e.stopPropagation() handleColumnSort(group, subGroup, metric) }} title={ getRowColumnSort(group, subGroup, metric) ? getRowColumnSort(group, subGroup, metric) ?.direction === 'asc' ? 'Sort descending' : 'Clear sort' : 'Sort by this row' } > {getRowColumnSort(group, subGroup, metric) ? getRowColumnSort(group, subGroup, metric) ?.direction === 'asc' ? '▲' : '▼' : '⇅'} {!isNaN(Number(cell)) ? Number(Number(cell).toFixed(3)) : cell}
)}
)}
) } export default LeaderboardTable