'''
To run dissection:

1. Load up the convolutional model you wish to dissect, and wrap it in
   an InstrumentedModel; then call imodel.retain_layers([layernames,..])
   to instrument the layers of interest.
2. Load the segmentation dataset using the BrodenDataset class;
   use the transform_image argument to normalize images to be
   suitable for the model, or the size argument to truncate the dataset.
3. Choose a directory in which to write the output, and call
   dissect(outdir, model, dataset).

Example:

    from dissect import InstrumentedModel, dissect
    from broden import BrodenDataset

    model = InstrumentedModel(load_my_model())
    model.eval()
    model.cuda()
    model.retain_layers(['conv1', 'conv2', 'conv3', 'conv4', 'conv5'])
    bds = BrodenDataset('dataset/broden1_227',
            transform_image=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]),
            size=1000)
    dissect('result/dissect', model, bds,
            examples_per_unit=10)
'''

import torch, numpy, os, re, json, shutil, types, tempfile, torchvision
# import warnings
# warnings.simplefilter('error', UserWarning)
from PIL import Image
from xml.etree import ElementTree as et
from collections import OrderedDict, defaultdict
from .progress import verbose_progress, default_progress, print_progress
from .progress import desc_progress
from .runningstats import RunningQuantile, RunningTopK
from .runningstats import RunningCrossCovariance, RunningConditionalQuantile
from .sampler import FixedSubsetSampler
from .actviz import activation_visualization
from .segviz import segment_visualization, high_contrast
from .workerpool import WorkerBase, WorkerPool
from .segmenter import UnifiedParsingSegmenter

def dissect(outdir, model, dataset,
        segrunner=None,
        train_dataset=None,
        model_segmenter=None,
        quantile_threshold=0.005,
        iou_threshold=0.05,
        iqr_threshold=0.01,
        examples_per_unit=100,
        batch_size=100,
        num_workers=24,
        seg_batch_size=5,
        make_images=True,
        make_labels=True,
        make_maxiou=False,
        make_covariance=False,
        make_report=True,
        make_row_images=True,
        make_single_images=False,
        rank_all_labels=False,
        netname=None,
        meta=None,
        merge=None,
        settings=None,
        ):
    '''
    Runs net dissection in-memory, using pytorch, and saves visualizations
    and metadata into outdir.
    '''
    assert not model.training, 'Run model.eval() before dissection'
    if netname is None:
        netname = type(model).__name__
    if segrunner is None:
        segrunner = ClassifierSegRunner(dataset)
    if train_dataset is None:
        train_dataset = dataset
    make_iqr = (quantile_threshold == 'iqr')
    with torch.no_grad():
        device = next(model.parameters()).device
        levels = None
        labelnames, catnames = None, None
        maxioudata, iqrdata = None, None
        labeldata = None
        iqrdata, cov = None, None

        labelnames, catnames = segrunner.get_label_and_category_names()
        label_category = [catnames.index(c) if c in catnames else 0
                for l, c in labelnames]

        # First, always collect qunatiles and topk information.
        segloader = torch.utils.data.DataLoader(dataset,
                batch_size=batch_size, num_workers=num_workers,
                pin_memory=(device.type == 'cuda'))
        quantiles, topk = collect_quantiles_and_topk(outdir, model,
            segloader, segrunner, k=examples_per_unit)

        # Thresholds can be automatically chosen by maximizing iqr
        if make_iqr:
            # Get thresholds based on an IQR optimization
            segloader = torch.utils.data.DataLoader(train_dataset,
                    batch_size=1, num_workers=num_workers,
                    pin_memory=(device.type == 'cuda'))
            iqrdata = collect_iqr(outdir, model, segloader, segrunner)
            max_iqr, full_iqr_levels = iqrdata[:2]
            max_iqr_agreement = iqrdata[4]
            # qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0
            levels = {layer: full_iqr_levels[layer][
                    max_iqr[layer].max(0)[1],
                    torch.arange(max_iqr[layer].shape[1])].to(device)
                    for layer in full_iqr_levels}
        else:
            levels = {k: qc.quantiles([1.0 - quantile_threshold])[:,0]
                      for k, qc in quantiles.items()}

        quantiledata = (topk, quantiles, levels, quantile_threshold)

        if make_images:
            segloader = torch.utils.data.DataLoader(dataset,
                    batch_size=batch_size, num_workers=num_workers,
                    pin_memory=(device.type == 'cuda'))
            generate_images(outdir, model, dataset, topk, levels, segrunner,
                    row_length=examples_per_unit, batch_size=seg_batch_size,
                    row_images=make_row_images,
                    single_images=make_single_images,
                    num_workers=num_workers)

        if make_maxiou:
            assert train_dataset, "Need training dataset for maxiou."
            segloader = torch.utils.data.DataLoader(train_dataset,
                    batch_size=1, num_workers=num_workers,
                    pin_memory=(device.type == 'cuda'))
            maxioudata = collect_maxiou(outdir, model, segloader,
                    segrunner)

        if make_labels:
            segloader = torch.utils.data.DataLoader(dataset,
                    batch_size=1, num_workers=num_workers,
                    pin_memory=(device.type == 'cuda'))
            iou_scores, iqr_scores, tcs, lcs, ccs, ics = (
                    collect_bincounts(outdir, model, segloader,
                    levels, segrunner))
            labeldata = (iou_scores, iqr_scores, lcs, ccs, ics, iou_threshold,
                    iqr_threshold)

        if make_covariance:
            segloader = torch.utils.data.DataLoader(dataset,
                    batch_size=seg_batch_size,
                    num_workers=num_workers,
                    pin_memory=(device.type == 'cuda'))
            cov = collect_covariance(outdir, model, segloader, segrunner)

        if make_report:
            generate_report(outdir,
                    quantiledata=quantiledata,
                    labelnames=labelnames,
                    catnames=catnames,
                    labeldata=labeldata,
                    maxioudata=maxioudata,
                    iqrdata=iqrdata,
                    covariancedata=cov,
                    rank_all_labels=rank_all_labels,
                    netname=netname,
                    meta=meta,
                    mergedata=merge,
                    settings=settings)

        return quantiledata, labeldata

