!(async function(){
  var isLock = false

  var csvstr = await (await fetch('rotated-accuracy.csv')).text()
  var allData = d3.csvParse(csvstr)
    .filter(d => {
      d.slug = [d.dataset_size, d.aVal, d.minority_percent].join(' ')

      d.accuracy_orig = (+d.accuracy_test_data_1     + +d.accuracy_test_data_7)/2000
      d.accuracy_rot  = (+d.accuracy_test_data_1_rot + +d.accuracy_test_data_7_rot)/2000
      d.accuracy_dif = d.accuracy_orig - d.accuracy_rot

      return d.accuracy_orig > 0 && d.accuracy_rot > 0
    })

  var data = d3.nestBy(allData, d => d.slug)
  data.forEach(slug => {
    slug.accuracy_orig = d3.median(slug, d => d.accuracy_orig)
    slug.accuracy_rot = d3.median(slug, d => d.accuracy_rot)
    slug.accuracy_dif = slug.accuracy_orig - slug.accuracy_rot

    slug.dataset_size = +slug[0].dataset_size
    slug.aVal = +slug[0].aVal
    slug.minority_percent = +slug[0].minority_percent
  })

  // d3.nestBy(data, d => d.length).forEach(d => {
  //   console.log(d.key, d.length)
  // })

  var byMetrics = 'dataset_size aVal minority_percent'
    .split(' ')
    .map(metricStr => {
      var byMetric = d3.nestBy(data, d => d[metricStr])
      byMetric.forEach(d => d.key = +d.key)
      byMetric = _.sortBy(byMetric, d => d.key)
      byMetric.forEach((d, i) => {
        d.metricIndex = i
        d.forEach(e => e['metric_' + metricStr] = d)
      })

      byMetric.forEach((d, i) => {
        if (metricStr == 'dataset_size') d.label = i % 2 == 0 ? '' : d3.format(',')(d.key)
        if (metricStr == 'aVal') d.label = ''
        if (metricStr == 'minority_percent') d.label = i % 2 ? '' : d3.format('.0%')(d.key)
      })

      byMetric.active = byMetric[5]
      byMetric.metricStr = metricStr
      byMetric.label = {dataset_size: 'Training Points', aVal: 'Less Privacy', minority_percent: 'Percent Rotated In Training Data'}[metricStr]

      return byMetric
    })


  // Heat map
  !(function(){
    var sel = d3.select('.rotated-accuracy-heatmap').html('')
      .st({width: 1100, position: 'relative', left: (850 - 1100)/2})
      .at({role: 'graphics-document', 'aria-label': `Faceted MNIST models by the percent of rotated digits in training data. Heatmaps show how privacy and training data change accuracy on rotated and original digits.`})

    sel.append('div.chart-title').text('Percentage of training data rotated 90° →')

    sel.appendMany('div', byMetrics[2])//.filter((d, i) => i % 2 == 0))
      .st({display: 'inline-block'})
      .each(drawHeatmap)
  })()
  function drawHeatmap(sizeData, chartIndex){

    var s = 8
    var n = 11

    var c = d3.conventions({
      sel: d3.select(this),
      width: s*n,
      height: s*n,
      margin: {left: 5, right: 5, top: 30, bottom: 50},
    })

    c.svg.append('rect').at({width: c.width, height: c.height, fillOpacity: 0})

    c.svg.append('text.chart-title')
      .text(d3.format('.0%')(sizeData.key)).at({dy: -4, textAnchor: 'middle', x: c.width/2})
      .st({fontWeight: 300})

    var linearScale = d3.scaleLinear().domain([0, .5]).clamp(1)
    var colorScale = d => d3.interpolatePlasma(linearScale(d))
    
    var pad = .5
    var dataSel = c.svg
      .on('mouseleave', () => isLock = false)
      .append('g').translate([.5, .5])
      .appendMany('g.accuracy-rect', sizeData)
      .translate(d => [
        s*d.metric_dataset_size.metricIndex, 
        s*(n - d.metric_aVal.metricIndex)
      ])
      .call(d3.attachTooltip)
      .on('mouseover', (d, i, node, isClickOverride) => {
        updateTooltip(d)

        if (isLock && !isClickOverride) return

        byMetrics[0].setActiveCol(d.metric_dataset_size)
        byMetrics[1].setActiveCol(d.metric_aVal)
        byMetrics[2].setActiveCol(d.metric_minority_percent)

        return d
      })
      .on('click', clickCb)
      .st({cursor: 'pointer'})



    dataSel.append('rect')
      .at({
        width: s - pad, 
        height: s - pad,
        fillOpacity: .1
      })
    
    // dataSel.append('rect')      
    //   .at({
    //     width:  d => Math.max(1, (s - pad)*(d.accuracy_orig - .5)*2), 
    //     height: d => Math.max(1, (s - pad)*(d.accuracy_rot - .5)*2),
    //   })
    sizeData.forEach(d => {
      d.y_orig = Math.max(0, (s - pad)*(d.accuracy_orig - .5)*2)
      d.y_rot  = Math.max(0, (s - pad)*(d.accuracy_rot - .5)*2)
    })

    dataSel.append('rect')
      .at({
        height: d => d.y_orig,
        y: d => s - d.y_orig,
        width: s/2,
        x: s/2,
        fill: 'purple',
      })
    dataSel.append('rect')
      .at({
        height: d => d.y_rot,
        y: d => s - d.y_rot,
        width: s/2,
        fill: 'orange',
      })

    sizeData.updateActiveRect = function(match){
      dataSel
        .classed('active', d => match == d)
        .filter(d => match == d)
        .raise()
    }

    if (chartIndex == 0){
      c.svg.append('g.x.axis').translate([10, c.height])
      c.svg.append('g.y.axis').translate([0, 5])

      util.addAxisLabel(c, 'Training Points →', 'Less Privacy →', 30, -15)
    }

    if (chartIndex == 8){
      c.svg.appendMany('g.axis', ['Original Digit Accuracy', 'Rotated Digit Accuracy'])
        .translate((d, i) => [c.width - 230*i - 230 -50, c.height + 30])
        .append('text.axis-label').text(d => d)
        .st({fontSize: 14})
        .parent()
        .appendMany('rect', (d, i) => d3.range(.2, 1.2, .2).map((v, j) => ({i, v, j})))
        .at({
          width: s/2,
          y: d => s - d.v*s - s,
          height: d => d.v*s,
          fill: d => ['purple', 'orange'][d.i],
          x: d => d.j*s*.75 - 35
        })
    }
  }

  // Metric barbell charts
  !(function(){
    var sel = d3.select('.rotated-accuracy').html('')
      .at({role: 'graphics-document', 'aria-label': `Barbell charts showing up privacy / data / percent underrepresented data all trade-off in complex ways.`})

    sel.appendMany('div', byMetrics)
      .st({display: 'inline-block', width: 300, marginRight: 10, marginBottom: 50, marginTop: 10})
      .each(drawMetricBarbell)
  })()
  function drawMetricBarbell(byMetric, byMetricIndex){
    var sel = d3.select(this)

    var c = d3.conventions({
      sel,
      height: 220,
      width: 220,
      margin: {bottom: 10, top: 5},
      layers: 's',
    })
    c.svg.append('rect').at({width: c.width, height: c.height, fillOpacity: 0})

    c.y.domain([.5, 1]).interpolate(d3.interpolateRound)
    c.x.domain([0, byMetric.length - 1]).clamp(1).interpolate(d3.interpolateRound)

    c.xAxis
      .tickValues(d3.range(byMetric.length))
      .tickFormat(i => byMetric[i].label)
    c.yAxis.ticks(5).tickFormat(d => d3.format('.0%')(d))

    d3.drawAxis(c)
    util.addAxisLabel(c, byMetric.label + ' →', byMetricIndex ? '' : 'Accuracy')
    util.ggPlotBg(c, false)

    c.svg.select('.x').raise()
    c.svg.selectAll('.axis').st({pointerEvents: 'none'})

    c.svg.append('defs').append('linearGradient#purple-to-orange')
      .at({x1: '0%', x2: '0%', y1: '0%', y2: '100%'})
      .append('stop').at({offset: '0%', 'stop-color': 'purple'}).parent()
      .append('stop').at({offset: '100%', 'stop-color': 'orange'})

    c.svg.append('defs').append('linearGradient#orange-to-purple')
      .at({x1: '0%', x2: '0%', y2: '0%', y1: '100%'})
      .append('stop').at({offset: '0%', 'stop-color': 'purple'}).parent()
      .append('stop').at({offset: '100%', 'stop-color': 'orange'})

    var colSel = c.svg.appendMany('g', byMetric)
      .translate(d => c.x(d.metricIndex) + .5, 0)
      .st({pointerEvents: 'none'})

    var pathSel = colSel.append('path')
      .at({stroke: 'url(#purple-to-orange)', strokeWidth: 1})

    var rectSel = colSel.append('rect')
      .at({width: 1, x: -.5})

    var origCircleSel = colSel.append('circle')
      .at({r: 3, fill: 'purple', stroke: '#000', strokeWidth: .5})

    var rotCircleSel = colSel.append('circle')
      .at({r: 3, fill: 'orange', stroke: '#000', strokeWidth: .5})

    function clampY(d){
      return d3.clamp(0, c.y(d), c.height + 3)
    }

    byMetric.updateActiveCol = function(){
      var findObj = {}
      byMetrics
        .filter(d => d != byMetric)
        .forEach(d => {
          findObj[d.metricStr] = d.active.key
        })

      byMetric.forEach(col => {
        col.active = _.find(col, findObj)
      })

      origCircleSel.at({cy: d => clampY(d.active.accuracy_orig)})
      rotCircleSel.at({cy: d => clampY(d.active.accuracy_rot)})

      // pathSel.at({
      //   d: d => 'M 0 ' + clampY(d.active.accuracy_orig) + ' L 1 ' + clampY(d.active.accuracy_rot)
      // })

      rectSel.at({
        y: d => Math.min(clampY(d.active.accuracy_orig), clampY(d.active.accuracy_rot)),
        height: d => Math.abs(clampY(d.active.accuracy_orig) - clampY(d.active.accuracy_rot)),
        fill: d => d.active.accuracy_orig > d.active.accuracy_rot ? 'url(#purple-to-orange)' : 'url(#orange-to-purple)'
      })
    }
    byMetric.updateActiveCol()


    c.svg
      .call(d3.attachTooltip)
      .st({cursor: 'pointer'})
      .on('mousemove', function(d, i, node, isClickOverride){
        var [mx] = d3.mouse(this)
        var metricIndex = Math.round(c.x.invert(mx))

        var prevActive = byMetric.active
        byMetric.active = byMetric[metricIndex]
        updateTooltip()
        byMetric.active = prevActive

        if (isLock && !isClickOverride) return
        byMetric.setActiveCol(byMetric[metricIndex])

        return byMetric[metricIndex]
      })
      .on('click', clickCb)
      .on('mouseexit', () => isLock = false)


    byMetric.setActiveCol = function(col){
      if (col) byMetric.active = col

      c.svg.selectAll('.x .tick')
        .classed('active', i => i == byMetric.active.metricIndex)

      colSel.classed('active', d => d == byMetric.active)

      if (col) renderActiveCol()
    }
    byMetric.setActiveCol()
  }

  function renderActiveCol(){
    byMetrics.forEach(d => {
      if (d.updateActiveCol) d.updateActiveCol()
    })

    var findObj = {}
    byMetrics.forEach(d => findObj[d.metricStr] = d.active.key)
    var match = _.find(data, findObj)

    byMetrics[2].forEach(d => {
      if (d.updateActiveRect) d.updateActiveRect(match)
    })
  }

  function updateTooltip(d){
    if (!d){
      var findObj = {}
      byMetrics.forEach(d => findObj[d.metricStr] = d.active.key)
      d = _.find(data, findObj)
    }

    var epsilon = Math.round(d[0].epsilon*100)/100
    ttSel.html(`
      <div>
        <b>${d3.format('.0%')(d.accuracy_orig)}</b> 
        accuracy on 
        <span style='padding: 2px; background: purple; color: #fff'>
          original digits
        </span>
      <div>
      <div>
        <b>${d3.format('.0%')(d.accuracy_rot)}</b> 
        accuracy on 
        <span style='padding: 2px; background: orange; color: #000'>
          rotated digits
        </span>
      <br>
      <br>
      <div>Training points: ${d3.format(',')(d.dataset_size)}</div>
      <div>Privacy: ${epsilon} ε</div>
      <div>Rotated in training data: ${d3.format('.0%')(d.minority_percent)} </div>

    `).st({width: 230})

    ttSel.classed('tooltip-footnote', 0)
  }

  function clickCb(d, i, node){
    var mFn = d3.select(this).on('mouseover') || d3.select(this).on('mousemove')

    var e = mFn.call(this, d, i, node, true)
    isLock = e == isLock ? null : e
  }


})()