import { getColor } from "./colors.mjs" |
import Plotly from "plotly.js/dist/plotly-basic" |
import _ from "lodash" |
const DATA_FOLDER = "assets/data/plots" |
const LINE_SETTINGS = { |
width: 2.5, |
type: "scatter", |
mode: "lines", |
} |
const BAR_SETTINGS = { |
width: 0.5, |
type: "bar", |
opacity: 0.9, |
marker: { |
line: { |
width: 1.0 |
} |
} |
} |
const TASK_ID_TO_NAME = { |
agg_score: "Aggregate Score", |
"commonsense_qa/acc_norm": "Commonsense QA", |
"hellaswag/acc_norm": "HellaSwag", |
"openbookqa/acc_norm": "OpenBook QA", |
"piqa/acc_norm": "PIQA", |
"siqa/acc_norm": "Social IQA", |
"winogrande/acc_norm": "WinoGrande", |
"arc/acc_norm": "ARC", |
"mmlu/acc_norm": "MMLU", |
"lines_ended_with_punct": "Lines Ended With Punctuation", |
"lines_chars": "Lines Chars", |
"short_lines": "Short Lines", |
}; |
const DATASET_ID_TO_NAME = { |
pii_removed: "Fineweb", |
allenai_c4_en: "C4", |
"tiiuae_falcon-refinedweb_data": "RefinedWeb", |
"red-pajama-v2_jsonl-deduplicated-extract": "RedPajamaV2", |
"dolma-sample": "Dolma1.6", |
dedup_minhash_independent_output: "Independent Dedup MinHash", |
"dedup_minhash_CC-MAIN-2013-48_output": "Full MinHash CC-MAIN-2013-48", |
"dedup_minhash_independent_output_CC-MAIN-2013-48": "Independent MinHash CC-MAIN-2013-48", |
"ind_minhash-CC-MAIN-2019-18": "Independent MinHash CC-MAIN-2019-18", |
"wet-extraction-2019-18": "WET Extraction 2019-18", |
"dedup_minhash_CC-MAIN-2013-48_output": "Full MinHash CC-MAIN-2013-48", |
"dedup_minhash_independent_output_CC-MAIN-2013-48": "Independent MinHash CC-MAIN-2013-48", |
}; |
slider: { |
max: 30, |
min: 0, |
default: 0, |
}, |
defaultMetric: "agg_score", |
type: "line" |
}; |
const DEFAULT_LAYOUT = { |
font: { |
family: "apple-system, Arial, sans-serif", |
}, |
title: { |
text: "Plot Title", |
font: { |
size: 19, |
family: "apple-system, Arial, sans-serif", |
}, |
}, |
xaxis: { |
title: { |
text: "Training tokens (billions)", |
font: { |
size: 15, |
family: "apple-system, Arial, sans-serif", |
}, |
}, |
tickfont: { |
size: 14, |
family: "apple-system, Arial, sans-serif", |
}, |
showgrid: false, |
mirror: true, |
ticks: "outside", |
showline: true, |
}, |
yaxis: { |
title: { |
text: "Agg Score", |
font: { |
size: 15, |
family: "apple-system, Arial, sans-serif", |
}, |
standoff: 10, |
}, |
showgrid: false, |
mirror: true, |
ticks: "outside", |
showline: true, |
tickfont: { |
size: 14, |
family: "apple-system, Arial, sans-serif", |
}, |
}, |
legend: { |
orientation: "v", |
xanchor: "right", |
yanchor: "bottom", |
x: 1, |
y: 0, |
font: { |
size: 14, |
family: "apple-system, Arial, sans-serif", |
}, |
bgcolor: "rgba(0,0,0,0)", |
}, |
margin: { |
t: 30, |
b: 50, |
}, |
height: 400, |
}; |
const getAutoRange = (traces) => { |
let minX = Math.min(...traces.flatMap((trace) => trace.x)); |
let maxX = Math.max(...traces.flatMap((trace) => trace.x)); |
return [minX * 0.95, maxX * 1.05]; |
}; |
const createAblationPlottingElements = ( |
plotElement, |
indexMapping, |
settings |
) => { |
const plot = document.createElement("figure"); |
const controls = document.createElement("div"); |
plot.classList.add("plotly"); |
controls.classList.add("plotly_controls"); |
plotElement.appendChild(plot); |
plotElement.appendChild(controls); |
const metricOptions = Object.keys(indexMapping).filter( |
(metric) => metric in TASK_ID_TO_NAME |
); |
let dropdown = undefined |
console.log(metricOptions) |
if (metricOptions.length > 1) { |
const dropdownLabel = document.createElement("label"); |
dropdownLabel.textContent = "Metric:"; |
dropdown = document.createElement("select"); |
dropdown.innerHTML = metricOptions |
.map( |
(option) => |
`<option value="${option}">${TASK_ID_TO_NAME[option]}</option>` |
) |
.join(""); |
dropdown.value = settings.defaultMetric; |
const dropdownContainer = document.createElement("div"); |
dropdownContainer.classList.add("plotly_input_container"); |
dropdownContainer.appendChild(dropdownLabel); |
dropdownContainer.appendChild(dropdown); |
controls.appendChild(dropdownContainer); |
} |
let slider = undefined; |
if (settings.slider !== null) { |
const sliderLabel = document.createElement("label"); |
sliderLabel.textContent = "Rolling window:"; |
slider = document.createElement("input"); |
slider.type = "range"; |
slider.min = settings.slider.min; |
slider.max = settings.slider.max; |
slider.value = settings.slider.default; |
const sliderValue = document.createElement("span"); |
sliderValue.textContent = slider.value; |
slider.addEventListener("input", () => { |
sliderValue.textContent = slider.value; |
}); |
const sliderInputContainer = document.createElement("div"); |
sliderInputContainer.classList.add("plotly_slider"); |
sliderInputContainer.appendChild(slider); |
sliderInputContainer.appendChild(sliderValue); |
const sliderContainer = document.createElement("div"); |
sliderContainer.classList.add("plotly_input_container"); |
sliderContainer.appendChild(sliderLabel); |
sliderContainer.appendChild(sliderInputContainer); |
controls.appendChild(sliderContainer); |
} |
return { dropdown, slider, plot }; |
}; |
const rollingWindow = function (data, windowSize) { |
if (windowSize === 0) { |
return data; |
} |
const rollingData = []; |
for (let i = windowSize; i < data.length; i++) { |
const windowStart = i - windowSize; |
const windowEnd = i; |
const windowData = data.slice(windowStart, windowEnd); |
const windowAverage = |
windowData.reduce((acc, value) => acc + value, 0) / |
windowData.length; |
rollingData.push(windowAverage); |
} |
return rollingData; |
}; |
export const init_ablation_plot = function () { |
const plotElements = document.querySelectorAll('[id^="plot-"]'); |
plotElements.forEach(async (plotElement) => { |
const plotName = plotElement.id.replace("plot-", ""); |
const indexData = await fetch(`${DATA_FOLDER}/${plotName}/index.json`).then( |
(response) => response.json() |
); |
const settings = _.merge({}, DEFAULT_SETTINGS, indexData.settings); |
const indexMapping = indexData.files; |
const { dropdown, slider, plot } = createAblationPlottingElements( |
plotElement, |
indexMapping, |
settings |
); |
plot.id = `graph-${plotName}`; |
if (dropdown !== undefined) { |
dropdown.addEventListener("change", () => updatePlot(dropdown, slider)); |
} |
let timeoutId; |
if (slider !== undefined) { |
slider.addEventListener("input", () => { |
clearTimeout(timeoutId); |
timeoutId = setTimeout(() => { |
updatePlot(dropdown, slider); |
}, 500); |
}); |
} |
Plotly.newPlot(plot, []); |
async function updatePlot(dropdown, slider) { |
const metricName = dropdown?.value ?? settings.defaultMetric; |
const sliderValue = parseInt(slider?.value ?? 0); |
console.log(plotName) |
console.log(metricName) |
console.log(indexMapping) |
const metricData = await fetch( |
`${DATA_FOLDER}/${plotName}/${indexMapping[metricName]["file"]}` |
).then((response) => response.json()); |
const traces = metricData?.traces?.[metricName] ?? []; |
for (const [index, [key, traceData]] of Object.entries(metricData?.data ?? []).entries()) { |
const y = rollingWindow(traceData.y, sliderValue); |
const x = traceData.x.slice(0, y.length); |
const plotSettings = settings.type === "bar" ? BAR_SETTINGS : LINE_SETTINGS; |
const trace = _.merge({}, { |
x: x, |
y: y, |
name: traceData.label ?? DATASET_ID_TO_NAME[key] ?? key, |
marker: { |
color: getColor(index), |
}, |
line: { |
color: getColor(index), |
}, |
}, plotSettings); |
traces.push(trace); |
} |
console.log(traces) |
const width = plot.parentElement.offsetWidth; |
const layout = _.merge( |
{}, |
{ |
width: width, |
yaxis: { title: { text: TASK_ID_TO_NAME[metricName] } }, |
xaxis: { |
range: settings.autoSetXRange |
? getAutoRange(traces) |
: undefined, |
}, |
}, |
metricData.layout |
); |
Plotly.react(plot, traces, layout); |
window.addEventListener("resize", () => { |
if (window.innerWidth < 768) { |
return; |
} |
Plotly.relayout(plot, { |
width: plot.parentElement.offsetWidth, |
}); |
}); |
} |
updatePlot(dropdown, slider); |
}); |
}; |