def generate_report(outdir, quantiledata, labelnames=None, catnames=None,
        labeldata=None, maxioudata=None, iqrdata=None, covariancedata=None,
        rank_all_labels=False, netname='Model', meta=None, settings=None,
        mergedata=None):
    '''
    Creates dissection.json reports and summary bargraph.svg files in the
    specified output directory, and copies a dissection.html interface
    to go along with it.
    '''
    all_layers = []
    # Current source code directory, for html to copy.
    srcdir = os.path.realpath(
       os.path.join(os.getcwd(), os.path.dirname(__file__)))
    # Unpack arguments
    topk, quantiles, levels, quantile_threshold = quantiledata
    top_record = dict(
            netname=netname,
            meta=meta,
            default_ranking='unit',
            quantile_threshold=quantile_threshold)
    if settings is not None:
        top_record['settings'] = settings
    if labeldata is not None:
        iou_scores, iqr_scores, lcs, ccs, ics, iou_threshold, iqr_threshold = (
                labeldata)
        catorder = {'object': -7, 'scene': -6, 'part': -5,
                    'piece': -4,
                    'material': -3, 'texture': -2, 'color': -1}
        for i, cat in enumerate(c for c in catnames if c not in catorder):
            catorder[cat] = i
        catnumber = {n: i for i, n in enumerate(catnames)}
        catnumber['-'] = 0
        top_record['default_ranking'] = 'label'
        top_record['iou_threshold'] = iou_threshold
        top_record['iqr_threshold'] = iqr_threshold
        labelnumber = dict((name[0], num)
                for num, name in enumerate(labelnames))
    # Make a segmentation color dictionary
    segcolors = {}
    for i, name in enumerate(labelnames):
        key = ','.join(str(s) for s in high_contrast[i % len(high_contrast)])
        if key in segcolors:
            segcolors[key] += '/' + name[0]
        else:
            segcolors[key] = name[0]
    top_record['segcolors'] = segcolors
    for layer in topk.keys():
        units, rankings = [], []
        record = dict(layer=layer, units=units, rankings=rankings)
        # For every unit, we always have basic visualization information.
        topa, topi = topk[layer].result()
        lev = levels[layer]
        for u in range(len(topa)):
            units.append(dict(
                unit=u,
                interp=True,
                level=lev[u].item(),
                top=[dict(imgnum=i.item(), maxact=a.item())
                    for i, a in zip(topi[u], topa[u])],
                ))
        rankings.append(dict(name="unit", score=list([
            u for u in range(len(topa))])))
        # TODO: consider including stats and ranking based on quantiles,
        # variance, connectedness here.

        # if we have labeldata, then every unit also gets a bunch of other info
        if labeldata is not None:
            lscore, qscore, cc, ic = [dat[layer]
                    for dat in [iou_scores, iqr_scores, ccs, ics]]
            if iqrdata is not None:
                # If we have IQR thresholds, assign labels based on that
                max_iqr, max_iqr_level = iqrdata[:2]
                best_label = max_iqr[layer].max(0)[1]
                best_score = lscore[best_label, torch.arange(lscore.shape[1])]
                best_qscore = qscore[best_label, torch.arange(lscore.shape[1])]
            else:
                # Otherwise, assign labels based on max iou
                best_score, best_label = lscore.max(0)
                best_qscore = qscore[best_label, torch.arange(qscore.shape[1])]
            record['iou_threshold'] = iou_threshold,
            for u, urec in enumerate(units):
                score, qscore, label = (
                        best_score[u], best_qscore[u], best_label[u])
                urec.update(dict(
                    iou=score.item(),
                    iou_iqr=qscore.item(),
                    lc=lcs[label].item(),
                    cc=cc[catnumber[labelnames[label][1]], u].item(),
                    ic=ic[label, u].item(),
                    interp=(qscore.item() > iqr_threshold and
                        score.item() > iou_threshold),
                    iou_labelnum=label.item(),
                    iou_label=labelnames[label.item()][0],
                    iou_cat=labelnames[label.item()][1],
                    ))
        if maxioudata is not None:
            max_iou, max_iou_level, max_iou_quantile = maxioudata
            qualified_iou = max_iou[layer].clone()
            # qualified_iou[max_iou_quantile[layer] > 0.75] = 0
            best_score, best_label = qualified_iou.max(0)
            for u, urec in enumerate(units):
                urec.update(dict(
                    maxiou=best_score[u].item(),
                    maxiou_label=labelnames[best_label[u].item()][0],
                    maxiou_cat=labelnames[best_label[u].item()][1],
                    maxiou_level=max_iou_level[layer][best_label[u], u].item(),
                    maxiou_quantile=max_iou_quantile[layer][
                        best_label[u], u].item()))
        if iqrdata is not None:
            [max_iqr, max_iqr_level, max_iqr_quantile,
                    max_iqr_iou, max_iqr_agreement] = iqrdata
            qualified_iqr = max_iqr[layer].clone()
            qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0
            best_score, best_label = qualified_iqr.max(0)
            for u, urec in enumerate(units):
                urec.update(dict(
                    iqr=best_score[u].item(),
                    iqr_label=labelnames[best_label[u].item()][0],
                    iqr_cat=labelnames[best_label[u].item()][1],
                    iqr_level=max_iqr_level[layer][best_label[u], u].item(),
                    iqr_quantile=max_iqr_quantile[layer][
                        best_label[u], u].item(),
                    iqr_iou=max_iqr_iou[layer][best_label[u], u].item()
                    ))
        if covariancedata is not None:
            score = covariancedata[layer].correlation()
            best_score, best_label = score.max(1)
            for u, urec in enumerate(units):
                urec.update(dict(
                    cor=best_score[u].item(),
                    cor_label=labelnames[best_label[u].item()][0],
                    cor_cat=labelnames[best_label[u].item()][1]
                    ))
        if mergedata is not None:
            # Final step: if the user passed any data to merge into the
            # units, merge them now.  This can be used, for example, to
            # indiate that a unit is not interpretable based on some
            # outside analysis of unit statistics.
            for lrec in mergedata.get('layers', []):
                if lrec['layer'] == layer:
                    break
            else:
                lrec = None
            for u, urec in enumerate(lrec.get('units', []) if lrec else []):
                units[u].update(urec)
        # After populating per-unit info, populate per-layer ranking info
        if labeldata is not None:
            # Collect all labeled units
            labelunits = defaultdict(list)
            all_labelunits = defaultdict(list)
            for u, urec in enumerate(units):
                if urec['interp']:
                    labelunits[urec['iou_labelnum']].append(u)
                all_labelunits[urec['iou_labelnum']].append(u)
            # Sort all units in order with most popular label first.
            label_ordering = sorted(units,
                # Sort by:
                key=lambda r: (-1 if r['interp'] else 0,  # interpretable
                    -len(labelunits[r['iou_labelnum']]),  # label freq, score
                    -max([units[u]['iou']
                        for u in labelunits[r['iou_labelnum']]], default=0),
                    r['iou_labelnum'],                    # label
                    -r['iou']))                           # unit score
            # Add label and iou ranking.
            rankings.append(dict(name="label", score=(numpy.argsort(list(
                ur['unit'] for ur in label_ordering))).tolist()))
            rankings.append(dict(name="max iou", metric="iou", score=list(
                -ur['iou'] for ur in units)))
            # Add ranking for top labels
            # for labelnum in [n for n in sorted(
            #     all_labelunits.keys(), key=lambda x:
            #         -len(all_labelunits[x])) if len(all_labelunits[n])]:
            #     label = labelnames[labelnum][0]
            #     rankings.append(dict(name="%s-iou" % label,
            #         concept=label, metric='iou',
            #         score=(-lscore[labelnum, :]).tolist()))
            # Collate labels by category then frequency.
            record['labels'] = [dict(
                        label=labelnames[label][0],
                        labelnum=label,
                        units=labelunits[label],
                        cat=labelnames[label][1])
                    for label in (sorted(labelunits.keys(),
                        # Sort by:
                        key=lambda l: (catorder.get(          # category
                            labelnames[l][1], 0),
                            -len(labelunits[l]),              # label freq
                            -max([units[u]['iou'] for u in labelunits[l]],
                                default=0) # score
                            ))) if len(labelunits[label])]
            # Total number of interpretable units.
            record['interpretable'] = sum(len(group['units'])
                    for group in record['labels'])
            # Make a bargraph of labels
            os.makedirs(os.path.join(outdir, safe_dir_name(layer)),
                    exist_ok=True)
            catgroups = OrderedDict()
            for _, cat in sorted([(v, k) for k, v in catorder.items()]):
                catgroups[cat] = []
            for rec in record['labels']:
                if rec['cat'] not in catgroups:
                    catgroups[rec['cat']] = []
                catgroups[rec['cat']].append(rec['label'])
            make_svg_bargraph(
                    [rec['label'] for rec in record['labels']],
                    [len(rec['units']) for rec in record['labels']],
                    [(cat, len(group)) for cat, group in catgroups.items()],
                    filename=os.path.join(outdir, safe_dir_name(layer),
                        'bargraph.svg'))
            # Only show the bargraph if it is non-empty.
            if len(record['labels']):
                record['bargraph'] = 'bargraph.svg'
        if maxioudata is not None:
            rankings.append(dict(name="max maxiou", metric="maxiou", score=list(
                    -ur['maxiou'] for ur in units)))
        if iqrdata is not None:
            rankings.append(dict(name="max iqr", metric="iqr", score=list(
                    -ur['iqr'] for ur in units)))
        if covariancedata is not None:
            rankings.append(dict(name="max cor", metric="cor", score=list(
                    -ur['cor'] for ur in units)))

        all_layers.append(record)
    # Now add the same rankings to every layer...
    all_labels = None
    if rank_all_labels:
        all_labels = [name for name, cat in labelnames]
    if labeldata is not None:
        # Count layers+quadrants with a given label, and sort by freq
        counted_labels = defaultdict(int)
        for label in [
                re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', unitrec['iou_label'])
                for record in all_layers for unitrec in record['units']]:
            counted_labels[label] += 1
        if all_labels is None:
            all_labels = [label for count, label in sorted((-v, k)
                for k, v in counted_labels.items())]
        for record in all_layers:
            layer = record['layer']
            for label in all_labels:
                labelnum = labelnumber[label]
                record['rankings'].append(dict(name="%s-iou" % label,
                    concept=label, metric='iou',
                    score=(-iou_scores[layer][labelnum, :]).tolist()))

    if maxioudata is not None:
        if all_labels is None:
            counted_labels = defaultdict(int)
            for label in [
                    re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '',
                        unitrec['maxiou_label'])
                    for record in all_layers for unitrec in record['units']]:
                counted_labels[label] += 1
            all_labels = [label for count, label in sorted((-v, k)
                for k, v in counted_labels.items())]
        qualified_iou = max_iou[layer].clone()
        qualified_iou[max_iou_quantile[layer] > 0.5] = 0
        for record in all_layers:
            layer = record['layer']
            for label in all_labels:
                labelnum = labelnumber[label]
                record['rankings'].append(dict(name="%s-maxiou" % label,
                    concept=label, metric='maxiou',
                    score=(-qualified_iou[labelnum, :]).tolist()))

    if iqrdata is not None:
        if all_labels is None:
            counted_labels = defaultdict(int)
            for label in [
                    re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '',
                        unitrec['iqr_label'])
                    for record in all_layers for unitrec in record['units']]:
                counted_labels[label] += 1
            all_labels = [label for count, label in sorted((-v, k)
                for k, v in counted_labels.items())]
        # qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0
        for record in all_layers:
            layer = record['layer']
            qualified_iqr = max_iqr[layer].clone()
            for label in all_labels:
                labelnum = labelnumber[label]
                record['rankings'].append(dict(name="%s-iqr" % label,
                    concept=label, metric='iqr',
                    score=(-qualified_iqr[labelnum, :]).tolist()))

    if covariancedata is not None:
        if all_labels is None:
            counted_labels = defaultdict(int)
            for label in [
                    re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '',
                        unitrec['cor_label'])
                    for record in all_layers for unitrec in record['units']]:
                counted_labels[label] += 1
            all_labels = [label for count, label in sorted((-v, k)
                for k, v in counted_labels.items())]
        for record in all_layers:
            layer = record['layer']
            score = covariancedata[layer].correlation()
            for label in all_labels:
                labelnum = labelnumber[label]
                record['rankings'].append(dict(name="%s-cor" % label,
                    concept=label, metric='cor',
                    score=(-score[:, labelnum]).tolist()))

    for record in all_layers:
        layer = record['layer']
        # Dump per-layer json inside per-layer directory
        record['dirname'] = '.'
        with open(os.path.join(outdir, safe_dir_name(layer), 'dissect.json'),
                'w') as jsonfile:
            top_record['layers'] = [record]
            json.dump(top_record, jsonfile, indent=1)
        # Copy the per-layer html
        shutil.copy(os.path.join(srcdir, 'dissect.html'),
                os.path.join(outdir, safe_dir_name(layer), 'dissect.html'))
        record['dirname'] = safe_dir_name(layer)

    # Dump all-layer json in parent directory
    with open(os.path.join(outdir, 'dissect.json'), 'w') as jsonfile:
        top_record['layers'] = all_layers
        json.dump(top_record, jsonfile, indent=1)
    # Copy the all-layer html
    shutil.copy(os.path.join(srcdir, 'dissect.html'),
            os.path.join(outdir, 'dissect.html'))
    shutil.copy(os.path.join(srcdir, 'edit.html'),
            os.path.join(outdir, 'edit.html'))


