import os, errno, numpy, torch, csv, re, shutil, os, zipfile
from collections import OrderedDict
from torchvision.datasets.folder import default_loader
from torchvision import transforms
from scipy import ndimage
from urllib.request import urlopen

class BrodenDataset(torch.utils.data.Dataset):
    '''
    A multicategory segmentation data set.

    Returns three streams:
    (1) The image (3, h, w).
    (2) The multicategory segmentation (labelcount, h, w).
    (3) A bincount of pixels in the segmentation (labelcount).

    Net dissect also assumes that the dataset object has three properties
    with human-readable labels:

    ds.labels = ['red', 'black', 'car', 'tree', 'grid', ...]
    ds.categories = ['color', 'part', 'object', 'texture']
    ds.label_category = [0, 0, 2, 2, 3, ...] # The category for each label
    '''
    def __init__(self, directory='dataset/broden', resolution=384,
            split='train', categories=None,
            transform=None, transform_segment=None,
            download=False, size=None, include_bincount=True,
            broden_version=1, max_segment_depth=6):
        assert resolution in [224, 227, 384]
        if download:
            ensure_broden_downloaded(directory, resolution, broden_version)
        self.directory = directory
        self.resolution = resolution
        self.resdir = os.path.join(directory, 'broden%d_%d' %
                (broden_version, resolution))
        self.loader = default_loader
        self.transform = transform
        self.transform_segment = transform_segment
        self.include_bincount = include_bincount
        # The maximum number of multilabel layers that coexist at an image.
        self.max_segment_depth = max_segment_depth
        with open(os.path.join(self.resdir, 'category.csv'),
                encoding='utf-8') as f:
            self.category_info = OrderedDict()
            for row in csv.DictReader(f):
                self.category_info[row['name']] = row
        if categories is not None:
            # Filter out unused categories
            categories = set([c for c in categories if c in self.category_info])
            for cat in list(self.category_info.keys()):
                if cat not in categories:
                    del self.category_info[cat]
        categories = list(self.category_info.keys())
        self.categories = categories

        # Filter out unneeded images.
        with open(os.path.join(self.resdir, 'index.csv'),
                encoding='utf-8') as f:
            all_images = [decode_index_dict(r) for r in csv.DictReader(f)]
        self.image = [row for row in all_images
            if index_has_any_data(row, categories) and row['split'] == split]
        if size is not None:
            self.image = self.image[:size]
        with open(os.path.join(self.resdir, 'label.csv'),
                encoding='utf-8') as f:
            self.label_info = build_dense_label_array([
                decode_label_dict(r) for r in csv.DictReader(f)])
            self.labels = [l['name'] for l in self.label_info]
        # Build dense remapping arrays for labels, so that you can
        # get dense ranges of labels for each category.
        self.category_map = {}
        self.category_unmap = {}
        self.category_label = {}
        for cat in self.categories:
            with open(os.path.join(self.resdir, 'c_%s.csv' % cat),
                    encoding='utf-8') as f:
                c_data = [decode_label_dict(r) for r in csv.DictReader(f)]
            self.category_unmap[cat], self.category_map[cat] = (
                    build_numpy_category_map(c_data))
            self.category_label[cat] = build_dense_label_array(
                    c_data, key='code')
        self.num_labels = len(self.labels)
        # Primary categories for each label is the category in which it
        # appears with the maximum coverage.
        self.label_category = numpy.zeros(self.num_labels, dtype=int)
        for i in range(self.num_labels):
            maxcoverage, self.label_category[i] = max(
               (self.category_label[cat][self.category_map[cat][i]]['coverage']
                    if i < len(self.category_map[cat])
                       and self.category_map[cat][i] else 0, ic)
                for ic, cat in enumerate(categories))

    def __len__(self):
        return len(self.image)

    def __getitem__(self, idx):
        record = self.image[idx]
        # example record: {
        #    'image': 'opensurfaces/25605.jpg', 'split': 'train',
        #    'ih': 384, 'iw': 384, 'sh': 192, 'sw': 192,
        #    'color': ['opensurfaces/25605_color.png'],
        #    'object': [], 'part': [],
        #    'material': ['opensurfaces/25605_material.png'],
        #    'scene': [], 'texture': []}
        image = self.loader(os.path.join(self.resdir, 'images',
            record['image']))
        segment = numpy.zeros(shape=(self.max_segment_depth,
            record['sh'], record['sw']), dtype=int)
        if self.include_bincount:
            bincount = numpy.zeros(shape=(self.num_labels,), dtype=int)
        depth = 0
        for cat in self.categories:
            for layer in record[cat]:
                if isinstance(layer, int):
                    segment[depth,:,:] = layer
                    if self.include_bincount:
                        bincount[layer] += segment.shape[1] * segment.shape[2]
                else:
                    png = numpy.asarray(self.loader(os.path.join(
                        self.resdir, 'images', layer)))
                    segment[depth,:,:] = png[:,:,0] + png[:,:,1] * 256
                    if self.include_bincount:
                        bincount += numpy.bincount(segment[depth,:,:].flatten(),
                            minlength=self.num_labels)
                depth += 1
        if self.transform:
            image = self.transform(image)
        if self.transform_segment:
            segment = self.transform_segment(segment)
        if self.include_bincount:    
            bincount[0] = 0
            return (image, segment, bincount)
        else:
            return (image, segment)

