lvwerra HF staff commited on
Commit
d4506c4
·
verified ·
1 Parent(s): 7b512d6

memory-layout-widget (#12)

Browse files

- minor changes (d3e46cc381099a421ec9d3f04c9abf38709e8ab4)

Files changed (3) hide show
  1. dist/main.bundle.js +0 -0
  2. dist/main.bundle.js.map +0 -0
  3. src/memory.js +36 -23
dist/main.bundle.js CHANGED
The diff for this file is too large to render. See raw diff
 
dist/main.bundle.js.map CHANGED
The diff for this file is too large to render. See raw diff
 
src/memory.js CHANGED
@@ -83,7 +83,7 @@ export function activationMemory(
83
  if (recomputation === "none" || recomputation === "selective") {
84
 
85
  data = {
86
- name: "activationMemory",
87
  children: [
88
  ...Array.from({ length: L }, (_, index) => ({
89
  name: `Layer ${index + 1}`,
@@ -101,7 +101,7 @@ export function activationMemory(
101
  };
102
  } else if (recomputation === "full") {
103
  data = {
104
- name: "activationMemory",
105
  children: [
106
  { name: 'LayerInput', value: s * b * h * bytesPerValue * L },
107
  { name: 'Dropout', value: inputDropout },
@@ -139,7 +139,7 @@ export function paramGradsOpt(h, L, s, v, k = 8, dp = 1, zero = 0, mixed = true)
139
  const bytesPerParameter = mixed ? 2 : 4;
140
 
141
  const data = {
142
- name: "ParametersGradientOps",
143
  children: [
144
  { name: 'Parameters', value: zero >= 3 ? bytesPerParameter * n / dp : bytesPerParameter * n },
145
  { name: 'Gradients', value: zero >= 2 ? bytesPerParameter * n / dp : bytesPerParameter * n },
@@ -220,23 +220,33 @@ export function updateGraph() {
220
 
221
  const color = d => {
222
  switch (d.data.name) {
223
- case 'Parameters': return '#117fc9'; // Blue
224
- case 'Gradients': return '#ffad5c'; // Orange
225
- case 'OptimizerAverages': return '#f67d8e'; // Red
226
- case 'activationMemory': return '#ffad5c'; // Orange
227
- case 'fixed100GB': return '#bae2b4'; // Green
228
- case 'Attention': return '#f67d8e'; // Red
229
- case 'Feedforward': return '#4aacef'; // Light Blue
230
- case 'LayerNorm': return '#fb8b28'; // Dark Orange
231
- case 'Dropout': return '#4ead4e'; // Dark Green
232
- case 'Projection': return '#d94361'; // Dark Red
233
- case 'Cross Entropy': return '#b492d3'; // Violet
234
- case 'Total': return '#bae2b4'; // Green
235
- case 'root': return '#f3f3f3'; // Light Grey
236
- default: return '#a0c4ff'; // Lighter Blue (for unexpected cases)
237
- }
238
- };
239
-
 
 
 
 
 
 
 
 
 
 
240
 
241
  if (d3.select('#tooltip').empty()) {
242
  d3.select('body')
@@ -272,7 +282,7 @@ export function updateGraph() {
272
  .attr("height", d => d.y1 - d.y0)
273
  .attr("fill", d => color(d))
274
  .attr("stroke", d => d.depth === 1 ? color(d) : "white")
275
- .attr("stroke-width", 0.5);
276
 
277
  const fontSize = 10;
278
  const padding = 2;
@@ -290,16 +300,19 @@ export function updateGraph() {
290
  if (d.depth === 1 || d.depth === 2) {
291
  node.attr("transform", `translate(${padding},${fontSize + padding})`)
292
  .attr("font-weight", "bold")
 
293
  .text(`${name}: ${value}`);
294
  } else {
295
  // Child nodes
296
  node.attr("transform", `translate(${padding},${fontSize + padding})`)
297
  .text(name[0].toUpperCase()) // Display only the first letter
 
298
  .append("title") // Add title for hover effect
299
  .text(`${name}: ${value}`);
300
  }
301
  });
302
 
 
303
  // Adjust legend positioning
304
  const legendData = root.children[0].children.concat(root.children[0]);
305
  const legend = svg.append("g")
@@ -325,10 +338,10 @@ export function updateGraph() {
325
  .attr("y", 9.5)
326
  .attr("dy", "0.32em")
327
  .text(d => `${d.data.name}: ${formatBytes(d.value)}`);
328
-
329
  console.log('Treemap nodes created');
330
  }
331
-
332
  function formatBytes(bytes) {
333
  const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB'];
334
  if (bytes === 0) return '0 Bytes';
 
83
  if (recomputation === "none" || recomputation === "selective") {
84
 
85
  data = {
86
+ name: "Activation Memory",
87
  children: [
88
  ...Array.from({ length: L }, (_, index) => ({
89
  name: `Layer ${index + 1}`,
 
101
  };
102
  } else if (recomputation === "full") {
103
  data = {
104
+ name: "Activation Memory",
105
  children: [
106
  { name: 'LayerInput', value: s * b * h * bytesPerValue * L },
107
  { name: 'Dropout', value: inputDropout },
 
139
  const bytesPerParameter = mixed ? 2 : 4;
140
 
141
  const data = {
142
+ name: "Parameters / Gradients / Optimizer States",
143
  children: [
144
  { name: 'Parameters', value: zero >= 3 ? bytesPerParameter * n / dp : bytesPerParameter * n },
145
  { name: 'Gradients', value: zero >= 2 ? bytesPerParameter * n / dp : bytesPerParameter * n },
 
220
 
221
  const color = d => {
222
  switch (d.data.name) {
223
+ // Root and Total (container levels)
224
+ case 'root': return 'rgb(225, 225, 225)'; // Light Grey
225
+ case 'Total': return 'rgb(225, 225, 225)'; // Light Grey
226
+
227
+ // Give distinct colors to the main section containers
228
+ case 'Activation Memory': return 'rgb(78, 165, 183)'; // Orange
229
+ case 'Parameters / Gradients / Optimizer States': return 'rgb(232, 137, 171)'; // Teal Blue
230
+
231
+ // Parameters / Gradients / Optimizer States branch
232
+ case 'Parameters': return 'rgb(206, 192, 250)'; // Blue
233
+ case 'Gradients': return 'rgb(227, 138, 66)'; // Orange
234
+ case 'OptimizerAverages': return 'rgb(78, 165, 183)'; // Pink
235
+
236
+ // activationMemory branch - Layer components
237
+ case 'Attention': return 'rgb(206, 192, 250)'; // Purple
238
+ case 'Feedforward': return 'rgb(171, 232, 241)'; // Light Blue
239
+ case 'LayerNorm': return 'rgb(232, 137, 171)'; // Light Green
240
+
241
+ // activationMemory branch - other components
242
+ case 'Dropout': return 'rgb(67, 145, 108)'; // Dark Green
243
+ case 'Projection': return 'rgb(174, 214, 251)'; // Sky Blue
244
+ case 'Cross Entropy': return 'rgb(232, 137, 171)'; // Pink
245
+
246
+ // Default for any Layer nodes and unexpected cases
247
+ default: return 'rgb(227, 138, 66)'; // Light Grey
248
+ };
249
+ };
250
 
251
  if (d3.select('#tooltip').empty()) {
252
  d3.select('body')
 
282
  .attr("height", d => d.y1 - d.y0)
283
  .attr("fill", d => color(d))
284
  .attr("stroke", d => d.depth === 1 ? color(d) : "white")
285
+ .attr("stroke-width", 1);
286
 
287
  const fontSize = 10;
288
  const padding = 2;
 
300
  if (d.depth === 1 || d.depth === 2) {
301
  node.attr("transform", `translate(${padding},${fontSize + padding})`)
302
  .attr("font-weight", "bold")
303
+ .attr("font-size", 12)
304
  .text(`${name}: ${value}`);
305
  } else {
306
  // Child nodes
307
  node.attr("transform", `translate(${padding},${fontSize + padding})`)
308
  .text(name[0].toUpperCase()) // Display only the first letter
309
+ .attr("font-weight", "bold")
310
  .append("title") // Add title for hover effect
311
  .text(`${name}: ${value}`);
312
  }
313
  });
314
 
315
+ /*
316
  // Adjust legend positioning
317
  const legendData = root.children[0].children.concat(root.children[0]);
318
  const legend = svg.append("g")
 
338
  .attr("y", 9.5)
339
  .attr("dy", "0.32em")
340
  .text(d => `${d.data.name}: ${formatBytes(d.value)}`);
341
+ */
342
  console.log('Treemap nodes created');
343
  }
344
+
345
  function formatBytes(bytes) {
346
  const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB'];
347
  if (bytes === 0) return '0 Bytes';