def generate_images(outdir, model, dataset, topk, levels,
        segrunner, row_length=None, gap_pixels=5,
        row_images=True, single_images=False, prefix='',
        batch_size=100, num_workers=24):
    '''
    Creates an image strip file for every unit of every retained layer
    of the model, in the format [outdir]/[layername]/[unitnum]-top.jpg.
    Assumes that the indexes of topk refer to the indexes of dataset.
    Limits each strip to the top row_length images.
    '''
    progress = default_progress()
    needed_images = {}
    if row_images is False:
        row_length = 1
    # Pass 1: needed_images lists all images that are topk for some unit.
    for layer in topk:
        topresult = topk[layer].result()[1].cpu()
        for unit, row in enumerate(topresult):
            for rank, imgnum in enumerate(row[:row_length]):
                imgnum = imgnum.item()
                if imgnum not in needed_images:
                    needed_images[imgnum] = []
                needed_images[imgnum].append((layer, unit, rank))
    levels = {k: v.cpu().numpy() for k, v in levels.items()}
    row_length = len(row[:row_length])
    needed_sample = FixedSubsetSampler(sorted(needed_images.keys()))
    device = next(model.parameters()).device
    segloader = torch.utils.data.DataLoader(dataset,
            batch_size=batch_size, num_workers=num_workers,
            pin_memory=(device.type == 'cuda'),
            sampler=needed_sample)
    vizgrid, maskgrid, origrid, seggrid = [{} for _ in range(4)]
    # Pass 2: populate vizgrid with visualizations of top units.
    pool = None
    for i, batch in enumerate(
            progress(segloader, desc='Making images')):
        # Reverse transformation to get the image in byte form.
        seg, _, byte_im, _ = segrunner.run_and_segment_batch(batch, model,
                want_rgb=True)
        torch_features = model.retained_features()
        scale_offset = getattr(model, 'scale_offset', None)
        if pool is None:
            # Distribute the work across processes: create shared mmaps.
            for layer, tf in torch_features.items():
                [vizgrid[layer], maskgrid[layer], origrid[layer],
                        seggrid[layer]] = [
                    create_temp_mmap_grid((tf.shape[1],
                        byte_im.shape[1], row_length,
                        byte_im.shape[2] + gap_pixels, depth),
                        dtype='uint8',
                        fill=255)
                    for depth in [3, 4, 3, 3]]
            # Pass those mmaps to worker processes.
            pool = WorkerPool(worker=VisualizeImageWorker,
                    memmap_grid_info=[
                        {layer: (g.filename, g.shape, g.dtype)
                            for layer, g in grid.items()}
                        for grid in [vizgrid, maskgrid, origrid, seggrid]])
        byte_im = byte_im.cpu().numpy()
        numpy_seg = seg.cpu().numpy()
        features = {}
        for index in range(len(byte_im)):
            imgnum = needed_sample.samples[index + i*segloader.batch_size]
            for layer, unit, rank in needed_images[imgnum]:
                if layer not in features:
                    features[layer] = torch_features[layer].cpu().numpy()
                pool.add(layer, unit, rank,
                        byte_im[index],
                        features[layer][index, unit],
                        levels[layer][unit],
                        scale_offset[layer] if scale_offset else None,
                        numpy_seg[index])
    pool.join()
    # Pass 3: save image strips as [outdir]/[layer]/[unitnum]-[top/orig].jpg
    pool = WorkerPool(worker=SaveImageWorker)
    for layer, vg in progress(vizgrid.items(), desc='Saving images'):
        os.makedirs(os.path.join(outdir, safe_dir_name(layer),
            prefix + 'image'), exist_ok=True)
        if single_images:
           os.makedirs(os.path.join(outdir, safe_dir_name(layer),
               prefix + 's-image'), exist_ok=True)
        og, sg, mg = origrid[layer], seggrid[layer], maskgrid[layer]
        for unit in progress(range(len(vg)), desc='Units'):
            for suffix, grid in [('top.jpg', vg), ('orig.jpg', og),
                    ('seg.png', sg), ('mask.png', mg)]:
                strip = grid[unit].reshape(
                        (grid.shape[1], grid.shape[2] * grid.shape[3],
                            grid.shape[4]))
                if row_images:
                    filename = os.path.join(outdir, safe_dir_name(layer),
                            prefix + 'image', '%d-%s' % (unit, suffix))
                    pool.add(strip[:,:-gap_pixels,:].copy(), filename)
                    # Image.fromarray(strip[:,:-gap_pixels,:]).save(filename,
                    #        optimize=True, quality=80)
                if single_images:
                    single_filename = os.path.join(outdir, safe_dir_name(layer),
                        prefix + 's-image', '%d-%s' % (unit, suffix))
                    pool.add(strip[:,:strip.shape[1] // row_length
                        - gap_pixels,:].copy(), single_filename)
                    # Image.fromarray(strip[:,:strip.shape[1] // row_length
                    #     - gap_pixels,:]).save(single_filename,
                    #             optimize=True, quality=80)
    pool.join()
    # Delete the shared memory map files
    clear_global_shared_files([g.filename
        for grid in [vizgrid, maskgrid, origrid, seggrid]
        for g in grid.values()])

global_shared_files = {}
def create_temp_mmap_grid(shape, dtype, fill):
    dtype = numpy.dtype(dtype)
    filename = os.path.join(tempfile.mkdtemp(), 'temp-%s-%s.mmap' %
            ('x'.join('%d' % s for s in shape), dtype.name))
    fid = open(filename, mode='w+b')
    original = numpy.memmap(fid, dtype=dtype, mode='w+', shape=shape)
    original.fid = fid
    original[...] = fill
    global_shared_files[filename] = original
    return original

