Spaces:
Build error
Build error
from functools import partial | |
import numpy | |
import os | |
import re | |
import random | |
import signal | |
import csv | |
from PIL import Image | |
import settings | |
import numpy as np | |
from collections import OrderedDict | |
import cv2 | |
# from scipy.misc import imread | |
from multiprocessing import Pool, cpu_count | |
from multiprocessing.pool import ThreadPool | |
from scipy.ndimage.interpolation import zoom | |
import sys | |
import pickle | |
def load_csv(filename, readfields=None): | |
def convert(value): | |
if re.match(r'^-?\d+$', value): | |
try: | |
return int(value) | |
except: | |
pass | |
if re.match(r'^-?[\.\d]+(?:e[+=]\d+)$', value): | |
try: | |
return float(value) | |
except: | |
pass | |
return value | |
with open(filename) as f: | |
reader = csv.DictReader(f) | |
result = [{k: convert(v) for k, v in row.items()} for row in reader] | |
if readfields is not None: | |
readfields.extend(reader.fieldnames) | |
return result | |
class AbstractSegmentation: | |
def all_names(self, category, j): | |
raise NotImplementedError | |
def size(self, split=None): | |
return 0 | |
def filename(self, i): | |
raise NotImplementedError | |
def metadata(self, i): | |
return self.filename(i) | |
def resolve_segmentation(cls, m): | |
return {} | |
def name(self, category, i): | |
''' | |
Default implemtnation for segmentation_data, | |
utilizing all_names. | |
''' | |
all_names = self.all_names(category, i) | |
return all_names[0] if len(all_names) else '' | |
def segmentation_data(self, category, i, c=0, full=False): | |
''' | |
Default implemtnation for segmentation_data, | |
utilizing metadata and resolve_segmentation. | |
''' | |
segs = self.resolve_segmentation( | |
self.metadata(i), categories=[category]) | |
if category not in segs: | |
return 0 | |
data = segs[category] | |
if not full and len(data.shape) >= 3: | |
return data[0] | |
return data | |
class SegmentationData(AbstractSegmentation): | |
''' | |
Represents and loads a multi-channel segmentation represented with | |
a series of csv files: index.csv lists the images together with | |
any label data avilable in each category; category.csv lists | |
the categories of segmentations available; and label.csv lists the | |
numbers used to describe each label class. In addition, the categories | |
each have a separate c_*.csv file describing a dense coding of labels. | |
isImageSet - if True, duplicate rgb images in index.csv will be removed | |
''' | |
def __init__(self, directory, categories=None, require_all=False, isImageSet=False): | |
directory = os.path.expanduser(directory) | |
self.directory = directory | |
with open(os.path.join(directory, settings.INDEX_FILE)) as f: | |
self.image = [decode_index_dict(r) for r in csv.DictReader(f)] | |
# self.actualFeatIdx = None ### use this dict in tally to map duplicates idx to nonduplicates | |
# if isImageSet is True: | |
# self.actualFeatIdx = {} | |
# self.newImgSet = [] | |
# self.duplicateDict = {} | |
# for imgRGBIdx, imgRGBData in enumerate(self.image): | |
# if imgRGBData["image"] not in self.duplicateDict: | |
# self.newImgSet.append(imgRGBData) | |
# self.duplicateDict[imgRGBData["image"]] = len(self.newImgSet) - 1 | |
# # print("self.duplicateDict[imgRGBData[image]]: ", self.duplicateDict[imgRGBData["image"]]) | |
# self.actualFeatIdx[imgRGBIdx] = self.duplicateDict[imgRGBData["image"]] ### Start at 0 | |
# self.image = self.newImgSet | |
# print("data_set.actualFeatIdx: ", self.actualFeatIdx) | |
# sys.exit() | |
# print("self.image: ", self.image) ### list | |
# print("type self.image: ", type(self.image)) ### list | |
# print("len self.image: ", len(self.image)) ### total rows in index.csv | |
# sys.exit() | |
with open(os.path.join(directory, 'category.csv')) as f: | |
self.category = OrderedDict() | |
for row in csv.DictReader(f): | |
if categories and row['name'] in categories: | |
self.category[row['name']] = row | |
categories = self.category.keys() | |
with open(os.path.join(directory, 'label.csv')) as f: | |
label_data = [decode_label_dict(r) for r in csv.DictReader(f)] | |
self.label = build_dense_label_array(label_data) ### Len is label_data+1 (from csv), 0 index is None | |
# print("self.label[0]: ", self.label[0]) ### None value (no class specified) | |
# print("self.label: ", self.label) | |
# sys.exit() | |
# Filter out images with insufficient data | |
filter_fn = partial( | |
index_has_all_data if require_all else index_has_any_data, | |
categories=categories) | |
self.image = [row for row in self.image if filter_fn(row)] | |
# Build dense remapping arrays for labels, so that you can | |
# get dense ranges of labels for each category. | |
self.category_map = {} | |
self.category_unmap = {} | |
self.category_label = {} | |
for cat in self.category: | |
with open(os.path.join(directory, 'c_%s.csv' % cat)) as f: | |
c_data = [decode_label_dict(r) for r in csv.DictReader(f)] | |
self.category_unmap[cat], self.category_map[cat] = ( | |
build_numpy_category_map(c_data)) | |
self.category_label[cat] = build_dense_label_array( | |
c_data, key='code') | |
# print("category_unmap: ", self.category_unmap) | |
self.labelcat = self.onehot(self.primary_categories_per_index()) # (480,1) | |
### labelcat - all ones | |
def primary_categories_per_index(ds): | |
''' | |
Returns an array of primary category numbers for each label, where the | |
first category listed in ds.category_names is given category number 0. | |
''' | |
catmap = {} | |
categories = ds.category_names() | |
for cat in categories: | |
imap = ds.category_index_map(cat) | |
if len(imap) < ds.label_size(None): | |
imap = np.concatenate((imap, np.zeros( | |
ds.label_size(None) - len(imap), dtype=imap.dtype))) | |
catmap[cat] = imap | |
result = [] | |
for i in range(ds.label_size(None)): | |
maxcov, maxcat = max( | |
(ds.coverage(cat, catmap[cat][i]) if catmap[cat][i] else 0, ic) | |
for ic, cat in enumerate(categories)) | |
result.append(maxcat) | |
return np.array(result) | |
def onehot(self, arr, minlength=None): | |
''' | |
Expands an array of integers in one-hot encoding by adding a new last | |
dimension, leaving zeros everywhere except for the nth dimension, where | |
the original array contained the integer n. The minlength parameter is | |
used to indcate the minimum size of the new dimension. | |
''' | |
length = np.amax(arr) + 1 | |
if minlength is not None: | |
length = max(minlength, length) | |
result = np.zeros(arr.shape + (length,)) | |
result[list(np.indices(arr.shape)) + [arr]] = 1 | |
return result | |
def all_names(self, category, j): | |
'''All English synonyms for the given label''' | |
if category is not None: | |
j = self.category_unmap[category][j] | |
return [self.label[j]['name']] + self.label[j]['syns'] | |
def size(self, split=None): | |
'''The number of images in this data set.''' | |
if split is None: | |
return len(self.image) | |
return len([im for im in self.image if im['split'] == split]) | |
def filename(self, i): | |
'''The filename of the ith jpeg (original image).''' | |
return os.path.join(self.directory, 'images', self.image[i]['image']) | |
def split(self, i): | |
'''Which split contains item i.''' | |
return self.image[i]['split'] | |
def metadata(self, i): | |
'''Extract metadata for image i, For efficient data loading.''' | |
return self.directory, self.image[i] | |
meta_categories = ['image', 'split', 'ih', 'iw', 'sh', 'sw'] | |
def resolve_segmentation(cls, m, categories=None, segm_to_label=None): | |
''' | |
Resolves a full segmentation, potentially in a differenct process, | |
for efficient multiprocess data loading. | |
''' | |
directory, row = m | |
result = {} | |
for cat, d in row.items(): | |
if cat in cls.meta_categories: | |
continue | |
if not wants(cat, categories): | |
continue | |
if all(isinstance(data, int) for data in d): | |
result[cat] = d | |
continue | |
out = numpy.empty((len(d), row['sh'], row['sw']), dtype=numpy.int16) | |
for i, channel in enumerate(d): | |
if isinstance(channel, int): | |
out[i] = channel | |
else: | |
segmFilenameSplit = channel.split('/') | |
segmFileToLabelName = segmFilenameSplit[-2] + "/" + segmFilenameSplit[-1] | |
if 'seg_label' not in result: | |
result['seg_label'] = [] | |
result['seg_label'].append(segm_to_label[segmFileToLabelName]) | |
# print("os.path.join(directory, 'images', channel)): ", os.path.join(directory, 'images', channel)) | |
rgb = cv2.resize(cv2.imread(os.path.join(directory, 'images', channel)), (settings.SEGM_SIZE, settings.SEGM_SIZE)) | |
rgb[:,:,0] = 0 | |
rgb[:,:,2] = np.where(rgb[:,:,2]>0, 235, 0) | |
rgb[:,:,1] = np.where(rgb[:,:,2]>0, 1, 0) | |
out[i] = rgb[:,:,0] + rgb[:,:,1] * 256 | |
result[cat] = out | |
return result, (row['sh'], row['sw']) | |
def label_size(self, category=None): | |
''' | |
Returns the number of distinct labels (plus zero), i.e., one | |
more than the maximum label number. If a category is specified, | |
returns the number of distinct labels within that category. | |
''' | |
if category is None: | |
return len(self.label) | |
else: | |
return len(self.category_unmap[category]) | |
def name(self, category, j): | |
''' | |
Returns an English name for the jth label. If a category is | |
specified, returns the name for the category-specific nubmer j. | |
If category=None, then treats j as a fully unified index number. | |
''' | |
if category is not None: | |
j = self.category_unmap[category][j] | |
return self.label[j]['name'] | |
def frequency(self, category, j): | |
''' | |
Returns the number of images for which the label appears. | |
''' | |
if category is not None: | |
return self.category_label[category][j]['frequency'] | |
return self.label[j]['frequency'] | |
def coverage(self, category, j): | |
''' | |
Returns the pixel coverage of the label in units of whole-images. | |
''' | |
if category is not None: | |
return self.category_label[category][j]['coverage'] | |
return self.label[j]['coverage'] | |
def category_names(self): | |
''' | |
Returns the set of category names. | |
''' | |
return list(self.category.keys()) | |
def category_frequency(self, category): | |
''' | |
Returns the number of images touched by a category. | |
''' | |
return float(self.category[category]['frequency']) | |
def primary_categories_per_index(self, categories=None): | |
''' | |
Returns an array of primary category numbers for each label, where | |
catagories are indexed according to the list of categories passed, | |
or self.category_names() if none. | |
''' | |
if categories is None: | |
categories = self.category_names() | |
# Make lists which are nonzero for labels in a category | |
catmap = {} | |
for cat in categories: | |
imap = self.category_index_map(cat) | |
if len(imap) < self.label_size(None): | |
imap = numpy.concatenate((imap, numpy.zeros( | |
self.label_size(None) - len(imap), dtype=imap.dtype))) | |
catmap[cat] = imap | |
# For each label, find the category with maximum coverage. | |
result = [] | |
for i in range(self.label_size(None)): | |
maxcov, maxcat = max( | |
(self.coverage(cat, catmap[cat][i]) | |
if catmap[cat][i] else 0, ic) | |
for ic, cat in enumerate(categories)) | |
result.append(maxcat) | |
# Return the max-coverage cateogry for each label. | |
return numpy.array(result) | |
def segmentation_data(self, category, i, c=0, full=False, out=None): | |
''' | |
Returns a 2-d numpy matrix with segmentation data for the ith image, | |
restricted to the given category. By default, maps all label numbers | |
to the category-specific dense mapping described in the c_*.csv | |
listing; but can be asked to expose the fully unique indexing by | |
using full=True. | |
''' | |
row = self.image[i] | |
data_channels = row.get(category, ()) | |
if c >= len(data_channels): | |
channel = 0 # Deal with unlabeled data in this category | |
else: | |
channel = data_channels[c] | |
if out is None: | |
out = numpy.empty((row['sh'], row['sw']), dtype=numpy.int16) | |
if isinstance(channel, int): | |
if not full: | |
channel = self.category_map[category][channel] | |
out[:,:] = channel # Single-label for the whole image | |
return out | |
png = cv2.resize(cv2.imread(os.path.join(self.directory, 'images', channel)), (settings.SEGM_SIZE, settings.SEGM_SIZE)) | |
png[:,:,0] = 0 | |
png[:,:,2] = np.where(png[:,:,2]>0, 235, 0) | |
png[:,:,1] = np.where(png[:,:,2]>0, 1, 0) | |
if full: | |
# Full case: just combine png channels. | |
out[...] = png[:,:,0] + png[:,:,1] * 256 | |
else: | |
# Dense case: combine png channels and apply the category map. | |
catmap = self.category_map[category] | |
out[...] = catmap[png[:,:,0] + png[:,:,1] * 256] | |
return out | |
def full_segmentation_data(self, i, | |
categories=None, max_depth=None, out=None): | |
''' | |
Returns a 3-d numpy tensor with segmentation data for the ith image, | |
with multiple layers represnting multiple lables for each pixel. | |
The depth is variable depending on available data but can be | |
limited to max_depth. | |
''' | |
row = self.image[i] | |
if categories: | |
groups = [d for cat, d in row.items() if cat in categories and d] | |
else: | |
groups = [d for cat, d in row.items() if d and ( | |
cat not in self.meta_categories)] | |
depth = sum(len(c) for c in groups) | |
if max_depth is not None: | |
depth = min(depth, max_depth) | |
# Allocate an array if not already allocated. | |
if out is None: | |
out = numpy.empty((depth, row['sh'], row['sw']), dtype=numpy.int16) | |
i = 0 | |
# Stack up the result segmentation one channel at a time | |
for group in groups: | |
for channel in group: | |
if isinstance(channel, int): | |
out[i] = channel | |
else: | |
png = cv2.resize(cv2.imread(os.path.join(self.directory, 'images', channel)), (settings.SEGM_SIZE, settings.SEGM_SIZE)) | |
png[:,:,0] = 0 | |
png[:,:,2] = np.where(png[:,:,2]>0, 235, 0) | |
png[:,:,1] = np.where(png[:,:,2]>0, 1, 0) | |
out[i] = png[:,:,0] + png[:,:,1] * 256 | |
i += 1 | |
if i == depth: | |
return out | |
# Return above when we get up to depth | |
assert False | |
def category_index_map(self, category): | |
return numpy.array(self.category_map[category]) | |
def build_dense_label_array(label_data, key='number', allow_none=False): | |
''' | |
Input: set of rows with 'number' fields (or another field name key). | |
Output: array such that a[number] = the row with the given number. | |
''' | |
result = [None] * (max([d[key] for d in label_data]) + 1) | |
for d in label_data: | |
result[d[key]] = d | |
# Fill in none | |
if not allow_none: | |
example = label_data[0] | |
def make_empty(k): | |
return dict((c, k if c is key else type(v)()) | |
for c, v in example.items()) | |
for i, d in enumerate(result): | |
if d is None: | |
result[i] = dict(make_empty(i)) | |
return result | |
def build_numpy_category_map(map_data, key1='code', key2='number'): | |
''' | |
Input: set of rows with 'number' fields (or another field name key). | |
Output: array such that a[number] = the row with the given number. | |
''' | |
results = list(numpy.zeros((max([d[key] for d in map_data]) + 1), | |
dtype=numpy.int16) for key in (key1, key2)) | |
for d in map_data: | |
results[0][d[key1]] = d[key2] | |
results[1][d[key2]] = d[key1] | |
return results | |
def decode_label_dict(row): | |
result = {} | |
for key, val in row.items(): | |
if key == 'category': | |
result[key] = dict((c, int(n)) | |
for c, n in [re.match('^([^(]*)\(([^)]*)\)$', f).groups() | |
for f in val.split(';')]) | |
elif key == 'name': | |
result[key] = val | |
elif key == 'syns': | |
result[key] = val.split(';') | |
elif re.match('^\d+$', val): | |
result[key] = int(val) | |
elif re.match('^\d+\.\d*$', val): | |
result[key] = float(val) | |
else: | |
result[key] = val | |
return result | |
def decode_index_dict(row): | |
result = {} | |
for key, val in row.items(): | |
if key in ['image', 'split']: | |
result[key] = val | |
elif key in ['sw', 'sh', 'iw', 'ih']: | |
result[key] = int(val) | |
else: | |
item = [s for s in val.split(';') if s] | |
for i, v in enumerate(item): | |
if re.match('^\d+$', v): | |
item[i] = int(v) | |
result[key] = item | |
return result | |
def index_has_any_data(row, categories): | |
for c in categories: | |
for data in row[c]: | |
if data: return True | |
return False | |
def index_has_all_data(row, categories): | |
for c in categories: | |
cat_has = False | |
for data in row[c]: | |
if data: | |
cat_has = True | |
break | |
if not cat_has: | |
return False | |
return True | |
class SegmentationPrefetcher: | |
''' | |
SegmentationPrefetcher will prefetch a bunch of segmentation | |
images using a multiprocessing pool, so you do not have to wait | |
around while the files get opened and decoded. Just request | |
batches of images and segmentations calling fetch_batch(). | |
''' | |
def __init__(self, segmentation, split=None, randomize=False, | |
segmentation_shape=None, categories=None, once=False, | |
start=None, end=None, batch_size=4, ahead=4, thread=False): | |
''' | |
Constructor arguments: | |
segmentation: The AbstractSegmentation to load. | |
split: None for no filtering, or 'train' or 'val' etc. | |
randomize: True to randomly shuffle order, or a random seed. | |
categories: a list of categories to include in each batch. | |
batch_size: number of data items for each batch. | |
ahead: the number of data items to prefetch ahead. | |
''' | |
self.segmentation = segmentation | |
self.segm_to_label = None | |
with open(settings.SEGM_TO_LABEL_PKL, 'rb') as f: | |
self.segm_to_label = pickle.load(f) | |
self.split = split | |
self.randomize = randomize | |
self.random = random.Random() | |
if randomize is not True: | |
self.random.seed(randomize) | |
self.categories = categories | |
self.once = once | |
self.batch_size = batch_size | |
self.ahead = ahead | |
# Initialize the multiprocessing pool | |
n_procs = cpu_count() | |
if thread: | |
self.pool = ThreadPool(processes=n_procs) | |
else: | |
original_sigint_handler = setup_sigint() | |
self.pool = Pool(processes=n_procs, initializer=setup_sigint) | |
restore_sigint(original_sigint_handler) | |
# Prefilter the image indexes of interest | |
if start is None: | |
start = 0 | |
if end is None: | |
end = segmentation.size() | |
self.indexes = range(start, end) | |
if split: | |
self.indexes = [i for i in self.indexes | |
if segmentation.split(i) == split] | |
if self.randomize: | |
self.random.shuffle(self.indexes) | |
self.index = 0 | |
self.result_queue = [] | |
self.segmentation_shape = segmentation_shape | |
# Get dense catmaps | |
self.catmaps = [ | |
segmentation.category_index_map(cat) if cat != 'image' else None | |
for cat in categories] | |
def next_job(self): | |
if self.index < 0: | |
return None | |
j = self.indexes[self.index] | |
result = (j, | |
self.segmentation.__class__, | |
self.segmentation.metadata(j), | |
self.segmentation.filename(j), | |
self.categories, | |
self.segm_to_label, | |
self.segmentation_shape) | |
self.index += 1 | |
if self.index >= len(self.indexes): | |
if self.once: | |
self.index = -1 | |
else: | |
self.index = 0 | |
if self.randomize: | |
# Reshuffle every time through | |
self.random.shuffle(self.indexes) | |
return result | |
def batches(self): | |
'''Iterator for all batches''' | |
while True: | |
batch = self.fetch_batch() | |
if batch is None: | |
break | |
else: | |
yield batch | |
# def batches(self): | |
# '''Iterator for all batches''' | |
# while True: | |
# batch = self.fetch_batch() | |
# if batch is None: | |
# raise StopIteration | |
# yield batch | |
def fetch_batch(self): | |
'''Returns a single batch as an array of dictionaries.''' | |
try: | |
self.refill_tasks() | |
if len(self.result_queue) == 0: | |
return None | |
result = self.result_queue.pop(0) | |
return result.get(31536000) | |
except KeyboardInterrupt: | |
print("Caught KeyboardInterrupt, terminating workers") | |
self.pool.terminate() | |
raise | |
def fetch_tensor_batch(self, bgr_mean=None, global_labels=False): | |
'''Iterator for batches as arrays of tensors.''' | |
batch = self.fetch_batch() | |
return self.form_caffe_tensors(batch, bgr_mean, global_labels) | |
def tensor_batches(self, bgr_mean=None, global_labels=False): | |
'''Returns a single batch as an array of tensors, one per category.''' | |
while True: | |
batch = self.fetch_tensor_batch( | |
bgr_mean=bgr_mean, global_labels=global_labels) | |
if batch is None: | |
break | |
else: | |
yield batch | |
def form_caffe_tensors(self, batch, bgr_mean=None, global_labels=False): | |
# Assemble a batch in [{'cat': data,..},..] format into | |
# an array of batch tensors, the first for the image, and the | |
# remaining for each category in self.categories, in order. | |
# This also applies a random flip if needed | |
if batch is None: | |
return None | |
batches = [[] for c in self.categories] | |
for record in batch: | |
default_shape = (1, record['sh'], record['sw']) | |
for c, cat in enumerate(self.categories): | |
if cat == 'image': | |
# Normalize image with right RGB order and mean | |
batches[c].append(normalize_image( | |
record[cat], bgr_mean)) | |
elif global_labels: | |
batches[c].append(normalize_label( | |
record[cat], default_shape, flatten=True)) | |
else: | |
catmap = self.catmaps[c] | |
batches[c].append(catmap[normalize_label( | |
record[cat], default_shape, flatten=True)]) | |
return [numpy.concatenate(tuple(m[numpy.newaxis] for m in b)) | |
for b in batches] | |
def refill_tasks(self): | |
# It will call the sequencer to ask for a sequence | |
# of batch_size jobs (indexes with categories) | |
# Then it will call pool.map_async | |
while len(self.result_queue) < self.ahead: | |
data = [] | |
while len(data) < self.batch_size: | |
job = self.next_job() | |
if job is None: | |
break | |
data.append(job) | |
if len(data) == 0: | |
return | |
self.result_queue.append(self.pool.map_async(prefetch_worker, data)) | |
def close(self): | |
while len(self.result_queue): | |
result = self.result_queue.pop(0) | |
if result is not None: | |
result.wait(0.001) | |
self.pool.close() | |
self.poool.cancel_join_thread() | |
def prefetch_worker(d): | |
if d is None: | |
return None | |
j, typ, m, fn, categories, segm_to_label, segmentation_shape = d | |
segs, shape = typ.resolve_segmentation(m, categories=categories, segm_to_label=segm_to_label) | |
if segmentation_shape is not None: | |
for k, v in segs.items(): | |
print("k: ", k) | |
segs[k] = scale_segmentation(v, segmentation_shape) | |
shape = segmentation_shape | |
# Some additional metadata to provide | |
segs['sh'], segs['sw'] = shape | |
segs['i'] = j | |
segs['fn'] = fn | |
if categories is None or 'image' in categories: | |
segs['image'] = np.asarray(Image.open(fn).convert('L').resize((settings.IMG_SIZE, settings.IMG_SIZE))) | |
return segs | |
# def convertRGBToGray(rgbImg): | |
# return np.dot(rgbImg[...,:3], [0.114, 0.587, 0.299]) | |
def scale_segmentation(segmentation, dims, crop=False): | |
''' | |
Zooms a 2d or 3d segmentation to the given dims, using nearest neighbor. | |
''' | |
shape = numpy.shape(segmentation) | |
if len(shape) < 2 or shape[-2:] == dims: | |
return segmentation | |
peel = (len(shape) == 2) | |
if peel: | |
segmentation = segmentation[numpy.newaxis] | |
levels = segmentation.shape[0] | |
result = numpy.zeros((levels, ) + dims, | |
dtype=segmentation.dtype) | |
ratio = (1,) + tuple(res / float(orig) | |
for res, orig in zip(result.shape[1:], segmentation.shape[1:])) | |
if not crop: | |
safezoom(segmentation, ratio, output=result, order=0) | |
else: | |
ratio = max(ratio[1:]) | |
height = int(round(dims[0] / ratio)) | |
hmargin = (segmentation.shape[0] - height) // 2 | |
width = int(round(dims[1] / ratio)) | |
wmargin = (segmentation.shape[1] - height) // 2 | |
safezoom(segmentation[:, hmargin:hmargin+height, | |
wmargin:wmargin+width], | |
(1, ratio, ratio), output=result, order=0) | |
if peel: | |
result = result[0] | |
return result | |
def safezoom(array, ratio, output=None, order=0): | |
'''Like numpy.zoom, but does not crash when the first dimension | |
of the array is of size 1, as happens often with segmentations''' | |
dtype = array.dtype | |
if array.dtype == numpy.float16: | |
array = array.astype(numpy.float32) | |
if array.shape[0] == 1: | |
if output is not None: | |
output = output[0,...] | |
result = zoom(array[0,...], ratio[1:], | |
output=output, order=order) | |
if output is None: | |
output = result[numpy.newaxis] | |
else: | |
result = zoom(array, ratio, output=output, order=order) | |
if output is None: | |
output = result | |
return output.astype(dtype) | |
def setup_sigint(): | |
import threading | |
if not isinstance(threading.current_thread(), threading._MainThread): | |
return None | |
return signal.signal(signal.SIGINT, signal.SIG_IGN) | |
def restore_sigint(original): | |
import threading | |
if not isinstance(threading.current_thread(), threading._MainThread): | |
return | |
if original is None: | |
original = signal.SIG_DFL | |
signal.signal(signal.SIGINT, original) | |
def wants(what, option): | |
if option is None: | |
return True | |
return what in option | |
def normalize_image(rgb_image, bgr_mean): | |
""" | |
Load input image and preprocess for Caffe: | |
- cast to float | |
- switch channels RGB -> BGR | |
- subtract mean | |
- transpose to channel x height x width order | |
""" | |
# img = numpy.array(rgb_image, dtype=numpy.float32) | |
# if (img.ndim == 2): | |
# img = numpy.repeat(img[:,:,None], 3, axis = 2) | |
# # img = img[:,:,::-1] | |
# if bgr_mean is not None: | |
# img -= bgr_mean | |
rgb_image = np.expand_dims(rgb_image, axis=0) | |
# rgb_image = rgb_image.transpose((2,0,1)) | |
# print("rgb_image shape: ", rgb_image.shape) | |
return rgb_image | |
### Original code | |
# def normalize_image(rgb_image, bgr_mean): | |
# """ | |
# Load input image and preprocess for Caffe: | |
# - cast to float | |
# - switch channels RGB -> BGR | |
# - subtract mean | |
# - transpose to channel x height x width order | |
# """ | |
# img = numpy.array(rgb_image, dtype=numpy.float32) | |
# if (img.ndim == 2): | |
# img = numpy.repeat(img[:,:,None], 3, axis = 2) | |
# img = img[:,:,::-1] | |
# if bgr_mean is not None: | |
# img -= bgr_mean | |
# img = img.transpose((2,0,1)) | |
# return img | |
def normalize_label(label_data, shape, flatten=False): | |
""" | |
Given a 0, 1, 2, or 3-dimensional label_data and a default | |
shape of the form (1, y, x), returns a 3d tensor by | |
""" | |
dims = len(numpy.shape(label_data)) | |
if dims <= 2: | |
# Scalar data on this channel: fill shape | |
if dims == 1: | |
if flatten: | |
label_data = label_data[0] if len(label_data) else 0 | |
else: | |
return (numpy.ones(shape, dtype=numpy.int16) * | |
numpy.asarray(label_data, dtype=numpy.int16) | |
[:, numpy.newaxis, numpy.newaxis]) | |
return numpy.full(shape, label_data, dtype=numpy.int16) | |
else: | |
if dims == 3: | |
if flatten: | |
label_data = label_data[0] | |
else: | |
return label_data | |
return label_data[numpy.newaxis] | |
if __name__ == '__main__': | |
data = SegmentationData('broden1_227') | |
pd = SegmentationPrefetcher(data,categories=data.category_names()+['image'],once=True) | |
bs = pd.batches().next() | |