import { getColor } from "./colors.mjs" import Plotly from "plotly.js/dist/plotly-basic" import _ from "lodash" const DATA_FOLDER = "assets/data/plots" const LINE_SETTINGS = { width: 2.5, type: "scatter", mode: "lines", } const BAR_SETTINGS = { width: 0.5, type: "bar", opacity: 0.9, marker: { line: { width: 1.0 } } } const TASK_ID_TO_NAME = { // Ablations agg_score: "Aggregate Score", "commonsense_qa/acc_norm": "Commonsense QA", "hellaswag/acc_norm": "HellaSwag", "openbookqa/acc_norm": "OpenBook QA", "piqa/acc_norm": "PIQA", "siqa/acc_norm": "Social IQA", "winogrande/acc_norm": "WinoGrande", "arc/acc_norm": "ARC", "mmlu/acc_norm": "MMLU", // Stats "lines_ended_with_punct": "Lines Ended With Punctuation", "lines_chars": "Lines Chars", "short_lines": "Short Lines", }; const DATASET_ID_TO_NAME = { pii_removed: "Fineweb", allenai_c4_en: "C4", "tiiuae_falcon-refinedweb_data": "RefinedWeb", "red-pajama-v2_jsonl-deduplicated-extract": "RedPajamaV2", "dolma-sample": "Dolma1.6", dedup_minhash_independent_output: "Independent Dedup MinHash", "dedup_minhash_CC-MAIN-2013-48_output": "Full MinHash CC-MAIN-2013-48", "dedup_minhash_independent_output_CC-MAIN-2013-48": "Independent MinHash CC-MAIN-2013-48", "ind_minhash-CC-MAIN-2019-18": "Independent MinHash CC-MAIN-2019-18", "wet-extraction-2019-18": "WET Extraction 2019-18", "dedup_minhash_CC-MAIN-2013-48_output": "Full MinHash CC-MAIN-2013-48", "dedup_minhash_independent_output_CC-MAIN-2013-48": "Independent MinHash CC-MAIN-2013-48", }; const DEFAULT_SETTINGS = { slider: { max: 30, min: 0, default: 0, }, defaultMetric: "agg_score", type: "line" }; const DEFAULT_LAYOUT = { font: { family: "apple-system, Arial, sans-serif", }, title: { text: "Plot Title", font: { size: 19, family: "apple-system, Arial, sans-serif", }, }, xaxis: { title: { text: "Training tokens (billions)", font: { size: 15, family: "apple-system, Arial, sans-serif", }, }, tickfont: { size: 14, family: "apple-system, Arial, sans-serif", }, showgrid: false, mirror: true, ticks: "outside", showline: true, }, yaxis: { title: { text: "Agg Score", font: { size: 15, family: "apple-system, Arial, sans-serif", }, standoff: 10, }, showgrid: false, mirror: true, ticks: "outside", showline: true, tickfont: { size: 14, family: "apple-system, Arial, sans-serif", }, }, legend: { orientation: "v", xanchor: "right", yanchor: "bottom", x: 1, y: 0, font: { size: 14, family: "apple-system, Arial, sans-serif", }, bgcolor: "rgba(0,0,0,0)", }, margin: { t: 30, b: 50, }, height: 400, }; const getAutoRange = (traces) => { let minX = Math.min(...traces.flatMap((trace) => trace.x)); let maxX = Math.max(...traces.flatMap((trace) => trace.x)); return [minX * 0.95, maxX * 1.05]; }; const createAblationPlottingElements = ( plotElement, indexMapping, settings ) => { const plot = document.createElement("figure"); const controls = document.createElement("div"); plot.classList.add("plotly"); controls.classList.add("plotly_controls"); plotElement.appendChild(plot); plotElement.appendChild(controls); const metricOptions = Object.keys(indexMapping).filter( (metric) => metric in TASK_ID_TO_NAME ); // Dropdown let dropdown = undefined console.log(metricOptions) if (metricOptions.length > 1) { const dropdownLabel = document.createElement("label"); dropdownLabel.textContent = "Metric:"; dropdown = document.createElement("select"); dropdown.innerHTML = metricOptions .map( (option) => `` ) .join(""); dropdown.value = settings.defaultMetric; const dropdownContainer = document.createElement("div"); dropdownContainer.classList.add("plotly_input_container"); dropdownContainer.appendChild(dropdownLabel); dropdownContainer.appendChild(dropdown); controls.appendChild(dropdownContainer); } let slider = undefined; if (settings.slider !== null) { const sliderLabel = document.createElement("label"); sliderLabel.textContent = "Rolling window:"; slider = document.createElement("input"); slider.type = "range"; slider.min = settings.slider.min; slider.max = settings.slider.max; slider.value = settings.slider.default; // current value const sliderValue = document.createElement("span"); sliderValue.textContent = slider.value; slider.addEventListener("input", () => { sliderValue.textContent = slider.value; }); const sliderInputContainer = document.createElement("div"); sliderInputContainer.classList.add("plotly_slider"); sliderInputContainer.appendChild(slider); sliderInputContainer.appendChild(sliderValue); const sliderContainer = document.createElement("div"); sliderContainer.classList.add("plotly_input_container"); sliderContainer.appendChild(sliderLabel); sliderContainer.appendChild(sliderInputContainer); controls.appendChild(sliderContainer); } return { dropdown, slider, plot }; }; const rollingWindow = function (data, windowSize) { if (windowSize === 0) { return data; } const rollingData = []; // Start at halfWindowSize to ensure we can get a full window for (let i = windowSize; i < data.length; i++) { const windowStart = i - windowSize; const windowEnd = i; const windowData = data.slice(windowStart, windowEnd); const windowAverage = windowData.reduce((acc, value) => acc + value, 0) / windowData.length; rollingData.push(windowAverage); } return rollingData; }; export const init_ablation_plot = function () { const plotElements = document.querySelectorAll('[id^="plot-"]'); plotElements.forEach(async (plotElement) => { const plotName = plotElement.id.replace("plot-", ""); const indexData = await fetch(`${DATA_FOLDER}/${plotName}/index.json`).then( (response) => response.json() ); const settings = _.merge({}, DEFAULT_SETTINGS, indexData.settings); const indexMapping = indexData.files; const { dropdown, slider, plot } = createAblationPlottingElements( plotElement, indexMapping, settings ); plot.id = `graph-${plotName}`; if (dropdown !== undefined) { dropdown.addEventListener("change", () => updatePlot(dropdown, slider)); } let timeoutId; // Debounce the slider if (slider !== undefined) { slider.addEventListener("input", () => { clearTimeout(timeoutId); timeoutId = setTimeout(() => { updatePlot(dropdown, slider); }, 500); }); } // Shared plot Plotly.newPlot(plot, []); async function updatePlot(dropdown, slider) { const metricName = dropdown?.value ?? settings.defaultMetric; const sliderValue = parseInt(slider?.value ?? 0); console.log(plotName) console.log(metricName) console.log(indexMapping) const metricData = await fetch( `${DATA_FOLDER}/${plotName}/${indexMapping[metricName]["file"]}` ).then((response) => response.json()); const traces = metricData?.traces ?? []; for (const [index, [key, traceData]] of Object.entries(metricData?.data ?? []).entries()) { const y = rollingWindow(traceData.y, sliderValue); const x = traceData.x.slice(0, y.length); const plotSettings = settings.type === "bar" ? BAR_SETTINGS : LINE_SETTINGS; const trace = _.merge({}, { x: x, y: y, name: traceData.label ?? DATASET_ID_TO_NAME[key] ?? key, marker: { color: getColor(index), }, line: { color: getColor(index), }, }, plotSettings); traces.push(trace); } console.log(traces) const width = plot.parentElement.offsetWidth; const layout = _.merge( {}, DEFAULT_LAYOUT, { width: width, yaxis: { title: { text: TASK_ID_TO_NAME[metricName] } }, xaxis: { range: settings.autoSetXRange ? getAutoRange(traces) : undefined, }, }, metricData.layout ); Plotly.react(plot, traces, layout); window.addEventListener("resize", () => { // If the window size is smaller than 768, we don't care as it's not shown if (window.innerWidth < 768) { return; } Plotly.relayout(plot, { width: plot.parentElement.offsetWidth, }); }); } // Initial plot updatePlot(dropdown, slider); }); };