def shared_temp_mmap_grid(filename, shape, dtype):
    if filename not in global_shared_files:
        global_shared_files[filename] = numpy.memmap(
                filename, dtype=dtype, mode='r+', shape=shape)
    return global_shared_files[filename]

def clear_global_shared_files(filenames):
    for fn in filenames:
        if fn in global_shared_files:
            del global_shared_files[fn]
        try:
            os.unlink(fn)
        except OSError:
            pass

class VisualizeImageWorker(WorkerBase):
    def setup(self, memmap_grid_info):
        self.vizgrid, self.maskgrid, self.origrid, self.seggrid = [
                {layer: shared_temp_mmap_grid(*info)
                    for layer, info in grid.items()}
                for grid in memmap_grid_info]
    def work(self, layer, unit, rank,
            byte_im, acts, level, scale_offset, seg):
        self.origrid[layer][unit,:,rank,:byte_im.shape[0],:] = byte_im
        [self.vizgrid[layer][unit,:,rank,:byte_im.shape[0],:],
         self.maskgrid[layer][unit,:,rank,:byte_im.shape[0],:]] = (
                    activation_visualization(
                        byte_im,
                        acts,
                        level,
                        scale_offset=scale_offset,
                        return_mask=True))
        self.seggrid[layer][unit,:,rank,:byte_im.shape[0],:] = (
                    segment_visualization(seg, byte_im.shape[0:2]))

class SaveImageWorker(WorkerBase):
    def work(self, data, filename):
        Image.fromarray(data).save(filename, optimize=True, quality=80)

def score_tally_stats(label_category, tc, truth, cc, ic):
    pred = cc[label_category]
    total = tc[label_category][:, None]
    truth = truth[:, None]
    epsilon = 1e-20 # avoid division-by-zero
    union = pred + truth - ic
    iou = ic.double() / (union.double() + epsilon)
    arr = torch.empty(size=(2, 2) + ic.shape, dtype=ic.dtype, device=ic.device)
    arr[0, 0] = ic
    arr[0, 1] = pred - ic
    arr[1, 0] = truth - ic
    arr[1, 1] = total - union
    arr = arr.double() / total.double()
    mi = mutual_information(arr)
    je = joint_entropy(arr)
    iqr = mi / je
    iqr[torch.isnan(iqr)] = 0 # Zero out any 0/0
    return iou, iqr

def collect_quantiles_and_topk(outdir, model, segloader,
        segrunner, k=100, resolution=1024):
    '''
    Collects (estimated) quantile information and (exact) sorted top-K lists
    for every channel in the retained layers of the model.  Returns
    a map of quantiles (one RunningQuantile for each layer) along with
    a map of topk (one RunningTopK for each layer).
    '''
    device = next(model.parameters()).device
    features = model.retained_features()
    cached_quantiles = {
            layer: load_quantile_if_present(os.path.join(outdir,
                safe_dir_name(layer)), 'quantiles.npz',
                device=torch.device('cpu'))
            for layer in features }
    cached_topks = {
            layer: load_topk_if_present(os.path.join(outdir,
                safe_dir_name(layer)), 'topk.npz',
                device=torch.device('cpu'))
            for layer in features }
    if (all(value is not None for value in cached_quantiles.values()) and
        all(value is not None for value in cached_topks.values())):
        return cached_quantiles, cached_topks

    layer_batch_size = 8
    all_layers = list(features.keys())
    layer_batches = [all_layers[i:i+layer_batch_size]
            for i in range(0, len(all_layers), layer_batch_size)]

    quantiles, topks = {}, {}
    progress = default_progress()
    for layer_batch in layer_batches:
        for i, batch in enumerate(progress(segloader, desc='Quantiles')):
            # We don't actually care about the model output.
            model(batch[0].to(device))
            features = model.retained_features()
            # We care about the retained values
            for key in layer_batch:
                value = features[key]
                if topks.get(key, None) is None:
                    topks[key] = RunningTopK(k)
                if quantiles.get(key, None) is None:
                    quantiles[key] = RunningQuantile(resolution=resolution)
                topvalue = value
                if len(value.shape) > 2:
                    topvalue, _ = value.view(*(value.shape[:2] + (-1,))).max(2)
                    # Put the channel index last.
                    value = value.permute(
                            (0,) + tuple(range(2, len(value.shape))) + (1,)
                            ).contiguous().view(-1, value.shape[1])
                quantiles[key].add(value)
                topks[key].add(topvalue)
        # Save GPU memory
        for key in layer_batch:
            quantiles[key].to_(torch.device('cpu'))
            topks[key].to_(torch.device('cpu'))
    for layer in quantiles:
        save_state_dict(quantiles[layer],
                os.path.join(outdir, safe_dir_name(layer), 'quantiles.npz'))
        save_state_dict(topks[layer],
                os.path.join(outdir, safe_dir_name(layer), 'topk.npz'))
    return quantiles, topks

def collect_bincounts(outdir, model, segloader, levels, segrunner):
    '''
    Returns label_counts, category_activation_counts, and intersection_counts,
    across the data set, counting the pixels of intersection between upsampled,
    thresholded model featuremaps, with segmentation classes in the segloader.

    label_counts (independent of model): pixels across the data set that
        are labeled with the given label.
    category_activation_counts (one per layer): for each feature channel,
        pixels across the dataset where the channel exceeds the level
        threshold.  There is one count per category: activations only
        contribute to the categories for which any category labels are
        present on the images.
    intersection_counts (one per layer): for each feature channel and
        label, pixels across the dataset where the channel exceeds
        the level, and the labeled segmentation class is also present.

    This is a performance-sensitive function.  Best performance is
    achieved with a counting scheme which assumes a segloader with
    batch_size 1.
    '''
    # Load cached data if present
    (iou_scores, iqr_scores,
            total_counts, label_counts, category_activation_counts,
            intersection_counts) = {}, {}, None, None, {}, {}
    found_all = True
    for layer in model.retained_features():
        filename = os.path.join(outdir, safe_dir_name(layer), 'bincounts.npz')
        if os.path.isfile(filename):
            data = numpy.load(filename)
            iou_scores[layer] = torch.from_numpy(data['iou_scores'])
            iqr_scores[layer] = torch.from_numpy(data['iqr_scores'])
            total_counts = torch.from_numpy(data['total_counts'])
            label_counts = torch.from_numpy(data['label_counts'])
            category_activation_counts[layer] = torch.from_numpy(
                    data['category_activation_counts'])
            intersection_counts[layer] = torch.from_numpy(
                    data['intersection_counts'])
        else:
            found_all = False
    if found_all:
        return (iou_scores, iqr_scores,
            total_counts, label_counts, category_activation_counts,
            intersection_counts)

    device = next(model.parameters()).device
    labelcat, categories = segrunner.get_label_and_category_names()
    label_category = [categories.index(c) if c in categories else 0
                for l, c in labelcat]
    num_labels, num_categories = (len(n) for n in [labelcat, categories])

    # One-hot vector of category for each label
    labelcat = torch.zeros(num_labels, num_categories,
            dtype=torch.long, device=device)
    labelcat.scatter_(1, torch.from_numpy(numpy.array(label_category,
        dtype='int64')).to(device)[:,None], 1)
    # Running bincounts
    # activation_counts = {}
    assert segloader.batch_size == 1 # category_activation_counts needs this.
    category_activation_counts = {}
    intersection_counts = {}
    label_counts = torch.zeros(num_labels, dtype=torch.long, device=device)
    total_counts = torch.zeros(num_categories, dtype=torch.long, device=device)
    progress = default_progress()
    scale_offset_map = getattr(model, 'scale_offset', None)
    upsample_grids = {}
    # total_batch_categories = torch.zeros(
    #         labelcat.shape[1], dtype=torch.long, device=device)
    for i, batch in enumerate(progress(segloader, desc='Bincounts')):
        seg, batch_label_counts, _, imshape = segrunner.run_and_segment_batch(
                batch, model, want_bincount=True, want_rgb=True)
        bc = batch_label_counts.cpu()
        batch_label_counts = batch_label_counts.to(device)
        seg = seg.to(device)
        features = model.retained_features()
        # Accumulate bincounts and identify nonzeros
        label_counts += batch_label_counts[0]
        batch_labels = bc[0].nonzero()[:,0]
        batch_categories = labelcat[batch_labels].max(0)[0]
        total_counts += batch_categories * (
                seg.shape[0] * seg.shape[2] * seg.shape[3])
        for key, value in features.items():
            if key not in upsample_grids:
                upsample_grids[key] = upsample_grid(value.shape[2:],
                        seg.shape[2:], imshape,
                        scale_offset=scale_offset_map.get(key, None)
                            if scale_offset_map is not None else None,
                        dtype=value.dtype, device=value.device)
            upsampled = torch.nn.functional.grid_sample(value,
                    upsample_grids[key], padding_mode='border')
            amask = (upsampled > levels[key][None,:,None,None].to(
                upsampled.device))
            ac = amask.int().view(amask.shape[1], -1).sum(1)
            # if key not in activation_counts:
            #     activation_counts[key] = ac
            # else:
            #     activation_counts[key] += ac
            # The fastest approach: sum over each label separately!
            for label in batch_labels.tolist():
                if label == 0:
                    continue  # ignore the background label
                imask = amask * ((seg == label).max(dim=1, keepdim=True)[0])
                ic = imask.int().view(imask.shape[1], -1).sum(1)
                if key not in intersection_counts:
                    intersection_counts[key] = torch.zeros(num_labels,
                            amask.shape[1], dtype=torch.long, device=device)
                intersection_counts[key][label] += ic
            # Count activations within images that have category labels.
            # Note: This only makes sense with batch-size one
            # total_batch_categories += batch_categories
            cc = batch_categories[:,None] * ac[None,:]
            if key not in category_activation_counts:
                category_activation_counts[key] = cc
            else:
                category_activation_counts[key] += cc
    iou_scores = {}
    iqr_scores = {}
    for k in intersection_counts:
        iou_scores[k], iqr_scores[k] = score_tally_stats(
            label_category, total_counts, label_counts,
            category_activation_counts[k], intersection_counts[k])
    for k in intersection_counts:
        numpy.savez(os.path.join(outdir, safe_dir_name(k), 'bincounts.npz'),
                iou_scores=iou_scores[k].cpu().numpy(),
                iqr_scores=iqr_scores[k].cpu().numpy(),
                total_counts=total_counts.cpu().numpy(),
                label_counts=label_counts.cpu().numpy(),
                category_activation_counts=category_activation_counts[k]
                    .cpu().numpy(),
                intersection_counts=intersection_counts[k].cpu().numpy(),
                levels=levels[k].cpu().numpy())
    return (iou_scores, iqr_scores,
            total_counts, label_counts, category_activation_counts,
            intersection_counts)

