|
import { getColor } from "./colors.mjs" |
|
import Plotly from "plotly.js/dist/plotly-basic" |
|
import _ from "lodash" |
|
|
|
const DATA_FOLDER = "assets/data/plots" |
|
|
|
|
|
const LINE_SETTINGS = { |
|
width: 2.5, |
|
type: "scatter", |
|
mode: "lines", |
|
} |
|
const BAR_SETTINGS = { |
|
width: 0.5, |
|
type: "bar", |
|
opacity: 0.9, |
|
marker: { |
|
line: { |
|
width: 1.0 |
|
} |
|
} |
|
} |
|
|
|
const TASK_ID_TO_NAME = { |
|
|
|
agg_score: "Aggregate Score", |
|
"commonsense_qa/acc_norm": "Commonsense QA", |
|
"hellaswag/acc_norm": "HellaSwag", |
|
"openbookqa/acc_norm": "OpenBook QA", |
|
"piqa/acc_norm": "PIQA", |
|
"siqa/acc_norm": "Social IQA", |
|
"winogrande/acc_norm": "WinoGrande", |
|
"arc/acc_norm": "ARC", |
|
"mmlu/acc_norm": "MMLU", |
|
|
|
|
|
"lines_ended_with_punct": "Lines Ended With Punctuation", |
|
"lines_chars": "Lines Chars", |
|
"short_lines": "Short Lines", |
|
}; |
|
|
|
const DATASET_ID_TO_NAME = { |
|
pii_removed: "Fineweb", |
|
allenai_c4_en: "C4", |
|
"tiiuae_falcon-refinedweb_data": "RefinedWeb", |
|
"red-pajama-v2_jsonl-deduplicated-extract": "RedPajamaV2", |
|
"dolma-sample": "Dolma1.6", |
|
dedup_minhash_independent_output: "Independent Dedup MinHash", |
|
"dedup_minhash_CC-MAIN-2013-48_output": "Full MinHash CC-MAIN-2013-48", |
|
"dedup_minhash_independent_output_CC-MAIN-2013-48": "Independent MinHash CC-MAIN-2013-48", |
|
"ind_minhash-CC-MAIN-2019-18": "Independent MinHash CC-MAIN-2019-18", |
|
"wet-extraction-2019-18": "WET Extraction 2019-18", |
|
"dedup_minhash_CC-MAIN-2013-48_output": "Full MinHash CC-MAIN-2013-48", |
|
"dedup_minhash_independent_output_CC-MAIN-2013-48": "Independent MinHash CC-MAIN-2013-48", |
|
|
|
}; |
|
|
|
const DEFAULT_SETTINGS = { |
|
slider: { |
|
max: 30, |
|
min: 0, |
|
default: 0, |
|
}, |
|
defaultMetric: "agg_score", |
|
type: "line" |
|
}; |
|
|
|
const DEFAULT_LAYOUT = { |
|
font: { |
|
family: "apple-system, Arial, sans-serif", |
|
}, |
|
title: { |
|
text: "Plot Title", |
|
font: { |
|
size: 19, |
|
family: "apple-system, Arial, sans-serif", |
|
}, |
|
}, |
|
xaxis: { |
|
title: { |
|
text: "Training tokens (billions)", |
|
font: { |
|
size: 15, |
|
family: "apple-system, Arial, sans-serif", |
|
}, |
|
}, |
|
tickfont: { |
|
size: 14, |
|
family: "apple-system, Arial, sans-serif", |
|
}, |
|
showgrid: false, |
|
mirror: true, |
|
ticks: "outside", |
|
showline: true, |
|
}, |
|
yaxis: { |
|
title: { |
|
text: "Agg Score", |
|
font: { |
|
size: 15, |
|
family: "apple-system, Arial, sans-serif", |
|
}, |
|
standoff: 10, |
|
}, |
|
showgrid: false, |
|
mirror: true, |
|
ticks: "outside", |
|
showline: true, |
|
tickfont: { |
|
size: 14, |
|
family: "apple-system, Arial, sans-serif", |
|
}, |
|
}, |
|
legend: { |
|
orientation: "v", |
|
xanchor: "right", |
|
yanchor: "bottom", |
|
x: 1, |
|
y: 0, |
|
font: { |
|
size: 14, |
|
family: "apple-system, Arial, sans-serif", |
|
}, |
|
bgcolor: "rgba(0,0,0,0)", |
|
}, |
|
margin: { |
|
t: 30, |
|
b: 50, |
|
}, |
|
height: 400, |
|
}; |
|
|
|
const getAutoRange = (traces) => { |
|
let minX = Math.min(...traces.flatMap((trace) => trace.x)); |
|
let maxX = Math.max(...traces.flatMap((trace) => trace.x)); |
|
return [minX * 0.95, maxX * 1.05]; |
|
}; |
|
|
|
|
|
const createAblationPlottingElements = ( |
|
plotElement, |
|
indexMapping, |
|
settings |
|
) => { |
|
const plot = document.createElement("figure"); |
|
const controls = document.createElement("div"); |
|
plot.classList.add("plotly"); |
|
controls.classList.add("plotly_controls"); |
|
plotElement.appendChild(plot); |
|
plotElement.appendChild(controls); |
|
|
|
const metricOptions = Object.keys(indexMapping).filter( |
|
(metric) => metric in TASK_ID_TO_NAME |
|
); |
|
|
|
let dropdown = undefined |
|
console.log(metricOptions) |
|
if (metricOptions.length > 1) { |
|
const dropdownLabel = document.createElement("label"); |
|
dropdownLabel.textContent = "Metric:"; |
|
dropdown = document.createElement("select"); |
|
dropdown.innerHTML = metricOptions |
|
.map( |
|
(option) => |
|
`<option value="${option}">${TASK_ID_TO_NAME[option]}</option>` |
|
) |
|
.join(""); |
|
dropdown.value = settings.defaultMetric; |
|
|
|
const dropdownContainer = document.createElement("div"); |
|
dropdownContainer.classList.add("plotly_input_container"); |
|
dropdownContainer.appendChild(dropdownLabel); |
|
dropdownContainer.appendChild(dropdown); |
|
controls.appendChild(dropdownContainer); |
|
} |
|
|
|
let slider = undefined; |
|
if (settings.slider !== null) { |
|
const sliderLabel = document.createElement("label"); |
|
sliderLabel.textContent = "Rolling window:"; |
|
slider = document.createElement("input"); |
|
slider.type = "range"; |
|
slider.min = settings.slider.min; |
|
slider.max = settings.slider.max; |
|
slider.value = settings.slider.default; |
|
|
|
|
|
const sliderValue = document.createElement("span"); |
|
sliderValue.textContent = slider.value; |
|
slider.addEventListener("input", () => { |
|
sliderValue.textContent = slider.value; |
|
}); |
|
const sliderInputContainer = document.createElement("div"); |
|
sliderInputContainer.classList.add("plotly_slider"); |
|
sliderInputContainer.appendChild(slider); |
|
sliderInputContainer.appendChild(sliderValue); |
|
|
|
const sliderContainer = document.createElement("div"); |
|
sliderContainer.classList.add("plotly_input_container"); |
|
|
|
sliderContainer.appendChild(sliderLabel); |
|
sliderContainer.appendChild(sliderInputContainer); |
|
controls.appendChild(sliderContainer); |
|
} |
|
|
|
return { dropdown, slider, plot }; |
|
}; |
|
|
|
const rollingWindow = function (data, windowSize) { |
|
if (windowSize === 0) { |
|
return data; |
|
} |
|
const rollingData = []; |
|
|
|
|
|
for (let i = windowSize; i < data.length; i++) { |
|
const windowStart = i - windowSize; |
|
const windowEnd = i; |
|
const windowData = data.slice(windowStart, windowEnd); |
|
|
|
const windowAverage = |
|
windowData.reduce((acc, value) => acc + value, 0) / |
|
windowData.length; |
|
rollingData.push(windowAverage); |
|
} |
|
|
|
return rollingData; |
|
}; |
|
|
|
export const init_ablation_plot = function () { |
|
const plotElements = document.querySelectorAll('[id^="plot-"]'); |
|
plotElements.forEach(async (plotElement) => { |
|
const plotName = plotElement.id.replace("plot-", ""); |
|
const indexData = await fetch(`${DATA_FOLDER}/${plotName}/index.json`).then( |
|
(response) => response.json() |
|
); |
|
const settings = _.merge({}, DEFAULT_SETTINGS, indexData.settings); |
|
const indexMapping = indexData.files; |
|
const { dropdown, slider, plot } = createAblationPlottingElements( |
|
plotElement, |
|
indexMapping, |
|
settings |
|
); |
|
plot.id = `graph-${plotName}`; |
|
if (dropdown !== undefined) { |
|
dropdown.addEventListener("change", () => updatePlot(dropdown, slider)); |
|
} |
|
let timeoutId; |
|
|
|
if (slider !== undefined) { |
|
slider.addEventListener("input", () => { |
|
clearTimeout(timeoutId); |
|
timeoutId = setTimeout(() => { |
|
updatePlot(dropdown, slider); |
|
}, 500); |
|
}); |
|
} |
|
|
|
Plotly.newPlot(plot, []); |
|
|
|
async function updatePlot(dropdown, slider) { |
|
const metricName = dropdown?.value ?? settings.defaultMetric; |
|
const sliderValue = parseInt(slider?.value ?? 0); |
|
console.log(plotName) |
|
console.log(metricName) |
|
console.log(indexMapping) |
|
const metricData = await fetch( |
|
`${DATA_FOLDER}/${plotName}/${indexMapping[metricName]["file"]}` |
|
).then((response) => response.json()); |
|
const traces = metricData?.traces?.[metricName] ?? []; |
|
for (const [index, [key, traceData]] of Object.entries(metricData?.data ?? []).entries()) { |
|
const y = rollingWindow(traceData.y, sliderValue); |
|
const x = traceData.x.slice(0, y.length); |
|
const plotSettings = settings.type === "bar" ? BAR_SETTINGS : LINE_SETTINGS; |
|
const trace = { |
|
x: x, |
|
y: y, |
|
name: traceData.label ?? DATASET_ID_TO_NAME[key] ?? key, |
|
marker: { |
|
color: getColor(index), |
|
}, |
|
line: { |
|
color: getColor(index), |
|
}, |
|
...plotSettings, |
|
}; |
|
traces.push(trace); |
|
} |
|
console.log(traces) |
|
const width = plot.parentElement.offsetWidth; |
|
const layout = _.merge( |
|
{}, |
|
DEFAULT_LAYOUT, |
|
{ |
|
width: width, |
|
yaxis: { title: { text: TASK_ID_TO_NAME[metricName] } }, |
|
xaxis: { |
|
range: settings.autoSetXRange |
|
? getAutoRange(traces) |
|
: undefined, |
|
}, |
|
}, |
|
metricData.layout |
|
); |
|
Plotly.react(plot, traces, layout); |
|
|
|
window.addEventListener("resize", () => { |
|
|
|
if (window.innerWidth < 768) { |
|
return; |
|
} |
|
Plotly.relayout(plot, { |
|
width: plot.parentElement.offsetWidth, |
|
}); |
|
}); |
|
} |
|
|
|
updatePlot(dropdown, slider); |
|
}); |
|
}; |