omnisealbench / frontend /src /components /LeaderboardTable.tsx
Mark Duppenthaler
Model Filters
cf4e7fe
raw
history blame
28.9 kB
import React, { useEffect, useState } from 'react'
import API from '../API'
import LeaderboardFilter from './LeaderboardFilter'
import ModelFilter from './ModelFilter'
interface LeaderboardTableProps {
file: string
}
interface Row {
metric: string
[key: string]: string | number
}
interface Groups {
[group: string]: { [subgroup: string]: string[] }
}
interface GroupStats {
average: { [key: string]: number }
stdDev: { [key: string]: number }
}
const LeaderboardTable: React.FC<LeaderboardTableProps> = ({ file }) => {
const [tableRows, setTableRows] = useState<Row[]>([])
const [tableHeader, setTableHeader] = useState<string[]>([])
const [loading, setLoading] = useState(true)
const [error, setError] = useState<string | null>(null)
const [groups, setGroups] = useState<Groups>({})
const [openGroups, setOpenGroups] = useState<{ [key: string]: boolean }>({})
const [openSubGroups, setOpenSubGroups] = useState<{ [key: string]: { [key: string]: boolean } }>(
{}
)
const [selectedMetrics, setSelectedMetrics] = useState<Set<string>>(new Set())
const [selectedModels, setSelectedModels] = useState<Set<string>>(new Set())
// To store the unique metrics from the Overall group
const [overallMetrics, setOverallMetrics] = useState<string[]>([])
useEffect(() => {
API.fetchStaticFile(`data/${file}_benchmark.csv`)
.then((response) => {
const data = JSON.parse(response)
const rows: Row[] = data['rows']
// Split out the Overall group from groups
const allGroups = data['groups'] as { [key: string]: string[] }
// const overallGroup = allGroups['Overall'] || []
// Remove 'Overall' from groups
const { Overall: overallGroup, ...groups } = allGroups
const uniqueMetrics = new Set<string>()
overallGroup.forEach((metric) => {
if (metric.includes('_')) {
// Extract the part after the first underscore
const metricName = metric.split('_').slice(1).join('_')
uniqueMetrics.add(metricName)
}
})
setOverallMetrics(Array.from(uniqueMetrics).sort())
// Each value of groups is a list of metrics, group them by the first part of the metric before the first _
const groupsData = Object.entries(groups)
.sort(([groupA], [groupB]) => {
// Make sure "overall" comes first
if (groupA === 'Overall') return -1
if (groupB === 'Overall') return 1
// Otherwise sort alphabetically
return groupA.localeCompare(groupB)
})
.reduce(
(acc, [group, metrics]) => {
// Sort metrics to ensure consistent subgroup order
const sortedMetrics = [...metrics].sort()
// Create and sort subgroups
acc[group] = sortedMetrics.reduce<{ [key: string]: string[] }>((subAcc, metric) => {
const [mainGroup, subGroup] = metric.split('_')
if (!subAcc[mainGroup]) {
subAcc[mainGroup] = []
}
subAcc[mainGroup].push(metric)
return subAcc
}, {})
// Convert to sorted entries and back to object
acc[group] = Object.fromEntries(
Object.entries(acc[group]).sort(([subGroupA], [subGroupB]) =>
subGroupA.localeCompare(subGroupB)
)
)
return acc
},
{} as { [key: string]: { [key: string]: string[] } }
)
const allKeys: string[] = Array.from(new Set(rows.flatMap((row) => Object.keys(row))))
// Remove 'metric' from headers if it exists
const headers = allKeys.filter((key) => key !== 'metric')
// Initialize open states for groups and subgroups
const initialOpenGroups: { [key: string]: boolean } = {}
const initialOpenSubGroups: { [key: string]: { [key: string]: boolean } } = {}
Object.keys(groupsData).forEach((group) => {
initialOpenGroups[group] = false
initialOpenSubGroups[group] = {}
Object.keys(groupsData[group]).forEach((subGroup) => {
initialOpenSubGroups[group][subGroup] = false
})
})
// Get all metrics from all groups
const allMetrics = Object.values(groups).flat()
setSelectedMetrics(new Set(allMetrics))
// Initialize all models as selected
setSelectedModels(new Set(headers))
setTableHeader(headers)
setTableRows(rows)
setGroups(groupsData)
setOpenGroups(initialOpenGroups)
setOpenSubGroups(initialOpenSubGroups)
setLoading(false)
})
.catch((err) => {
setError('Failed to fetch JSON: ' + err.message)
setLoading(false)
})
}, [file])
const toggleGroup = (group: string) => {
setOpenGroups((prev) => ({ ...prev, [group]: !prev[group] }))
}
const toggleSubGroup = (group: string, subGroup: string) => {
setOpenSubGroups((prev) => ({
...prev,
[group]: {
...(prev[group] || {}),
[subGroup]: !prev[group]?.[subGroup],
},
}))
}
// Find all metrics matching a particular extracted metric name (like "log10_p_value")
const findAllMetricsForName = (metricName: string): string[] => {
return tableRows
.filter((row) => {
const metric = row.metric as string
if (metric.includes('_')) {
const extractedName = metric.split('_').slice(1).join('_')
return extractedName.endsWith(metricName)
}
return false
})
.map((row) => row.metric as string)
}
// Identify metrics that don't belong to any overall metric group
const findStandaloneMetrics = (): string[] => {
// Get all metrics from the table rows
const allMetrics = tableRows.map((row) => row.metric as string)
// Filter to only include metrics that aren't part of any of the overall metrics
return allMetrics.filter((metric) => {
// Check if this metric is part of any of the overall metrics
for (const overall of overallMetrics) {
if (metric.endsWith(`_${overall}`) || metric === overall) {
return false // This metric belongs to an overall group
}
}
return true
})
}
// Calculate average and standard deviation for a set of metrics for a specific column
const calculateStats = (
metricNames: string[],
columnKey: string
): { avg: number; stdDev: number } => {
const values = metricNames
.map((metricName) => {
const row = tableRows.find((row) => row.metric === metricName)
return row ? Number(row[columnKey]) : NaN
})
.filter((value) => !isNaN(value))
if (values.length === 0) return { avg: NaN, stdDev: NaN }
const avg = values.reduce((sum, val) => sum + val, 0) / values.length
const squareDiffs = values.map((value) => {
const diff = value - avg
return diff * diff
})
const variance = squareDiffs.reduce((sum, sqrDiff) => sum + sqrDiff, 0) / values.length
const stdDev = Math.sqrt(variance)
return { avg, stdDev }
}
// Filter metrics by group and/or subgroup
const filterMetricsByGroupAndSubgroup = (
metricNames: string[],
group: string | null = null,
subgroup: string | null = null
): string[] => {
// If no group specified, return all metrics
if (!group) return metricNames
// Get all metrics for the specified group
const groupMetrics = Object.values(groups[group] || {}).flat()
// If subgroup is specified, further filter to that subgroup
if (subgroup && groups[group]?.[subgroup]) {
return metricNames.filter(
(metric) => groups[group][subgroup].includes(metric) && selectedMetrics.has(metric)
)
}
// Otherwise return all metrics in the group
return metricNames.filter(
(metric) => groupMetrics.includes(metric) && selectedMetrics.has(metric)
)
}
return (
<div className="rounded shadow overflow-auto">
<h3 className="font-bold mb-2">{file}</h3>
{loading && <div>Loading...</div>}
{error && <div className="text-red-500">{error}</div>}
{!loading && !error && (
<div className="overflow-x-auto">
<div className="flex flex-col gap-4">
<ModelFilter
models={tableHeader}
selectedModels={selectedModels}
setSelectedModels={setSelectedModels}
/>
<LeaderboardFilter
groups={groups}
selectedMetrics={selectedMetrics}
setSelectedMetrics={setSelectedMetrics}
/>
</div>
{selectedModels.size === 0 || selectedMetrics.size === 0 ? (
<div className="text-center p-4 text-lg">
Please select at least one model and one metric to display the data
</div>
) : (
<>
<table className="table w-full">
<thead>
<tr>
<th>Group / Subgroup</th>
{overallMetrics.map((metric) => (
<th
key={metric}
colSpan={tableHeader.filter((model) => selectedModels.has(model)).length}
className="text-center border-x"
>
{metric}
</th>
))}
</tr>
<tr>
<th></th>
{overallMetrics.map((metric) => (
<React.Fragment key={`header-models-${metric}`}>
{tableHeader
.filter((model) => selectedModels.has(model))
.map((model) => (
<th key={`${metric}-${model}`} className="text-center text-xs">
{model}
</th>
))}
</React.Fragment>
))}
</tr>
</thead>
<tbody>
{/* First render each group */}
{Object.entries(groups).map(([group, subGroups]) => {
// Skip the "Overall" group completely
if (group === 'Overall') return null
// Get all metrics for this group
const allGroupMetrics = Object.values(subGroups).flat()
// Filter to only include selected metrics
const visibleGroupMetrics = filterMetricsByGroupAndSubgroup(
allGroupMetrics,
group
)
// Skip this group if no metrics are selected
if (visibleGroupMetrics.length === 0) return null
return (
<React.Fragment key={group}>
{/* Group row with average stats for the entire group */}
<tr
className="bg-base-200 cursor-pointer hover:bg-base-300"
onClick={() => toggleGroup(group)}
>
<td className="font-medium">
{openGroups[group] ? '▼ ' : '▶ '}
{group}
</td>
{/* For each metric column */}
{overallMetrics.map((metric) => (
// Render sub-columns for each model
<React.Fragment key={`${group}-${metric}`}>
{tableHeader
.filter((model) => selectedModels.has(model))
.map((col) => {
// Find all metrics in this group that match the current metric name
const allMetricsWithName = findAllMetricsForName(metric)
const metricsInGroupForThisMetric = visibleGroupMetrics.filter(
(m) => allMetricsWithName.includes(m)
)
const stats = calculateStats(metricsInGroupForThisMetric, col)
return (
<td
key={`${group}-${metric}-${col}`}
className="font-medium text-center"
>
{!isNaN(stats.avg)
? `${stats.avg.toFixed(3)} ± ${stats.stdDev.toFixed(3)}`
: 'N/A'}
</td>
)
})}
</React.Fragment>
))}
</tr>
{/* Only render subgroups if group is open */}
{openGroups[group] &&
Object.entries(subGroups).map(([subGroup, metrics]) => {
// Filter to only include selected metrics in this subgroup
const visibleSubgroupMetrics = filterMetricsByGroupAndSubgroup(
metrics,
group,
subGroup
)
// Skip this subgroup if no metrics are selected
if (visibleSubgroupMetrics.length === 0) return null
return (
<React.Fragment key={`${group}-${subGroup}`}>
{/* Subgroup row with average stats for the subgroup */}
<tr
className="bg-base-100 cursor-pointer hover:bg-base-200"
onClick={() => toggleSubGroup(group, subGroup)}
>
<td className="pl-6 font-medium">
{openSubGroups[group]?.[subGroup] ? '▼ ' : '▶ '}
{subGroup}
</td>
{/* For each metric column */}
{overallMetrics.map((metric) => (
// Render sub-columns for each model
<React.Fragment key={`${group}-${subGroup}-${metric}`}>
{tableHeader
.filter((model) => selectedModels.has(model))
.map((col) => {
// Find all metrics in this subgroup that match the current metric name
const allMetricsWithName = findAllMetricsForName(metric)
const metricsInSubgroupForThisMetric =
visibleSubgroupMetrics.filter((m) =>
allMetricsWithName.includes(m)
)
const stats = calculateStats(
metricsInSubgroupForThisMetric,
col
)
return (
<td
key={`${group}-${subGroup}-${metric}-${col}`}
className="font-medium text-center"
>
{!isNaN(stats.avg)
? `${stats.avg.toFixed(3)} ± ${stats.stdDev.toFixed(3)}`
: 'N/A'}
</td>
)
})}
</React.Fragment>
))}
</tr>
{/* Individual metric rows */}
{openSubGroups[group]?.[subGroup] &&
// Sort visibleSubgroupMetrics alphabetically by the clean metric name
[...visibleSubgroupMetrics]
.sort((a, b) => {
// For metrics with format {category}_{strength}_{overall_metric_name},
// First sort by category, then by overall_metric_name, then by strength
// First extract the overall metric group
const getOverallMetricGroup = (metric: string) => {
for (const overall of overallMetrics) {
if (
metric.endsWith(`_${overall}`) ||
metric === overall
) {
return overall
}
}
return ''
}
const overallA = getOverallMetricGroup(a)
const overallB = getOverallMetricGroup(b)
// Extract the strength (last part before the overall metric)
const stripOverall = (metric: string, overall: string) => {
if (metric.endsWith(`_${overall}`)) {
// Remove the overall metric group and any preceding underscore
const stripped = metric.slice(
0,
metric.length - overall.length - 1
)
const parts = stripped.split('_')
return parts.length > 0 ? parts[parts.length - 1] : ''
}
return metric
}
// Extract the category (what remains after removing strength and overall_metric_name)
const getCategory = (metric: string, overall: string) => {
if (metric.endsWith(`_${overall}`)) {
const stripped = metric.slice(
0,
metric.length - overall.length - 1
)
const parts = stripped.split('_')
// Remove the last part (strength) and join the rest (category)
return parts.length > 1
? parts.slice(0, parts.length - 1).join('_')
: ''
}
return metric
}
const categoryA = getCategory(a, overallA)
const categoryB = getCategory(b, overallB)
// First sort by category
if (categoryA !== categoryB) {
return categoryA.localeCompare(categoryB)
}
// Then sort by overall metric name
if (overallA !== overallB) {
return overallA.localeCompare(overallB)
}
// Finally sort by strength
const subA = stripOverall(a, overallA)
const subB = stripOverall(b, overallB)
// Try to parse subA and subB as numbers, handling k/m/b suffixes
const parseNumber = (str: string) => {
const match = str.match(/^(\d+(?:\.\d+)?)([kKmMbB]?)$/)
if (!match) return NaN
let [_, num, suffix] = match
let value = parseFloat(num)
switch (suffix.toLowerCase()) {
case 'k':
value *= 1e3
break
case 'm':
value *= 1e6
break
case 'b':
value *= 1e9
break
}
return value
}
const numA = parseNumber(subA)
const numB = parseNumber(subB)
if (!isNaN(numA) && !isNaN(numB)) {
return numA - numB
}
// Fallback to string comparison if not both numbers
return subA.localeCompare(subB)
})
.map((metric) => {
const row = tableRows.find((r) => r.metric === metric)
if (!row) return null
// Extract the metric name (after the underscore)
const metricName = metric.includes('_')
? metric.split('_').slice(1).join('_')
: metric
return (
<tr key={metric} className="hover:bg-base-100">
<td className="pl-10">{metric}</td>
{/* For each metric column */}
{overallMetrics.map((oMetric) => {
// Only show values for the matching metric
const isMatchingMetric =
findAllMetricsForName(oMetric).includes(metric)
if (!isMatchingMetric) {
// Fill empty cells for non-matching metrics
return (
<React.Fragment key={`${metric}-${oMetric}`}>
{tableHeader
.filter((model) => selectedModels.has(model))
.map((col) => (
<td
key={`${metric}-${oMetric}-${col}`}
className="text-center"
></td>
))}
</React.Fragment>
)
}
// Show values for the matching metric
return (
<React.Fragment key={`${metric}-${oMetric}`}>
{tableHeader
.filter((model) => selectedModels.has(model))
.map((col) => {
const cell = row[col]
return (
<td
key={`${metric}-${oMetric}-${col}`}
className="text-center"
>
{!isNaN(Number(cell))
? Number(Number(cell).toFixed(3))
: cell}
</td>
)
})}
</React.Fragment>
)
})}
</tr>
)
})}
</React.Fragment>
)
})}
</React.Fragment>
)
})}
</tbody>
</table>
{/* Separate table for metrics that don't belong to any overall group */}
{(() => {
const standaloneMetrics = findStandaloneMetrics()
if (standaloneMetrics.length === 0) return null
return (
<div className="mt-8">
<h4 className="font-bold mb-2">Other Metrics</h4>
<table className="table w-full">
<thead>
<tr>
<th>Metric</th>
{tableHeader
.filter((model) => selectedModels.has(model))
.map((model) => (
<th key={`standalone-${model}`} className="text-center text-xs">
{model}
</th>
))}
</tr>
</thead>
<tbody>
{standaloneMetrics.sort().map((metric) => {
const row = tableRows.find((r) => r.metric === metric)
if (!row) return null
return (
<tr key={`standalone-${metric}`} className="hover:bg-base-100">
<td>{metric}</td>
{tableHeader
.filter((model) => selectedModels.has(model))
.map((col) => {
const cell = row[col]
return (
<td key={`standalone-${metric}-${col}`} className="text-center">
{!isNaN(Number(cell))
? Number(Number(cell).toFixed(3))
: cell}
</td>
)
})}
</tr>
)
})}
</tbody>
</table>
</div>
)
})()}
</>
)}
</div>
)}
</div>
)
}
export default LeaderboardTable