def collect_cond_quantiles(outdir, model, segloader, segrunner):
    '''
    Returns maxiou and maxiou_level across the data set, one per layer.

    This is a performance-sensitive function.  Best performance is
    achieved with a counting scheme which assumes a segloader with
    batch_size 1.
    '''
    device = next(model.parameters()).device
    cached_cond_quantiles = {
            layer: load_conditional_quantile_if_present(os.path.join(outdir,
                safe_dir_name(layer)), 'cond_quantiles.npz') # on cpu
            for layer in model.retained_features() }
    label_fracs = load_npy_if_present(outdir, 'label_fracs.npy', 'cpu')
    if label_fracs is not None and all(
            value is not None for value in cached_cond_quantiles.values()):
        return cached_cond_quantiles, label_fracs

    labelcat, categories = segrunner.get_label_and_category_names()
    label_category = [categories.index(c) if c in categories else 0
                for l, c in labelcat]
    num_labels, num_categories = (len(n) for n in [labelcat, categories])

    # One-hot vector of category for each label
    labelcat = torch.zeros(num_labels, num_categories,
            dtype=torch.long, device=device)
    labelcat.scatter_(1, torch.from_numpy(numpy.array(label_category,
        dtype='int64')).to(device)[:,None], 1)
    # Running maxiou
    assert segloader.batch_size == 1 # category_activation_counts needs this.
    conditional_quantiles = {}
    label_counts = torch.zeros(num_labels, dtype=torch.long, device=device)
    pixel_count = 0
    progress = default_progress()
    scale_offset_map = getattr(model, 'scale_offset', None)
    upsample_grids = {}
    common_conditions = set()
    if label_fracs is None or label_fracs is 0:
        for i, batch in enumerate(progress(segloader, desc='label fracs')):
            seg, batch_label_counts, im, _ = segrunner.run_and_segment_batch(
                    batch, model, want_bincount=True, want_rgb=True)
            batch_label_counts = batch_label_counts.to(device)
            features = model.retained_features()
            # Accumulate bincounts and identify nonzeros
            label_counts += batch_label_counts[0]
            pixel_count += seg.shape[2] * seg.shape[3]
        label_fracs = (label_counts.cpu().float() / pixel_count)[:, None, None]
        numpy.save(os.path.join(outdir, 'label_fracs.npy'), label_fracs)

    skip_threshold = 1e-4
    skip_labels = set(i.item()
        for i in (label_fracs.view(-1) < skip_threshold).nonzero().view(-1))

    for layer in progress(model.retained_features().keys(), desc='CQ layers'):
        if cached_cond_quantiles.get(layer, None) is not None:
            conditional_quantiles[layer] = cached_cond_quantiles[layer]
            continue

        for i, batch in enumerate(progress(segloader, desc='Condquant')):
            seg, batch_label_counts, _, imshape = (
                    segrunner.run_and_segment_batch(
                         batch, model, want_bincount=True, want_rgb=True))
            bc = batch_label_counts.cpu()
            batch_label_counts = batch_label_counts.to(device)
            features = model.retained_features()
            # Accumulate bincounts and identify nonzeros
            label_counts += batch_label_counts[0]
            pixel_count += seg.shape[2] * seg.shape[3]
            batch_labels = bc[0].nonzero()[:,0]
            batch_categories = labelcat[batch_labels].max(0)[0]
            cpu_seg = None
            value = features[layer]
            if layer not in upsample_grids:
                upsample_grids[layer] = upsample_grid(value.shape[2:],
                        seg.shape[2:], imshape,
                        scale_offset=scale_offset_map.get(layer, None)
                            if scale_offset_map is not None else None,
                        dtype=value.dtype, device=value.device)
            if layer not in conditional_quantiles:
                conditional_quantiles[layer] = RunningConditionalQuantile(
                        resolution=2048)
            upsampled = torch.nn.functional.grid_sample(value,
                    upsample_grids[layer], padding_mode='border').view(
                            value.shape[1], -1)
            conditional_quantiles[layer].add(('all',), upsampled.t())
            cpu_upsampled = None
            for label in batch_labels.tolist():
                if label in skip_labels:
                    continue
                label_key = ('label', label)
                if label_key in common_conditions:
                    imask = (seg == label).max(dim=1)[0].view(-1)
                    intersected = upsampled[:, imask]
                    conditional_quantiles[layer].add(('label', label),
                            intersected.t())
                else:
                    if cpu_seg is None:
                        cpu_seg = seg.cpu()
                    if cpu_upsampled is None:
                        cpu_upsampled = upsampled.cpu()
                    imask = (cpu_seg == label).max(dim=1)[0].view(-1)
                    intersected = cpu_upsampled[:, imask]
                    conditional_quantiles[layer].add(('label', label),
                            intersected.t())
            if num_categories > 1:
                for cat in batch_categories.nonzero()[:,0]:
                    conditional_quantiles[layer].add(('cat', cat.item()),
                            upsampled.t())
            # Move the most common conditions to the GPU.
            if i and not i & (i - 1):  # if i is a power of 2:
                cq = conditional_quantiles[layer]
                common_conditions = set(cq.most_common_conditions(64))
                cq.to_('cpu', [k for k in cq.running_quantiles.keys()
                        if k not in common_conditions])
        # When a layer is done, get it off the GPU
        conditional_quantiles[layer].to_('cpu')

    label_fracs = (label_counts.cpu().float() / pixel_count)[:, None, None]

    for cq in conditional_quantiles.values():
        cq.to_('cpu')

    for layer in conditional_quantiles:
        save_state_dict(conditional_quantiles[layer],
            os.path.join(outdir, safe_dir_name(layer), 'cond_quantiles.npz'))
    numpy.save(os.path.join(outdir, 'label_fracs.npy'), label_fracs)

    return conditional_quantiles, label_fracs


