import argparse from mmcv import Config from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,wrap_fp16_model) from mmseg.models import build_segmentor import matplotlib.pyplot as plt import mmcv import torch from mmcv.parallel import collate, scatter from mmcv.runner import load_checkpoint from mmseg.datasets.pipelines import Compose from mmseg.models import build_segmentor from mmseg.datasets import build_dataloader, build_dataset, load_flood_test_data import rasterio import torch import torch.nn.functional as F from torchvision import transforms from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmseg.apis import multi_gpu_test, single_gpu_test, init_segmentor from . import custom # custom preprocessing for hls import pdb import numpy as np import glob import os import time def parse_args(): parser = argparse.ArgumentParser(description="Inference on burn scar fine-tuned model") parser.add_argument('-config', help='path to model configuration file') parser.add_argument('-ckpt', help='path to model checkpoint') parser.add_argument('-input', help='path to input images folder for inference') parser.add_argument('-output', help='directory path to save output images') parser.add_argument('-input_type', help='file type of input images',default="tif") args = parser.parse_args() return args def open_tiff(fname): with rasterio.open(fname, "r") as src: data = src.read() return data def write_tiff(img_wrt, filename, metadata): """ It writes a raster image to file. :param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands) :param filename: file path to the output file :param metadata: metadata to use to write the raster to disk :return: """ with rasterio.open(filename, "w", **metadata) as dest: if len(img_wrt.shape) == 2: img_wrt = img_wrt[None] for i in range(img_wrt.shape[0]): dest.write(img_wrt[i, :, :], i + 1) def get_meta(fname): with rasterio.open(fname, "r") as src: meta = src.meta return meta def preprocess_image(data, means, stds, nodata=-9999): data=np.where(data == nodata, 0, data) data = data.astype(np.float32) if len(data)==2: (x, y) = data else: x=data y=np.full((x.shape[-2], x.shape[-1]), -1) im, label = x.copy(), y.copy() label = label.astype(np.float64) im1 = im[0] # red im2 = im[1] # green im3 = im[2] # blue im4 = im[3] # NIR narrow im5 = im[4] # swir 1 im6 = im[5] # swir 2 dim = x.shape[-1] label = label.squeeze() norm = transforms.Normalize(means, stds) ims = [torch.stack((transforms.ToTensor()(im1).squeeze(), transforms.ToTensor()(im2).squeeze(), transforms.ToTensor()(im3).squeeze(), transforms.ToTensor()(im4).squeeze(), transforms.ToTensor()(im5).squeeze(), transforms.ToTensor()(im6).squeeze()))] ims = [norm(im) for im in ims] ims = torch.stack(ims) label = transforms.ToTensor()(label).squeeze() _img_metas = { 'ori_shape': (dim, dim), 'img_shape': (dim, dim), 'pad_shape': (dim, dim), 'scale_factor': [1., 1., 1., 1.], 'flip': False, # needs flip direction specified } img_metas = [_img_metas] * 1 return {"img": ims, "img_metas": img_metas, "gt_semantic_seg": label} def load_model(config, ckpt): print('Loading configuration...') cfg = Config.fromfile(config) print('Building model...') model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) print('Loading checkpoint...') checkpoint = load_checkpoint(model,ckpt, map_location='cpu') print('Evaluating model...') model = MMDataParallel(model, device_ids=[0]) model.eval() return model def inference_on_file(model, target_image, output_image, means, stds): try: st = time.time() data_orig = open_tiff(target_image) meta = get_meta(target_image) nodata = meta['nodata'] if meta['nodata'] is not None else -9999 data = preprocess_image(data_orig, means, stds, nodata) small_fixed_size_arrs = custom.split_and_pad(data['img'][:,:,None,:,:], (1, 6, 1, 224, 224)) single_chip_batch = [torch.vstack([torch.tensor(t) for t in small_fixed_size_arrs])] print('Running inference...') with torch.no_grad(): result = model(single_chip_batch, data['img_metas'], return_loss=False, rescale=False) print("Result: Unique Values: ",np.unique(result)) print("Output has shape: " + str(result[0].shape)) #### TO DO: Post process (e.g. morphological operations) result = custom.merge_and_unpad(result, (data_orig.shape[-2],data_orig.shape[-1]), (224, 224)) print("Result: Unique Values: ",np.unique(result)) ##### Save file to disk meta["count"] = 1 meta["dtype"] = "int16" meta["compress"] = "lzw" meta["nodata"] = -1 meta["nodata"] = nodata print('Saving output...') # pdb.set_trace() result = np.where(data_orig[0] == nodata, nodata, result) write_tiff(result, output_image, meta) et = time.time() print(f'Inference completed in {str(np.round(et - st, 1))} seconds. Output available at: ' + output_image) except: print(f'Error on image {target_image} \nContinue to next input') def main(): args = parse_args() model = load_model(args.config, args.ckpt) image_pattern = "*merged" target_images = glob.glob(os.path.join(args.input, image_pattern + "." + args.input_type)) print('Identified images to predict on: ' + str(len(target_images))) if not os.path.isdir(args.output): os.mkdir(args.output) means, stds = custom.calculate_band_statistics(args.input, image_pattern, bands=[0, 1, 2, 3, 4, 5]) for i, target_image in enumerate(target_images): print(f'Working on Image {i}') output_image = os.path.join(args.output,target_image.split("/")[-1].split(f"_{image_pattern[1:]}.")[0]+'_pred.'+args.input_type) inference_on_file(model, target_image, output_image, means, stds) print("Running metric eval") gt_dir = "/home/workdir/hls-foundation/data/burn_scars/validation" pred_dir = args.output avg_dice_score = custom.compute_metrics(gt_dir, pred_dir) print("Average Dice score:", avg_dice_score) if __name__ == "__main__": main()