def build_dense_label_array(label_data, key='number', allow_none=False):
    '''
    Input: set of rows with 'number' fields (or another field name key).
    Output: array such that a[number] = the row with the given number.
    '''
    result = [None] * (max([d[key] for d in label_data]) + 1)
    for d in label_data:
        result[d[key]] = d
    # Fill in none
    if not allow_none:
        example = label_data[0]
        def make_empty(k):
            return dict((c, k if c is key else type(v)())
                    for c, v in example.items())
        for i, d in enumerate(result):
            if d is None:
                result[i] = dict(make_empty(i))
    return result

def build_numpy_category_map(map_data, key1='code', key2='number'):
    '''
    Input: set of rows with 'number' fields (or another field name key).
    Output: array such that a[number] = the row with the given number.
    '''
    results = list(numpy.zeros((max([d[key] for d in map_data]) + 1),
            dtype=numpy.int16) for key in (key1, key2))
    for d in map_data:
        results[0][d[key1]] = d[key2]
        results[1][d[key2]] = d[key1]
    return results

def index_has_any_data(row, categories):
    for c in categories:
        for data in row[c]:
            if data: return True
    return False

def decode_label_dict(row):
    result = {}
    for key, val in row.items():
        if key == 'category':
            result[key] = dict((c, int(n))
                for c, n in [re.match('^([^(]*)\(([^)]*)\)$', f).groups()
                    for f in val.split(';')])
        elif key == 'name':
            result[key] = val
        elif key == 'syns':
            result[key] = val.split(';')
        elif re.match('^\d+$', val):
            result[key] = int(val)
        elif re.match('^\d+\.\d*$', val):
            result[key] = float(val)
        else:
            result[key] = val
    return result

def decode_index_dict(row):
    result = {}
    for key, val in row.items():
        if key in ['image', 'split']:
            result[key] = val
        elif key in ['sw', 'sh', 'iw', 'ih']:
            result[key] = int(val)
        else:
            item = [s for s in val.split(';') if s]
            for i, v in enumerate(item):
                if re.match('^\d+$', v):
                    item[i] = int(v)
            result[key] = item
    return result

class ScaleSegmentation:
    '''
    Utility for scaling segmentations, using nearest-neighbor zooming.
    '''
    def __init__(self, target_height, target_width):
        self.target_height = target_height
        self.target_width = target_width
    def __call__(self, seg):
        ratio = (1, self.target_height / float(seg.shape[1]),
                self.target_width / float(seg.shape[2]))
        return ndimage.zoom(seg, ratio, order=0)

def scatter_batch(seg, num_labels, omit_zero=True, dtype=torch.uint8):
    '''
    Utility for scattering semgentations into a one-hot representation.
    '''
    result = torch.zeros(*((seg.shape[0], num_labels,) + seg.shape[2:]),
            dtype=dtype, device=seg.device)
    result.scatter_(1, seg, 1)
    if omit_zero:
        result[:,0] = 0
    return result

def ensure_broden_downloaded(directory, resolution, broden_version=1):
    assert resolution in [224, 227, 384]
    baseurl = 'http://netdissect.csail.mit.edu/data/'
    dirname = 'broden%d_%d' % (broden_version, resolution)
    if os.path.isfile(os.path.join(directory, dirname, 'index.csv')):
        return # Already downloaded
    zipfilename = 'broden1_%d.zip' % resolution
    download_dir = os.path.join(directory, 'download')
    os.makedirs(download_dir, exist_ok=True)
    full_zipfilename = os.path.join(download_dir, zipfilename)
    if not os.path.exists(full_zipfilename):
        url = '%s/%s' % (baseurl, zipfilename)
        print('Downloading %s' % url)
        data = urlopen(url)
        with open(full_zipfilename, 'wb') as f:
            f.write(data.read())
    print('Unzipping %s' % zipfilename)
    with zipfile.ZipFile(full_zipfilename, 'r') as zip_ref:
        zip_ref.extractall(directory)
    assert os.path.isfile(os.path.join(directory, dirname, 'index.csv'))

def test_broden_dataset():
    '''
    Testing code.
    '''
    bds = BrodenDataset('dataset/broden', resolution=384,
            transform=transforms.Compose([
                        transforms.Resize(224),
                        transforms.ToTensor()]),
            transform_segment=transforms.Compose([
                        ScaleSegmentation(224, 224)
                        ]),
            include_bincount=True)
    loader = torch.utils.data.DataLoader(bds, batch_size=100, num_workers=24)
    for i in range(1,20):
        print(bds.label[i]['name'],
                list(bds.category.keys())[bds.primary_category[i]])
    for i, (im, seg, bc) in enumerate(loader):
        print(i, im.shape, seg.shape, seg.max(), bc.shape)

if __name__ == '__main__':
    test_broden_dataset()