def collect_maxiou(outdir, model, segloader, segrunner):
    '''
    Returns maxiou and maxiou_level across the data set, one per layer.

    This is a performance-sensitive function.  Best performance is
    achieved with a counting scheme which assumes a segloader with
    batch_size 1.
    '''
    device = next(model.parameters()).device
    conditional_quantiles, label_fracs = collect_cond_quantiles(
            outdir, model, segloader, segrunner)

    labelcat, categories = segrunner.get_label_and_category_names()
    label_category = [categories.index(c) if c in categories else 0
                for l, c in labelcat]
    num_labels, num_categories = (len(n) for n in [labelcat, categories])

    label_list = [('label', i) for i in range(num_labels)]
    category_list = [('all',)] if num_categories <= 1 else (
            [('cat', i) for i in range(num_categories)])
    max_iou, max_iou_level, max_iou_quantile = {}, {}, {}
    fracs = torch.logspace(-3, 0, 100)
    progress = default_progress()
    for layer, cq in progress(conditional_quantiles.items(), desc='Maxiou'):
        levels = cq.conditional(('all',)).quantiles(1 - fracs)
        denoms = 1 - cq.collected_normalize(category_list, levels)
        isects = (1 - cq.collected_normalize(label_list, levels)) * label_fracs
        unions = label_fracs + denoms[label_category, :, :] - isects
        iou = isects / unions
        # TODO: erase any for which threshold is bad
        max_iou[layer], level_bucket = iou.max(2)
        max_iou_level[layer] = levels[
                torch.arange(levels.shape[0])[None,:], level_bucket]
        max_iou_quantile[layer] = fracs[level_bucket]
    for layer in model.retained_features():
        numpy.savez(os.path.join(outdir, safe_dir_name(layer), 'max_iou.npz'),
            max_iou=max_iou[layer].cpu().numpy(),
            max_iou_level=max_iou_level[layer].cpu().numpy(),
            max_iou_quantile=max_iou_quantile[layer].cpu().numpy())
    return (max_iou, max_iou_level, max_iou_quantile)

def collect_iqr(outdir, model, segloader, segrunner):
    '''
    Returns iqr and iqr_level.

    This is a performance-sensitive function.  Best performance is
    achieved with a counting scheme which assumes a segloader with
    batch_size 1.
    '''
    max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou  = {}, {}, {}, {}
    max_iqr_agreement = {}
    found_all = True
    for layer in model.retained_features():
        filename = os.path.join(outdir, safe_dir_name(layer), 'iqr.npz')
        if os.path.isfile(filename):
            data = numpy.load(filename)
            max_iqr[layer] = torch.from_numpy(data['max_iqr'])
            max_iqr_level[layer] = torch.from_numpy(data['max_iqr_level'])
            max_iqr_quantile[layer] = torch.from_numpy(data['max_iqr_quantile'])
            max_iqr_iou[layer] = torch.from_numpy(data['max_iqr_iou'])
            max_iqr_agreement[layer] = torch.from_numpy(
                    data['max_iqr_agreement'])
        else:
            found_all = False
    if found_all:
        return (max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou,
            max_iqr_agreement)


    device = next(model.parameters()).device
    conditional_quantiles, label_fracs = collect_cond_quantiles(
            outdir, model, segloader, segrunner)

    labelcat, categories = segrunner.get_label_and_category_names()
    label_category = [categories.index(c) if c in categories else 0
                for l, c in labelcat]
    num_labels, num_categories = (len(n) for n in [labelcat, categories])

    label_list = [('label', i) for i in range(num_labels)]
    category_list = [('all',)] if num_categories <= 1 else (
            [('cat', i) for i in range(num_categories)])
    full_mi, full_je, full_iqr = {}, {}, {}
    fracs = torch.logspace(-3, 0, 100)
    progress = default_progress()
    for layer, cq in progress(conditional_quantiles.items(), desc='IQR'):
        levels = cq.conditional(('all',)).quantiles(1 - fracs)
        truth = label_fracs.to(device)
        preds = (1 - cq.collected_normalize(category_list, levels)
                )[label_category, :, :].to(device)
        cond_isects = 1 - cq.collected_normalize(label_list, levels).to(device)
        isects = cond_isects * truth
        unions = truth + preds - isects
        arr = torch.empty(size=(2, 2) + isects.shape, dtype=isects.dtype,
                device=device)
        arr[0, 0] = isects
        arr[0, 1] = preds - isects
        arr[1, 0] = truth - isects
        arr[1, 1] = 1 - unions
        arr.clamp_(0, 1)
        mi = mutual_information(arr)
        mi[:,:,-1] = 0  # at the 1.0 quantile should be no MI.
        # Don't trust mi when less than label_frac is less than 1e-3,
        # because our samples are too small.
        mi[label_fracs.view(-1) < 1e-3, :, :] = 0
        je = joint_entropy(arr)
        iqr = mi / je
        iqr[torch.isnan(iqr)] = 0 # Zero out any 0/0
        full_mi[layer] = mi.cpu()
        full_je[layer] = je.cpu()
        full_iqr[layer] = iqr.cpu()
        del mi, je
        agreement = isects + arr[1, 1]
        # When optimizing, maximize only over those pairs where the
        # unit is positively correlated with the label, and where the
        # threshold level is positive
        positive_iqr = iqr
        positive_iqr[agreement <= 0.8] = 0
        positive_iqr[(levels <= 0.0)[None, :, :].expand(positive_iqr.shape)] = 0
        # TODO: erase any for which threshold is bad
        maxiqr, level_bucket = positive_iqr.max(2)
        max_iqr[layer] = maxiqr.cpu()
        max_iqr_level[layer] = levels.to(device)[
                torch.arange(levels.shape[0])[None,:], level_bucket].cpu()
        max_iqr_quantile[layer] = fracs.to(device)[level_bucket].cpu()
        max_iqr_agreement[layer] = agreement[
                torch.arange(agreement.shape[0])[:, None],
                torch.arange(agreement.shape[1])[None, :],
                level_bucket].cpu()

        # Compute the iou that goes with each maximized iqr
        matching_iou = (isects[
                torch.arange(isects.shape[0])[:, None],
                torch.arange(isects.shape[1])[None, :],
                level_bucket] /
            unions[
                torch.arange(unions.shape[0])[:, None],
                torch.arange(unions.shape[1])[None, :],
                level_bucket])
        matching_iou[torch.isnan(matching_iou)] = 0
        max_iqr_iou[layer] = matching_iou.cpu()
    for layer in model.retained_features():
        numpy.savez(os.path.join(outdir, safe_dir_name(layer), 'iqr.npz'),
            max_iqr=max_iqr[layer].cpu().numpy(),
            max_iqr_level=max_iqr_level[layer].cpu().numpy(),
            max_iqr_quantile=max_iqr_quantile[layer].cpu().numpy(),
            max_iqr_iou=max_iqr_iou[layer].cpu().numpy(),
            max_iqr_agreement=max_iqr_agreement[layer].cpu().numpy(),
            full_mi=full_mi[layer].cpu().numpy(),
            full_je=full_je[layer].cpu().numpy(),
            full_iqr=full_iqr[layer].cpu().numpy())
    return (max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou,
            max_iqr_agreement)

def mutual_information(arr):
    total = 0
    for j in range(arr.shape[0]):
        for k in range(arr.shape[1]):
            joint = arr[j,k]
            ind = arr[j,:].sum(dim=0) * arr[:,k].sum(dim=0)
            term = joint * (joint / ind).log()
            term[torch.isnan(term)] = 0
            total += term
    return total.clamp_(0)

def joint_entropy(arr):
    total = 0
    for j in range(arr.shape[0]):
        for k in range(arr.shape[1]):
            joint = arr[j,k]
            term = joint * joint.log()
            term[torch.isnan(term)] = 0
            total += term
    return (-total).clamp_(0)

def information_quality_ratio(arr):
    iqr = mutual_information(arr) / joint_entropy(arr)
    iqr[torch.isnan(iqr)] = 0
    return iqr

