Spaces:
Running
Running
import React, { useEffect, useState } from 'react' | |
import LeaderboardFilter from './LeaderboardFilter' | |
import LoadingSpinner from './LoadingSpinner' | |
import IndependentMetricsTable from './IndependentMetricsTable' | |
interface LeaderboardTableProps { | |
benchmarkData: any | |
selectedModels: Set<string> | |
} | |
interface Row { | |
metric: string | |
[key: string]: string | number | |
} | |
interface Groups { | |
[group: string]: { [subgroup: string]: string[] } | |
} | |
interface SortState { | |
[overallMetric: string]: { | |
[model: string]: { direction: 'asc' | 'desc' } | |
} | |
} | |
const OverallMetricFilter: React.FC<{ | |
overallMetrics: string[] | |
selectedOverallMetrics: Set<string> | |
setSelectedOverallMetrics: (metrics: Set<string>) => void | |
}> = ({ overallMetrics, selectedOverallMetrics, setSelectedOverallMetrics }) => { | |
const toggleMetric = (metric: string) => { | |
const newSelected = new Set(selectedOverallMetrics) | |
if (newSelected.has(metric)) { | |
newSelected.delete(metric) | |
} else { | |
newSelected.add(metric) | |
} | |
setSelectedOverallMetrics(newSelected) | |
} | |
return ( | |
<div className="w-full"> | |
<fieldset className="fieldset w-full p-4 rounded border border-gray-700 bg-base-200"> | |
<legend className="fieldset-legend font-semibold"> | |
Metrics ({selectedOverallMetrics.size}/{overallMetrics.length}) | |
</legend> | |
<div className="grid grid-cols-2 md:grid-cols-4 lg:grid-cols-6 gap-1 max-h-48 overflow-y-auto pr-2"> | |
{overallMetrics.map((metric) => ( | |
<label key={metric} className="flex items-center gap-2 text-sm"> | |
<input | |
type="checkbox" | |
className="form-checkbox h-4 w-4" | |
checked={selectedOverallMetrics.has(metric)} | |
onChange={() => toggleMetric(metric)} | |
/> | |
<span className="truncate" title={metric}> | |
{metric} | |
</span> | |
</label> | |
))} | |
</div> | |
</fieldset> | |
</div> | |
) | |
} | |
const LeaderboardTable: React.FC<LeaderboardTableProps> = ({ benchmarkData, selectedModels }) => { | |
const [tableRows, setTableRows] = useState<Row[]>([]) | |
const [tableHeader, setTableHeader] = useState<string[]>([]) | |
const [error, setError] = useState<string | null>(null) | |
const [groupRows, setGroupRows] = useState<Groups>({}) | |
const [openGroupRows, setOpenGroupRows] = useState<{ [key: string]: boolean }>({}) | |
const [openSubGroupRows, setOpenSubGroupRows] = useState<{ | |
[key: string]: { [key: string]: boolean } | |
}>({}) | |
const [selectedMetrics, setSelectedMetrics] = useState<Set<string>>(new Set()) | |
const [overallMetrics, setOverallMetrics] = useState<string[]>([]) | |
const [selectedOverallMetrics, setSelectedOverallMetrics] = useState<Set<string>>(new Set()) | |
const [sortState, setSortState] = useState<SortState>({}) | |
const [columnSortState, setColumnSortState] = useState<SortState>({}) | |
// Add state for row-based column sorting | |
const [selectedRowForSort, setSelectedRowForSort] = useState<{ | |
[rowKey: string]: { direction: 'asc' | 'desc' } | |
}>({}) | |
useEffect(() => { | |
if (!benchmarkData) { | |
return | |
} | |
try { | |
const data = benchmarkData | |
const rows: Row[] = data['rows'] | |
const allGroups = data['groups'] as { [key: string]: string[] } | |
const { Overall: overallGroup, ...groups } = allGroups | |
const uniqueMetrics = new Set<string>() | |
overallGroup?.forEach((metric) => { | |
if (metric.includes('_')) { | |
const metricName = metric.split('_').slice(1).join('_') | |
uniqueMetrics.add(metricName) | |
} | |
}) | |
setOverallMetrics(Array.from(uniqueMetrics).sort()) | |
setSelectedOverallMetrics(new Set(Array.from(uniqueMetrics))) | |
const groupsData = Object.entries(groups) | |
.sort(([groupA], [groupB]) => { | |
if (groupA === 'Overall') return -1 | |
if (groupB === 'Overall') return 1 | |
return groupA.localeCompare(groupB) | |
}) | |
.reduce( | |
(acc, [group, metrics]) => { | |
const sortedMetrics = [...metrics].sort() | |
acc[group] = sortedMetrics.reduce<{ [key: string]: string[] }>((subAcc, metric) => { | |
const [mainGroup, subGroup] = metric.split('_') | |
if (!subAcc[mainGroup]) { | |
subAcc[mainGroup] = [] | |
} | |
subAcc[mainGroup].push(metric) | |
return subAcc | |
}, {}) | |
acc[group] = Object.fromEntries( | |
Object.entries(acc[group]).sort(([subGroupA], [subGroupB]) => | |
subGroupA.localeCompare(subGroupB) | |
) | |
) | |
return acc | |
}, | |
{} as { [key: string]: { [key: string]: string[] } } | |
) | |
const allKeys: string[] = Array.from(new Set(rows.flatMap((row) => Object.keys(row)))) | |
const headers = allKeys.filter((key) => key !== 'metric') | |
const initialOpenGroups: { [key: string]: boolean } = {} | |
const initialOpenSubGroups: { [key: string]: { [key: string]: boolean } } = {} | |
Object.keys(groupsData).forEach((group) => { | |
initialOpenGroups[group] = false | |
initialOpenSubGroups[group] = {} | |
Object.keys(groupsData[group]).forEach((subGroup) => { | |
initialOpenSubGroups[group][subGroup] = false | |
}) | |
}) | |
const allMetrics = Object.values(groups).flat() | |
setSelectedMetrics(new Set(allMetrics)) | |
setTableHeader(headers) | |
setTableRows(rows) | |
setGroupRows(groupsData) | |
setOpenGroupRows(initialOpenGroups) | |
setOpenSubGroupRows(initialOpenSubGroups) | |
setError(null) | |
} catch (err: any) { | |
setError('Failed to parse benchmark data, please try again: ' + err.message) | |
} | |
}, [benchmarkData]) | |
const toggleGroup = (group: string) => { | |
setOpenGroupRows((prev) => ({ ...prev, [group]: !prev[group] })) | |
} | |
const toggleSubGroup = (group: string, subGroup: string) => { | |
setOpenSubGroupRows((prev) => ({ | |
...prev, | |
[group]: { | |
...(prev[group] || {}), | |
[subGroup]: !prev[group]?.[subGroup], | |
}, | |
})) | |
} | |
const handleSort = (overallMetric: string, model: string) => { | |
setSortState((prev) => { | |
const prevDir = prev[overallMetric]?.[model]?.direction | |
let newSortState: SortState = {} | |
if (!prevDir) { | |
// No sort yet, set to 'asc' | |
newSortState[overallMetric] = { [model]: { direction: 'asc' } } | |
} else if (prevDir === 'asc') { | |
// Was 'asc', set to 'desc' | |
newSortState[overallMetric] = { [model]: { direction: 'desc' } } | |
} | |
// Else revert back to unsorted state | |
return newSortState | |
}) | |
} | |
// Helper to generate a stable composite key for row-based column sorting | |
function getRowSortKey(group: string | null, subGroup: string | null, metric: string | null) { | |
return `${group ?? ''}||${subGroup ?? ''}||${metric ?? ''}` | |
} | |
// Update handleColumnSort to use setSelectedRowForSort | |
const handleColumnSort = ( | |
group: string | null, | |
subGroup: string | null, | |
metric: string | null | |
) => { | |
const rowKey = getRowSortKey(group, subGroup, metric) | |
setSelectedRowForSort((prev) => { | |
const prevDir = prev[rowKey]?.direction | |
const newSortState: { [rowKey: string]: { direction: 'asc' | 'desc' } } = {} | |
if (!prevDir) { | |
newSortState[rowKey] = { direction: 'asc' } | |
} else if (prevDir === 'asc') { | |
newSortState[rowKey] = { direction: 'desc' } | |
} else if (prevDir === 'desc') { | |
delete newSortState[rowKey] | |
} | |
return newSortState | |
}) | |
} | |
// Helper to get current row sort config for a row | |
function getRowColumnSort(group: string | null, subGroup: string | null, metric: string | null) { | |
return selectedRowForSort[getRowSortKey(group, subGroup, metric)] || null | |
} | |
const getSortConfig = () => { | |
// Find the first sorted column (overallMetric, model) | |
console.log({ sortState }) | |
for (const overallMetric of overallMetrics) { | |
if (!selectedOverallMetrics.has(overallMetric)) continue | |
const models = tableHeader.filter((model) => selectedModels.has(model)) | |
for (const model of models) { | |
if (sortState[overallMetric]?.[model]) { | |
return { overallMetric, model, direction: sortState[overallMetric][model].direction } | |
} | |
} | |
} | |
return null | |
} | |
// Move getRowSortConfig above sortModelColumns so it is defined before use | |
const getRowSortConfig = () => { | |
for (const overallMetric of overallMetrics) { | |
if (!selectedOverallMetrics.has(overallMetric)) continue | |
const models = tableHeader.filter((model) => selectedModels.has(model)) | |
for (const model of models) { | |
if (sortState[overallMetric]?.[model]) { | |
return { overallMetric, model, direction: sortState[overallMetric][model].direction } | |
} | |
} | |
} | |
return null | |
} | |
const getColumnSortConfig = () => { | |
for (const overallMetric of overallMetrics) { | |
if (!selectedOverallMetrics.has(overallMetric)) continue | |
if (columnSortState[overallMetric]?.['__col__']) { | |
return { overallMetric, direction: columnSortState[overallMetric]['__col__'].direction } | |
} | |
} | |
return null | |
} | |
const sortModelColumns = (models: string[], overallMetric: string): string[] => { | |
// Column sort takes precedence; if no column sort, return models in default order | |
const columnSortConfig = getColumnSortConfig() | |
console.log({ columnSortConfig, overallMetric }) | |
if (columnSortConfig && columnSortConfig.overallMetric === overallMetric) { | |
// Sort by average value for each model in this overallMetric | |
return [...models].sort((a, b) => { | |
const valsA = tableRows | |
.filter((row) => findAllMetricsForName(overallMetric).includes(row.metric as string)) | |
.map((row) => Number(row[a])) | |
.filter((v) => !isNaN(v)) | |
const valsB = tableRows | |
.filter((row) => findAllMetricsForName(overallMetric).includes(row.metric as string)) | |
.map((row) => Number(row[b])) | |
.filter((v) => !isNaN(v)) | |
const avgA = valsA.length ? valsA.reduce((s, v) => s + v, 0) / valsA.length : NaN | |
const avgB = valsB.length ? valsB.reduce((s, v) => s + v, 0) / valsB.length : NaN | |
if (isNaN(avgA) && isNaN(avgB)) return 0 | |
if (isNaN(avgA)) return 1 | |
if (isNaN(avgB)) return -1 | |
return columnSortConfig.direction === 'asc' ? avgA - avgB : avgB - avgA | |
}) | |
} | |
// No column sort: return models in default order | |
return models | |
} | |
const sortRowsBySubcolumn = ( | |
rows: string[], | |
overallMetric: string, | |
model: string, | |
direction: 'asc' | 'desc' | |
) => { | |
return [...rows].sort((a, b) => { | |
const rowA = tableRows.find((r) => r.metric === a) | |
const rowB = tableRows.find((r) => r.metric === b) | |
if (!rowA || !rowB) return 0 | |
const valA = Number(rowA[model]) | |
const valB = Number(rowB[model]) | |
if (isNaN(valA) && isNaN(valB)) return 0 | |
if (isNaN(valA)) return 1 | |
if (isNaN(valB)) return -1 | |
return direction === 'asc' ? valA - valB : valB - valA | |
}) | |
} | |
// Find all metrics matching a particular extracted metric name (like "log10_p_value") | |
const findAllMetricsForName = (metricName: string): string[] => { | |
return tableRows | |
.filter((row) => { | |
const metric = row.metric as string | |
if (metric.includes('_')) { | |
const extractedName = metric.split('_').slice(1).join('_') | |
return extractedName.endsWith(metricName) | |
} | |
return false | |
}) | |
.map((row) => row.metric as string) | |
} | |
// Identify metrics that don't belong to any overall metric group | |
const findStandaloneMetrics = (): string[] => { | |
// Get all metrics from the table rows | |
const allMetrics = tableRows.map((row) => row.metric as string) | |
// Filter to only include metrics that aren't part of any of the overall metrics | |
return allMetrics.filter((metric) => { | |
// Check if this metric is part of any of the overall metrics | |
for (const overall of overallMetrics) { | |
if (metric.endsWith(`_${overall}`) || metric === overall) { | |
return false // This metric belongs to an overall group | |
} | |
} | |
return true | |
}) | |
} | |
// Calculate average and standard deviation for a set of metrics for a specific column | |
const calculateStats = ( | |
metricNames: string[], | |
columnKey: string | |
): { avg: number; stdDev: number } => { | |
const values = metricNames | |
.map((metricName) => { | |
const row = tableRows.find((row) => row.metric === metricName) | |
return row ? Number(row[columnKey]) : NaN | |
}) | |
.filter((value) => !isNaN(value)) | |
if (values.length === 0) return { avg: NaN, stdDev: NaN } | |
const avg = values.reduce((sum, val) => sum + val, 0) / values.length | |
const squareDiffs = values.map((value) => { | |
const diff = value - avg | |
return diff * diff | |
}) | |
const variance = squareDiffs.reduce((sum, sqrDiff) => sum + sqrDiff, 0) / values.length | |
const stdDev = Math.sqrt(variance) | |
return { avg, stdDev } | |
} | |
// Filter metrics by group and/or subgroup | |
const filterMetricsByGroupAndSubgroup = ( | |
metricNames: string[], | |
group: string | null = null, | |
subgroup: string | null = null | |
): string[] => { | |
// If no group specified, return all metrics | |
if (!group) return metricNames | |
// Get all metrics for the specified group | |
const groupMetrics = Object.values(groupRows[group] || {}).flat() as string[] | |
// If subgroup is specified, further filter to that subgroup | |
if (subgroup && groupRows[group]?.[subgroup]) { | |
return metricNames.filter( | |
(metric) => groupRows[group][subgroup].includes(metric) && selectedMetrics.has(metric) | |
) | |
} | |
// Otherwise return all metrics in the group | |
return metricNames.filter( | |
(metric) => groupMetrics.includes(metric) && selectedMetrics.has(metric) | |
) | |
} | |
// Before rendering group rows: | |
const groupSortConfig = getSortConfig() | |
let groupEntries = Object.entries(groupRows).filter(([group]) => group !== 'Overall') | |
if (groupSortConfig) { | |
groupEntries = groupEntries.sort(([groupA, subGroupsA], [groupB, subGroupsB]) => { | |
// For each group, get all metrics in the group for the selected overallMetric | |
const allMetricsWithName = findAllMetricsForName(groupSortConfig.overallMetric) | |
const getGroupAvg = (subGroups: { [key: string]: string[] }) => { | |
const allGroupMetrics = Object.values(subGroups).flat() | |
const metricsInGroupForThisMetric = allGroupMetrics.filter((m) => | |
allMetricsWithName.includes(m) | |
) | |
const stats = calculateStats(metricsInGroupForThisMetric, groupSortConfig.model) | |
return stats.avg | |
} | |
const avgA = getGroupAvg(subGroupsA) | |
const avgB = getGroupAvg(subGroupsB) | |
if (isNaN(avgA) && isNaN(avgB)) return 0 | |
if (isNaN(avgA)) return 1 | |
if (isNaN(avgB)) return -1 | |
return groupSortConfig.direction === 'asc' ? avgA - avgB : avgB - avgA | |
}) | |
} | |
// Compute model order for each overall metric before rendering | |
const modelOrderByOverallMetric: { [metric: string]: string[] } = {} | |
overallMetrics | |
.filter((metric) => selectedOverallMetrics.has(metric)) | |
.forEach((metric) => { | |
// Check if there is an active row-based column sort for this metric | |
let sortedModels: string[] | null = null | |
// Find the active rowKey for this metric in rowColumnSort | |
const activeRowKey = Object.keys(selectedRowForSort).find((rowKey) => { | |
// rowKey format: group||subGroup||metric | |
const [group, subGroup, rowMetric] = rowKey.split('||') | |
// If rowMetric is empty, it's a group or subgroup row | |
if (rowMetric === '' && metric === metric) return true | |
// If rowMetric matches this metric, it's an individual metric row | |
if (rowMetric && findAllMetricsForName(metric).includes(rowMetric)) return true | |
return false | |
}) | |
if (activeRowKey && selectedRowForSort[activeRowKey]) { | |
const direction = selectedRowForSort[activeRowKey].direction | |
const [group, subGroup, rowMetric] = activeRowKey.split('||') | |
const models = tableHeader.filter((model) => selectedModels.has(model)) | |
if (!rowMetric) { | |
// Group or subgroup row: sort by average for this group/subgroup and metric | |
// Find all metrics in this group/subgroup for this overall metric | |
let relevantMetrics: string[] = [] | |
if (group && !subGroup) { | |
// Group row | |
const groupMetrics = Object.values(groupRows[group] || {}).flat() as string[] | |
relevantMetrics = groupMetrics.filter((m: string) => | |
findAllMetricsForName(metric).includes(m) | |
) | |
} else if (group && subGroup) { | |
// Subgroup row | |
relevantMetrics = (groupRows[group]?.[subGroup] || []).filter((m: string) => | |
findAllMetricsForName(metric).includes(m) | |
) | |
} | |
sortedModels = [...models].sort((a, b) => { | |
const statsA = calculateStats(relevantMetrics, a) | |
const statsB = calculateStats(relevantMetrics, b) | |
if (isNaN(statsA.avg) && isNaN(statsB.avg)) return 0 | |
if (isNaN(statsA.avg)) return 1 | |
if (isNaN(statsB.avg)) return -1 | |
return direction === 'asc' ? statsA.avg - statsB.avg : statsB.avg - statsA.avg | |
}) | |
} else { | |
// Individual metric row: sort by value for that metric | |
sortedModels = [...models].sort((a, b) => { | |
const rowA = tableRows.find((r) => r.metric === rowMetric) | |
const rowB = rowA // same row | |
const valA = rowA ? Number(rowA[a]) : NaN | |
const valB = rowB ? Number(rowB[b]) : NaN | |
if (isNaN(valA) && isNaN(valB)) return 0 | |
if (isNaN(valA)) return 1 | |
if (isNaN(valB)) return -1 | |
return direction === 'asc' ? valA - valB : valB - valA | |
}) | |
} | |
} | |
modelOrderByOverallMetric[metric] = | |
sortedModels || | |
sortModelColumns( | |
tableHeader.filter((model) => selectedModels.has(model)), | |
metric | |
) | |
}) | |
console.log({ modelOrderByOverallMetric }) | |
return ( | |
<div className="rounded"> | |
{error && <div className="text-red-500">{error}</div>} | |
{!error && ( | |
<div className="flex flex-col gap-8"> | |
<div className="flex flex-col gap-4"> | |
<OverallMetricFilter | |
overallMetrics={overallMetrics} | |
selectedOverallMetrics={selectedOverallMetrics} | |
setSelectedOverallMetrics={setSelectedOverallMetrics} | |
/> | |
<LeaderboardFilter | |
groups={groupRows} | |
selectedMetrics={selectedMetrics} | |
setSelectedMetrics={setSelectedMetrics} | |
/> | |
</div> | |
{selectedModels.size === 0 || selectedMetrics.size === 0 ? ( | |
<div className="text-center p-4 text-lg"> | |
Please select at least one model and one metric to display the data | |
</div> | |
) : ( | |
<> | |
{/* Standalone metrics table */} | |
<IndependentMetricsTable | |
independentMetrics={findStandaloneMetrics()} | |
tableHeader={tableHeader} | |
selectedModels={selectedModels} | |
tableRows={tableRows} | |
/> | |
{/* Main metrics table */} | |
<div className="overflow-x-auto max-h-[80vh] overflow-y-auto"> | |
<table className="table w-full min-w-max border-gray-700 border"> | |
<thead> | |
<tr> | |
<th className="sticky left-0 top-0 bg-base-100 z-20 border-gray-700 border"> | |
Attack Category Metrics | |
</th> | |
{overallMetrics | |
.filter((metric) => selectedOverallMetrics.has(metric)) | |
.map((metric) => ( | |
<th | |
key={`header-metric-${metric}`} | |
className="bg-base-100 z-10 text-center text-xs border-gray-700 border" | |
colSpan={modelOrderByOverallMetric[metric].length} | |
> | |
{metric} | |
</th> | |
))} | |
</tr> | |
<tr> | |
<th className="sticky left-0 bg-base-100 z-10 border-gray-700 border"></th> | |
{overallMetrics | |
.filter((metric) => selectedOverallMetrics.has(metric)) | |
.map((metric) => ( | |
<React.Fragment key={`header-models-${metric}`}> | |
{modelOrderByOverallMetric[metric].map((model) => { | |
const isSorted = sortState[metric]?.[model]?.direction !== undefined | |
const direction = sortState[metric]?.[model]?.direction || 'desc' | |
return ( | |
<th | |
key={`${metric}-${model}`} | |
className="sticky top-12 bg-base-100 z-10 text-center text-xs border-gray-700 border border-bottom-solid border-b-gray-700 border-b-3 cursor-pointer select-none" | |
onClick={() => handleSort(metric, model)} | |
> | |
{model} | |
<span className="ml-1"> | |
{isSorted ? (direction === 'asc' ? '↑' : '↓') : '⇅'} | |
</span> | |
</th> | |
) | |
})} | |
</React.Fragment> | |
))} | |
</tr> | |
</thead> | |
<tbody> | |
{/* First render each group row */} | |
{groupEntries.map(([group, subGroups]) => { | |
// Skip the "Overall" group completely | |
if (group === 'Overall') return null | |
// Get all metrics for this group row | |
const allGroupMetrics = Object.values(subGroups).flat() | |
// Filter to only include selected metrics | |
const visibleGroupMetrics = filterMetricsByGroupAndSubgroup( | |
allGroupMetrics, | |
group | |
) | |
// Skip this group row if no metrics are selected | |
if (visibleGroupMetrics.length === 0) return null | |
// Sort subgroups by average if sort config is active | |
let subGroupEntries = Object.entries(subGroups) | |
if (groupSortConfig) { | |
const allMetricsWithName = findAllMetricsForName( | |
groupSortConfig.overallMetric | |
) | |
const getSubGroupAvg = (metrics: string[]) => { | |
const metricsInSubGroupForThisMetric = metrics.filter((m) => | |
allMetricsWithName.includes(m) | |
) | |
const stats = calculateStats( | |
metricsInSubGroupForThisMetric, | |
groupSortConfig.model | |
) | |
return stats.avg | |
} | |
subGroupEntries = subGroupEntries.sort(([, metricsA], [, metricsB]) => { | |
const avgA = getSubGroupAvg(metricsA) | |
const avgB = getSubGroupAvg(metricsB) | |
if (isNaN(avgA) && isNaN(avgB)) return 0 | |
if (isNaN(avgA)) return 1 | |
if (isNaN(avgB)) return -1 | |
return groupSortConfig.direction === 'asc' ? avgA - avgB : avgB - avgA | |
}) | |
} | |
return ( | |
<React.Fragment key={group}> | |
{/* Group row with average stats for the entire group */} | |
<tr | |
className="bg-base-200 cursor-pointer hover:bg-base-300" | |
onClick={() => toggleGroup(group)} | |
> | |
<td className="sticky left-0 bg-base-200 z-10 font-medium border-gray-700 border cursor-pointer select-none flex items-center gap-1"> | |
<span>{openGroupRows[group] ? '▼ ' : '▶ '}</span> | |
<span className="flex-1">{group}</span> | |
{/* Sort icon: only this triggers sort, and shows default if unsorted */} | |
<span | |
className="ml-1 cursor-pointer" | |
onClick={(e) => { | |
e.stopPropagation() | |
handleColumnSort(group, null, null) | |
}} | |
title={ | |
getRowColumnSort(group, null, null) | |
? getRowColumnSort(group, null, null)?.direction === 'asc' | |
? 'Sort descending' | |
: 'Clear sort' | |
: 'Sort by this row' | |
} | |
> | |
{getRowColumnSort(group, null, null) | |
? getRowColumnSort(group, null, null)?.direction === 'asc' | |
? '↑' | |
: '↓' | |
: '⇅'} | |
</span> | |
</td> | |
{/* For each metric column */} | |
{overallMetrics | |
.filter((metric) => selectedOverallMetrics.has(metric)) | |
.map((metric) => { | |
const rowKey = getRowSortKey(group, null, null) | |
return ( | |
<React.Fragment key={`${group}-${metric}`}> | |
{modelOrderByOverallMetric[metric].map((col: string) => { | |
const allMetricsWithName = findAllMetricsForName(metric) | |
const metricsInGroupForThisMetric = | |
visibleGroupMetrics.filter((m) => | |
allMetricsWithName.includes(m) | |
) | |
const stats = calculateStats(metricsInGroupForThisMetric, col) | |
return ( | |
<td | |
key={`${group}-${metric}-${col}`} | |
className="font-medium text-center border-gray-700 border" | |
> | |
{!isNaN(stats.avg) | |
? `${stats.avg.toFixed(3)} ± ${stats.stdDev.toFixed(3)}` | |
: 'N/A'} | |
</td> | |
) | |
})} | |
</React.Fragment> | |
) | |
})} | |
</tr> | |
{/* Only render subgroups if group row is open */} | |
{openGroupRows[group] && | |
subGroupEntries.map(([subGroup, metrics]) => { | |
// Filter to only include selected metrics in this subgroup row | |
const visibleSubgroupMetrics = filterMetricsByGroupAndSubgroup( | |
metrics, | |
group, | |
subGroup | |
) | |
// Skip this subgroup row if no metrics are selected | |
if (visibleSubgroupMetrics.length === 0) return null | |
return ( | |
<React.Fragment key={`${group}-${subGroup}`}> | |
{/* Subgroup row with average stats for the subgroup */} | |
<tr | |
className="bg-base-100 cursor-pointer hover:bg-base-200" | |
onClick={() => toggleSubGroup(group, subGroup)} | |
> | |
<td className="sticky left-0 bg-base-100 z-10 pl-6 font-medium border-gray-700 border cursor-pointer select-none flex items-center gap-1"> | |
<span> | |
{openSubGroupRows[group]?.[subGroup] ? '▼ ' : '▶ '} | |
</span> | |
<span className="flex-1">{subGroup}</span> | |
<span | |
className="ml-1 cursor-pointer" | |
onClick={(e) => { | |
e.stopPropagation() | |
handleColumnSort(group, subGroup, null) | |
}} | |
title={ | |
getRowColumnSort(group, subGroup, null) | |
? getRowColumnSort(group, subGroup, null)?.direction === | |
'asc' | |
? 'Sort descending' | |
: 'Clear sort' | |
: 'Sort by this row' | |
} | |
> | |
{getRowColumnSort(group, subGroup, null) | |
? getRowColumnSort(group, subGroup, null)?.direction === | |
'asc' | |
? '↑' | |
: '↓' | |
: '⇅'} | |
</span> | |
</td> | |
{/* For each metric column */} | |
{overallMetrics | |
.filter((metric) => selectedOverallMetrics.has(metric)) | |
.map((metric) => { | |
const rowKey = getRowSortKey(group, subGroup, null) | |
return ( | |
<React.Fragment key={`${group}-${subGroup}-${metric}`}> | |
{modelOrderByOverallMetric[metric].map( | |
(col: string) => { | |
const allMetricsWithName = | |
findAllMetricsForName(metric) | |
const metricsInSubgroupForThisMetric = | |
visibleSubgroupMetrics.filter((m) => | |
allMetricsWithName.includes(m) | |
) | |
const stats = calculateStats( | |
metricsInSubgroupForThisMetric, | |
col | |
) | |
return ( | |
<td | |
key={`${group}-${subGroup}-${metric}-${col}`} | |
className="font-medium text-center border-gray-700 border" | |
> | |
{!isNaN(stats.avg) | |
? `${stats.avg.toFixed(3)} ± ${stats.stdDev.toFixed(3)}` | |
: 'N/A'} | |
</td> | |
) | |
} | |
)} | |
</React.Fragment> | |
) | |
})} | |
</tr> | |
{/* Individual metric rows */} | |
{openSubGroupRows[group]?.[subGroup] && | |
(() => { | |
// Sorting logic for individual metric rows | |
const sortConfig = getSortConfig() | |
let sortedMetrics = [...visibleSubgroupMetrics] | |
console.log( | |
'Sorting metrics for subgroup:', | |
group, | |
subGroup, | |
'with config:', | |
sortConfig | |
) | |
if (sortConfig) { | |
// Only sort metrics that match the selected overallMetric and model | |
const allMetricsWithName = findAllMetricsForName( | |
sortConfig.overallMetric | |
) | |
const metricsInSubgroupForThisMetric = sortedMetrics.filter( | |
(m) => allMetricsWithName.includes(m) | |
) | |
const metricsNotInSubgroupForThisMetric = | |
sortedMetrics.filter( | |
(m) => !allMetricsWithName.includes(m) | |
) | |
// Only apply subcolumn sort to matching metrics, leave others in original order | |
sortedMetrics = [ | |
...sortRowsBySubcolumn( | |
metricsInSubgroupForThisMetric, | |
sortConfig.overallMetric, | |
sortConfig.model, | |
sortConfig.direction | |
), | |
...metricsNotInSubgroupForThisMetric, | |
] | |
} else { | |
// Fallback sort logic (category, overall, strength) | |
sortedMetrics = sortedMetrics.sort((a, b) => { | |
// For metrics with format {category}_{strength}_{overall_metric_name}, | |
// First sort by category, then by overall_metric_name, then by strength | |
// First extract the overall metric group | |
const getOverallMetricGroup = (metric: string) => { | |
for (const overall of overallMetrics) { | |
if ( | |
metric.endsWith(`_${overall}`) || | |
metric === overall | |
) { | |
return overall | |
} | |
} | |
return '' | |
} | |
const overallA = getOverallMetricGroup(a) | |
const overallB = getOverallMetricGroup(b) | |
// Extract the strength (last part before the overall metric) | |
const stripOverall = ( | |
metric: string, | |
overall: string | |
) => { | |
if (metric.endsWith(`_${overall}`)) { | |
// Remove the overall metric group and any preceding underscore | |
const stripped = metric.slice( | |
0, | |
metric.length - overall.length - 1 | |
) | |
const parts = stripped.split('_') | |
return parts.length > 0 ? parts[parts.length - 1] : '' | |
} | |
return metric | |
} | |
// Extract the category (what remains after removing strength and overall_metric_name) | |
const getCategory = (metric: string, overall: string) => { | |
if (metric.endsWith(`_${overall}`)) { | |
const stripped = metric.slice( | |
0, | |
metric.length - overall.length - 1 | |
) | |
const parts = stripped.split('_') | |
// Remove the last part (strength) and join the rest (category) | |
return parts.length > 1 | |
? parts.slice(0, parts.length - 1).join('_') | |
: '' | |
} | |
return metric | |
} | |
const categoryA = getCategory(a, overallA) | |
const categoryB = getCategory(b, overallB) | |
// First sort by category | |
if (categoryA !== categoryB) { | |
return categoryA.localeCompare(categoryB) | |
} | |
// Then sort by overall metric name | |
if (overallA !== overallB) { | |
return overallA.localeCompare(overallB) | |
} | |
// Finally sort by strength | |
const subA = stripOverall(a, overallA) | |
const subB = stripOverall(b, overallB) | |
// Try to parse subA and subB as numbers, handling k/m/b suffixes | |
const parseNumber = (str: string) => { | |
const match = str.match(/^\d+(?:\.\d+)?([kKmMbB]?)$/) | |
if (!match) return NaN | |
let [_, suffix] = match | |
let value = parseFloat(str) | |
switch (suffix?.toLowerCase()) { | |
case 'k': | |
value *= 1e3 | |
break | |
case 'm': | |
value *= 1e6 | |
break | |
case 'b': | |
value *= 1e9 | |
break | |
} | |
return value | |
} | |
const numA = parseNumber(subA) | |
const numB = parseNumber(subB) | |
if (!isNaN(numA) && !isNaN(numB)) { | |
return numA - numB | |
} | |
// Fallback to string comparison if not both numbers | |
return subA.localeCompare(subB) | |
}) | |
} | |
return sortedMetrics.map((metric) => { | |
const row = tableRows.find((r) => r.metric === metric) | |
if (!row) return null | |
// Extract the metric name (after the underscore) | |
const metricName = metric.includes('_') | |
? metric.split('_').slice(1).join('_') | |
: metric | |
return ( | |
<tr key={metric} className="hover:bg-base-100"> | |
<td className="sticky left-0 bg-base-100 z-10 pl-10 border-gray-700 border cursor-pointer select-none flex items-center gap-1"> | |
<span className="flex-1">{metric}</span> | |
<span | |
className="ml-1 cursor-pointer" | |
onClick={(e) => { | |
e.stopPropagation() | |
handleColumnSort(group, subGroup, metric) | |
}} | |
title={ | |
getRowColumnSort(group, subGroup, metric) | |
? getRowColumnSort(group, subGroup, metric) | |
?.direction === 'asc' | |
? 'Sort descending' | |
: 'Clear sort' | |
: 'Sort by this row' | |
} | |
> | |
{getRowColumnSort(group, subGroup, metric) | |
? getRowColumnSort(group, subGroup, metric) | |
?.direction === 'asc' | |
? '▲' | |
: '▼' | |
: '⇅'} | |
</span> | |
</td> | |
{overallMetrics | |
.filter((oMetric) => | |
selectedOverallMetrics.has(oMetric) | |
) | |
.map((oMetric) => { | |
const isMatchingMetric = | |
findAllMetricsForName(oMetric).includes(metric) | |
if (!isMatchingMetric) { | |
return ( | |
<React.Fragment key={`${metric}-${oMetric}`}> | |
{modelOrderByOverallMetric[oMetric].map( | |
(col) => ( | |
<td | |
key={`${metric}-${oMetric}-${col}`} | |
className="text-center border-gray-700 border" | |
></td> | |
) | |
)} | |
</React.Fragment> | |
) | |
} | |
return ( | |
<React.Fragment key={`${metric}-${oMetric}`}> | |
{modelOrderByOverallMetric[oMetric].map( | |
(col) => { | |
const cell = row[col] | |
return ( | |
<td | |
key={`${metric}-${oMetric}-${col}`} | |
className="text-center border-gray-700 border" | |
> | |
{!isNaN(Number(cell)) | |
? Number(Number(cell).toFixed(3)) | |
: cell} | |
</td> | |
) | |
} | |
)} | |
</React.Fragment> | |
) | |
})} | |
</tr> | |
) | |
}) | |
})()} | |
</React.Fragment> | |
) | |
})} | |
</React.Fragment> | |
) | |
})} | |
</tbody> | |
</table> | |
</div> | |
</> | |
)} | |
</div> | |
)} | |
</div> | |
) | |
} | |
export default LeaderboardTable | |