const TASK_ID_TO_NAME = { 'agg_score': 'Aggregate Score', 'commonsense_qa/acc_norm': 'Commonsense QA', 'hellaswag/acc_norm': 'HellaSwag', 'openbookqa/acc_norm': 'OpenBook QA', 'piqa/acc_norm': 'PIQA', 'winogrande/acc_norm': 'WinoGrande', 'arc/acc_norm': 'ARC', 'mmlu/acc_norm': 'MMLU' }; const DEFAULT_LAYOUT = { title: { text: 'Plot Title', font: { size: 19, family: "apple-system, Arial, sans-serif" } }, xaxis: { title: { text: 'Training tokens (billions)', font: { size: 15, family: "apple-system, Arial, sans-serif" } }, tickfont: { size: 14, family: "apple-system, Arial, sans-serif" }, showgrid: false, mirror: true, ticks: 'outside', showline: true, }, yaxis: { title: { text: "Agg Score", font: { size: 15, family: "apple-system, Arial, sans-serif" }, standoff: 10 }, showgrid: false, mirror: true, ticks: 'outside', showline: true, tickfont: { size: 14, family: "apple-system, Arial, sans-serif" }, }, legend: { orientation: 'v', xanchor: 'right', yanchor: 'bottom', x: 1, y: 0, font: { size: 14, family: "apple-system, Arial, sans-serif" }, bgcolor: 'rgba(0,0,0,0)', }, margin: { t: 30, b: 50 }, height: 400 } const init_plot = function() { const plotElements = document.querySelectorAll('[id^="plot-"]'); plotElements.forEach(async (plotElement) => { const plotName = plotElement.id.replace('plot-', ''); const data = await fetch(`data/plots/${plotName}.json`).then((response) => response.json()); const {dropdown, slider, plot} = createPlottingElements(plotElement, data.data ?? data.traces, data.defaultMetric ?? "agg_score", data.defaultWindowSize ?? 0, data.createSlider ?? 1); plot.id = `graph-${plotName}`; dropdown.addEventListener('change', () => updatePlot(dropdown, slider)); let timeoutId; // Debounce the slider if (slider) slider.addEventListener('input', () => { clearTimeout(timeoutId); timeoutId = setTimeout(() => { updatePlot(dropdown, slider); }, 500); }); function updatePlot(dropdown, slider) { const metric = dropdown.value; const sliderValue = parseInt(slider?.value ?? 0); const traces = "traces" in data ? data.traces[metric] : []; if (!("traces" in data)) { const metricData = data.data[metric]; for (const key in metricData) { const y = rollingWindow(metricData[key].y, sliderValue); const x = metricData[key].x.slice(0, y.length); const trace = { x: x, y: y, type: 'scatter', mode: 'lines', line: { width: 2.5 }, name: metricData[key].label }; traces.push(trace); } } let minX = Math.min(...traces.flatMap(trace => trace.x)); let maxX = Math.max(...traces.flatMap(trace => trace.x)); const width = plot.parentElement.offsetWidth; const layout = _.merge({}, DEFAULT_LAYOUT, {width: width, yaxis: {title: {text: TASK_ID_TO_NAME[metric]}}, xaxis: {range: [minX*0.95, maxX*1.05]}}, data.layout); Plotly.newPlot(plot, traces, layout); window.addEventListener('resize', () => { // If the window size is smaller than 768, we don't care as it's not shown if (window.innerWidth < 768) { return; } // For some reason plotly doesn't respect the width :( console.log(plot.parentElement.offsetWidth); console.log(plot.id); Plotly.relayout(plot, {width: plot.parentElement.offsetWidth}); }) } // Initial plot updatePlot(dropdown, slider); }); }; document.addEventListener('DOMContentLoaded', init_plot); const getSliderMax = (data) => { const firstMetricData = data[Object.keys(data)[0]] const totalSamples = firstMetricData[Object.keys(firstMetricData)[0]].x.length console.log(totalSamples); if (totalSamples < 20) { return 10; } return 30; } const createPlottingElements = (plotElement, data, defaultMetric, defaultWindowSize, createSlider) => { // Create plot const plot = document.createElement('figure'); const controls = document.createElement('div'); plot.classList.add('plotly'); controls.classList.add('plotly_controls'); plotElement.appendChild(plot); plotElement.appendChild(controls); const metricOptions = Object.keys(data).filter(metric => metric in TASK_ID_TO_NAME); // Dropdown const dropdownLabel = document.createElement('label'); dropdownLabel.textContent = 'Metric:'; const dropdown = document.createElement('select'); dropdown.innerHTML = metricOptions.map((option) => ``).join(''); dropdown.value = defaultMetric; const dropdownContainer = document.createElement('div'); dropdownContainer.classList.add('plotly_input_container'); dropdownContainer.appendChild(dropdownLabel); dropdownContainer.appendChild(dropdown); controls.appendChild(dropdownContainer); if (!createSlider) return {dropdown, undefined, plot}; // Slider const sliderLabel = document.createElement('label'); sliderLabel.textContent = 'Rolling window:'; const slider = document.createElement('input'); slider.type = 'range'; slider.min = 0; slider.max = getSliderMax(data); slider.value = defaultWindowSize ?? 0; // Get the first example for any metric // current value const sliderValue = document.createElement('span'); sliderValue.textContent = slider.value; slider.addEventListener('input', () => { sliderValue.textContent = slider.value; }); const sliderInputContainer = document.createElement('div'); sliderInputContainer.classList.add('plotly_slider'); sliderInputContainer.appendChild(slider); sliderInputContainer.appendChild(sliderValue); const sliderContainer = document.createElement('div'); sliderContainer.classList.add('plotly_input_container'); sliderContainer.appendChild(sliderLabel); sliderContainer.appendChild(sliderInputContainer); controls.appendChild(sliderContainer); return {dropdown, slider, plot}; } const rollingWindow = function(data, windowSize) { if (windowSize === 0) { return data; } const rollingData = []; // Start at halfWindowSize to ensure we can get a full window for (let i = windowSize; i < data.length; i++) { const windowStart = i - windowSize; const windowEnd = i; const windowData = data.slice(windowStart, windowEnd); const windowAverage = windowData.reduce((acc, value) => acc + value, 0) / windowData.length; rollingData.push(windowAverage); } return rollingData; }