def collect_covariance(outdir, model, segloader, segrunner):
    '''
    Returns label_mean, label_variance, unit_mean, unit_variance,
    and cross_covariance across the data set.

    label_mean, label_variance (independent of model):
        treating the label as a one-hot, each label's mean and variance.
    unit_mean, unit_variance (one per layer): for each feature channel,
        the mean and variance of the activations in that channel.
    cross_covariance (one per layer): the cross covariance between the
        labels and the units in the layer.
    '''
    device = next(model.parameters()).device
    cached_covariance = {
            layer: load_covariance_if_present(os.path.join(outdir,
                safe_dir_name(layer)), 'covariance.npz', device=device)
            for layer in model.retained_features() }
    if all(value is not None for value in cached_covariance.values()):
        return cached_covariance
    labelcat, categories = segrunner.get_label_and_category_names()
    label_category = [categories.index(c) if c in categories else 0
                for l, c in labelcat]
    num_labels, num_categories = (len(n) for n in [labelcat, categories])

    # Running covariance
    cov = {}
    progress = default_progress()
    scale_offset_map = getattr(model, 'scale_offset', None)
    upsample_grids = {}
    for i, batch in enumerate(progress(segloader, desc='Covariance')):
        seg, _, _, imshape = segrunner.run_and_segment_batch(batch, model,
                want_rgb=True)
        features = model.retained_features()
        ohfeats = multilabel_onehot(seg, num_labels, ignore_index=0)
        # Accumulate bincounts and identify nonzeros
        for key, value in features.items():
            if key not in upsample_grids:
                upsample_grids[key] = upsample_grid(value.shape[2:],
                        seg.shape[2:], imshape,
                        scale_offset=scale_offset_map.get(key, None)
                            if scale_offset_map is not None else None,
                        dtype=value.dtype, device=value.device)
            upsampled = torch.nn.functional.grid_sample(value,
                    upsample_grids[key].expand(
                        (value.shape[0],) + upsample_grids[key].shape[1:]),
                    padding_mode='border')
            if key not in cov:
                cov[key] = RunningCrossCovariance()
            cov[key].add(upsampled, ohfeats)
    for layer in cov:
        save_state_dict(cov[layer],
                os.path.join(outdir, safe_dir_name(layer), 'covariance.npz'))
    return cov

def multilabel_onehot(labels, num_labels, dtype=None, ignore_index=None):
    '''
    Converts a multilabel tensor into a onehot tensor.

    The input labels is a tensor of shape (samples, multilabels, y, x).
    The output is a tensor of shape (samples, num_labels, y, x).
    If ignore_index is specified, labels with that index are ignored.
    Each x in labels should be 0 <= x < num_labels, or x == ignore_index.
    '''
    assert ignore_index is None or ignore_index <= 0
    if dtype is None:
        dtype = torch.float
    device = labels.device
    chans = num_labels + (-ignore_index if ignore_index else 0)
    outshape = (labels.shape[0], chans) + labels.shape[2:]
    result = torch.zeros(outshape, device=device, dtype=dtype)
    if ignore_index and ignore_index < 0:
        labels = labels + (-ignore_index)
    result.scatter_(1, labels, 1)
    if ignore_index and ignore_index < 0:
        result = result[:, -ignore_index:]
    elif ignore_index is not None:
        result[:, ignore_index] = 0
    return result

def load_npy_if_present(outdir, filename, device):
    filepath = os.path.join(outdir, filename)
    if os.path.isfile(filepath):
        data = numpy.load(filepath)
        return torch.from_numpy(data).to(device)
    return 0

def load_npz_if_present(outdir, filename, varnames, device):
    filepath = os.path.join(outdir, filename)
    if os.path.isfile(filepath):
        data = numpy.load(filepath)
        numpy_result = [data[n] for n in varnames]
        return tuple(torch.from_numpy(data).to(device) for data in numpy_result)
    return None

def load_quantile_if_present(outdir, filename, device):
    filepath = os.path.join(outdir, filename)
    if os.path.isfile(filepath):
        data = numpy.load(filepath)
        result = RunningQuantile(state=data)
        result.to_(device)
        return result
    return None

def load_conditional_quantile_if_present(outdir, filename):
    filepath = os.path.join(outdir, filename)
    if os.path.isfile(filepath):
        data = numpy.load(filepath)
        result = RunningConditionalQuantile(state=data)
        return result
    return None

def load_topk_if_present(outdir, filename, device):
    filepath = os.path.join(outdir, filename)
    if os.path.isfile(filepath):
        data = numpy.load(filepath)
        result = RunningTopK(state=data)
        result.to_(device)
        return result
    return None

def load_covariance_if_present(outdir, filename, device):
    filepath = os.path.join(outdir, filename)
    if os.path.isfile(filepath):
        data = numpy.load(filepath)
        result = RunningCrossCovariance(state=data)
        result.to_(device)
        return result
    return None

def save_state_dict(obj, filepath):
    dirname = os.path.dirname(filepath)
    os.makedirs(dirname, exist_ok=True)
    dic = obj.state_dict()
    numpy.savez(filepath, **dic)

def upsample_grid(data_shape, target_shape, input_shape=None,
        scale_offset=None, dtype=torch.float, device=None):
    '''Prepares a grid to use with grid_sample to upsample a batch of
    features in data_shape to the target_shape. Can use scale_offset
    and input_shape to center the grid in a nondefault way: scale_offset
    maps feature pixels to input_shape pixels, and it is assumed that
    the target_shape is a uniform downsampling of input_shape.'''
    # Default is that nothing is resized.
    if target_shape is None:
        target_shape = data_shape
    # Make a default scale_offset to fill the image if there isn't one
    if scale_offset is None:
        scale = tuple(float(ts) / ds
                for ts, ds in zip(target_shape, data_shape))
        offset = tuple(0.5 * s - 0.5 for s in scale)
    else:
        scale, offset = (v for v in zip(*scale_offset))
        # Handle downsampling for different input vs target shape.
        if input_shape is not None:
            scale = tuple(s * (ts - 1) / (ns - 1)
                    for s, ns, ts in zip(scale, input_shape, target_shape))
            offset = tuple(o * (ts - 1) / (ns - 1)
                    for o, ns, ts in zip(offset, input_shape, target_shape))
    # Pytorch needs target coordinates in terms of source coordinates [-1..1]
    ty, tx = (((torch.arange(ts, dtype=dtype, device=device) - o)
                  * (2 / (s * (ss - 1))) - 1)
        for ts, ss, s, o, in zip(target_shape, data_shape, scale, offset))
    # Whoa, note that grid_sample reverses the order y, x -> x, y.
    grid = torch.stack(
        (tx[None,:].expand(target_shape), ty[:,None].expand(target_shape)),2
       )[None,:,:,:].expand((1, target_shape[0], target_shape[1], 2))
    return grid

def safe_dir_name(filename):
    keepcharacters = (' ','.','_','-')
    return ''.join(c
            for c in filename if c.isalnum() or c in keepcharacters).rstrip()

bargraph_palette = [
    ('#4B4CBF', '#B6B6F2'),
    ('#55B05B', '#B6F2BA'),
    ('#50BDAC', '#A5E5DB'),
    ('#81C679', '#C0FF9B'),
    ('#F0883B', '#F2CFB6'),
    ('#D4CF24', '#F2F1B6'),
    ('#D92E2B', '#F2B6B6'),
    ('#AB6BC6', '#CFAAFF'),
]

def make_svg_bargraph(labels, heights, categories,
        barheight=100, barwidth=12, show_labels=True, filename=None):
    # if len(labels) == 0:
    #     return # Nothing to do
    unitheight = float(barheight) / max(max(heights, default=1), 1)
    textheight = barheight if show_labels else 0
    labelsize = float(barwidth)
    gap = float(barwidth) / 4
    textsize = barwidth + gap
    rollup = max(heights, default=1)
    textmargin = float(labelsize) * 2 / 3
    leftmargin = 32
    rightmargin = 8
    svgwidth = len(heights) * (barwidth + gap) + 2 * leftmargin + rightmargin
    svgheight = barheight + textheight

    # create an SVG XML element
    svg = et.Element('svg', width=str(svgwidth), height=str(svgheight),
            version='1.1', xmlns='http://www.w3.org/2000/svg')

    # Draw the bar graph
    basey = svgheight - textheight
    x = leftmargin
    # Add units scale on left
    if len(heights):
        for h in [1, (max(heights) + 1) // 2, max(heights)]:
            et.SubElement(svg, 'text', x='0', y='0',
                style=('font-family:sans-serif;font-size:%dpx;' +
                'text-anchor:end;alignment-baseline:hanging;' +
                'transform:translate(%dpx, %dpx);') %
                (textsize, x - gap, basey - h * unitheight)).text = str(h)
        et.SubElement(svg, 'text', x='0', y='0',
                style=('font-family:sans-serif;font-size:%dpx;' +
                'text-anchor:middle;' +
                'transform:translate(%dpx, %dpx) rotate(-90deg)') %
                (textsize, x - gap - textsize, basey - h * unitheight / 2)
                ).text = 'units'
    # Draw big category background rectangles
    for catindex, (cat, catcount) in enumerate(categories):
        if not catcount:
            continue
        et.SubElement(svg, 'rect', x=str(x), y=str(basey - rollup * unitheight),
                width=(str((barwidth + gap) * catcount - gap)),
                height = str(rollup*unitheight),
                fill=bargraph_palette[catindex % len(bargraph_palette)][1])
        x += (barwidth + gap) * catcount
    # Draw small bars as well as 45degree text labels
    x = leftmargin
    catindex = -1
    catcount = 0
    for label, height in zip(labels, heights):
        while not catcount and catindex <= len(categories):
            catindex += 1
            catcount = categories[catindex][1]
            color = bargraph_palette[catindex % len(bargraph_palette)][0]
        et.SubElement(svg, 'rect', x=str(x), y=str(basey-(height * unitheight)),
                width=str(barwidth), height=str(height * unitheight),
                fill=color)
        x += barwidth
        if show_labels:
            et.SubElement(svg, 'text', x='0', y='0',
                style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;'+
                'transform:translate(%dpx, %dpx) rotate(-45deg);') %
                (labelsize, x, basey + textmargin)).text = readable(label)
        x += gap
        catcount -= 1
    # Text labels for each category
    x = leftmargin
    for cat, catcount in categories:
        if not catcount:
            continue
        et.SubElement(svg, 'text', x='0', y='0',
            style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;'+
            'transform:translate(%dpx, %dpx) rotate(-90deg);') %
            (textsize, x + (barwidth + gap) * catcount - gap,
                basey - rollup * unitheight + gap)).text = '%d %s' % (
                    catcount, readable(cat + ('s' if catcount != 1 else '')))
        x += (barwidth + gap) * catcount
    # Output - this is the bare svg.
    result = et.tostring(svg)
    if filename:
        f = open(filename, 'wb')
        # When writing to a file a special header is needed.
        f.write(''.join([
            '<?xml version=\"1.0\" standalone=\"no\"?>\n',
            '<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n',
            '\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n']
            ).encode('utf-8'))
        f.write(result)
        f.close()
    return result

readable_replacements = [(re.compile(r[0]), r[1]) for r in [
    (r'-[sc]$', ''),
    (r'_', ' '),
    ]]

def readable(label):
    for pattern, subst in readable_replacements:
        label= re.sub(pattern, subst, label)
    return label

def reverse_normalize_from_transform(transform):
    '''
    Crawl around the transforms attached to a dataset looking for a
    Normalize transform, and return it a corresponding ReverseNormalize,
    or None if no normalization is found.
    '''
    if isinstance(transform, torchvision.transforms.Normalize):
        return ReverseNormalize(transform.mean, transform.std)
    t = getattr(transform, 'transform', None)
    if t is not None:
        return reverse_normalize_from_transform(t)
    transforms = getattr(transform, 'transforms', None)
    if transforms is not None:
        for t in reversed(transforms):
            result = reverse_normalize_from_transform(t)
            if result is not None:
                return result
    return None

class ReverseNormalize:
    '''
    Applies the reverse of torchvision.transforms.Normalize.
    '''
    def __init__(self, mean, stdev):
        mean = numpy.array(mean)
        stdev = numpy.array(stdev)
        self.mean = torch.from_numpy(mean)[None,:,None,None].float()
        self.stdev = torch.from_numpy(stdev)[None,:,None,None].float()
    def __call__(self, data):
        device = data.device
        return data.mul(self.stdev.to(device)).add_(self.mean.to(device))

class ImageOnlySegRunner:
    def __init__(self, dataset, recover_image=None):
        if recover_image is None:
            recover_image = reverse_normalize_from_transform(dataset)
        self.recover_image = recover_image
        self.dataset = dataset
    def get_label_and_category_names(self):
        return [('-', '-')], ['-']
    def run_and_segment_batch(self, batch, model,
            want_bincount=False, want_rgb=False):
        [im] = batch
        device = next(model.parameters()).device
        if want_rgb:
            rgb = self.recover_image(im.clone()
                ).permute(0, 2, 3, 1).mul_(255).clamp(0, 255).byte()
        else:
            rgb = None
        # Stubs for seg and bc
        seg = torch.zeros(im.shape[0], 1, 1, 1, dtype=torch.long)
        bc = torch.ones(im.shape[0], 1, dtype=torch.long)
        # Run the model.
        model(im.to(device))
        return seg, bc, rgb, im.shape[2:]

class ClassifierSegRunner:
    def __init__(self, dataset, recover_image=None):
        # The dataset contains explicit segmentations
        if recover_image is None:
            recover_image = reverse_normalize_from_transform(dataset)
        self.recover_image = recover_image
        self.dataset = dataset
    def get_label_and_category_names(self):
        catnames = self.dataset.categories
        label_and_cat_names = [(readable(label),
            catnames[self.dataset.label_category[i]])
                for i, label in enumerate(self.dataset.labels)]
        return label_and_cat_names, catnames
    def run_and_segment_batch(self, batch, model,
            want_bincount=False, want_rgb=False):
        '''
        Runs the dissected model on one batch of the dataset, and
        returns a multilabel semantic segmentation for the data.
        Given a batch of size (n, c, y, x) the segmentation should
        be a (long integer) tensor of size (n, d, y//r, x//r) where
        d is the maximum number of simultaneous labels given to a pixel,
        and where r is some (optional) resolution reduction factor.
        In the segmentation returned, the label `0` is reserved for
        the background "no-label".

        In addition to the segmentation, bc, rgb, and shape are returned
        where bc is a per-image bincount counting returned label pixels,
        rgb is a viewable (n, y, x, rgb) byte image tensor for the data
        for visualizations (reversing normalizations, for example), and
        shape is the (y, x) size of the data.  If want_bincount or
        want_rgb are False, those return values may be None.
        '''
        im, seg, bc = batch
        device = next(model.parameters()).device
        if want_rgb:
            rgb = self.recover_image(im.clone()
                ).permute(0, 2, 3, 1).mul_(255).clamp(0, 255).byte()
        else:
            rgb = None
        # Run the model.
        model(im.to(device))
        return seg, bc, rgb, im.shape[2:]

class GeneratorSegRunner:
    def __init__(self, segmenter):
        # The segmentations are given by an algorithm
        if segmenter is None:
            segmenter = UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')
        self.segmenter = segmenter
        self.num_classes = len(segmenter.get_label_and_category_names()[0])
    def get_label_and_category_names(self):
        return self.segmenter.get_label_and_category_names()
    def run_and_segment_batch(self, batch, model,
            want_bincount=False, want_rgb=False):
        '''
        Runs the dissected model on one batch of the dataset, and
        returns a multilabel semantic segmentation for the data.
        Given a batch of size (n, c, y, x) the segmentation should
        be a (long integer) tensor of size (n, d, y//r, x//r) where
        d is the maximum number of simultaneous labels given to a pixel,
        and where r is some (optional) resolution reduction factor.
        In the segmentation returned, the label `0` is reserved for
        the background "no-label".

        In addition to the segmentation, bc, rgb, and shape are returned
        where bc is a per-image bincount counting returned label pixels,
        rgb is a viewable (n, y, x, rgb) byte image tensor for the data
        for visualizations (reversing normalizations, for example), and
        shape is the (y, x) size of the data.  If want_bincount or
        want_rgb are False, those return values may be None.
        '''
        device = next(model.parameters()).device
        z_batch = batch[0]
        tensor_images = model(z_batch.to(device))
        seg = self.segmenter.segment_batch(tensor_images, downsample=2)
        if want_bincount:
            index = torch.arange(z_batch.shape[0],
                    dtype=torch.long, device=device)
            bc = (seg + index[:, None, None, None] * self.num_classes).view(-1
                ).bincount(minlength=z_batch.shape[0] * self.num_classes)
            bc = bc.view(z_batch.shape[0], self.num_classes)
        else:
            bc = None
        if want_rgb:
            images = ((tensor_images + 1) / 2 * 255)
            rgb = images.permute(0, 2, 3, 1).clamp(0, 255).byte()
        else:
            rgb = None
        return seg, bc, rgb, tensor_images.shape[2:]