Spaces:
Sleeping
Sleeping
######### pull files | |
import os | |
from huggingface_hub import hf_hub_download | |
config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", | |
filename="multi_temporal_crop_classification_Prithvi_100M.py", | |
token=os.environ.get("token")) | |
ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", | |
filename='multi_temporal_crop_classification_Prithvi_100M.pth', | |
token=os.environ.get("token")) | |
########## | |
import argparse | |
from mmcv import Config | |
from mmseg.models import build_segmentor | |
from mmseg.datasets.pipelines import Compose, LoadImageFromFile | |
import rasterio | |
import torch | |
from mmseg.apis import init_segmentor | |
from mmcv.parallel import collate, scatter | |
import numpy as np | |
import glob | |
import os | |
import time | |
import numpy as np | |
import gradio as gr | |
from functools import partial | |
import pdb | |
import matplotlib.pyplot as plt | |
def open_tiff(fname): | |
with rasterio.open(fname, "r") as src: | |
data = src.read() | |
return data | |
def write_tiff(img_wrt, filename, metadata): | |
""" | |
It writes a raster image to file. | |
:param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands) | |
:param filename: file path to the output file | |
:param metadata: metadata to use to write the raster to disk | |
:return: | |
""" | |
with rasterio.open(filename, "w", **metadata) as dest: | |
if len(img_wrt.shape) == 2: | |
img_wrt = img_wrt[None] | |
for i in range(img_wrt.shape[0]): | |
dest.write(img_wrt[i, :, :], i + 1) | |
return filename | |
def get_meta(fname): | |
with rasterio.open(fname, "r") as src: | |
meta = src.meta | |
return meta | |
def preprocess_example(example_list): | |
example_list = [os.path.join(os.path.abspath(''), x) for x in example_list] | |
return example_list | |
def inference_segmentor(model, imgs, custom_test_pipeline=None): | |
"""Inference image(s) with the segmentor. | |
Args: | |
model (nn.Module): The loaded segmentor. | |
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded | |
images. | |
Returns: | |
(list[Tensor]): The segmentation result. | |
""" | |
cfg = model.cfg | |
device = next(model.parameters()).device # model device | |
# build the data pipeline | |
test_pipeline = [LoadImageFromFile()] + cfg.data.test.pipeline[1:] if custom_test_pipeline == None else custom_test_pipeline | |
test_pipeline = Compose(test_pipeline) | |
# prepare data | |
data = [] | |
imgs = imgs if isinstance(imgs, list) else [imgs] | |
for img in imgs: | |
img_data = {'img_info': {'filename': img}} | |
img_data = test_pipeline(img_data) | |
data.append(img_data) | |
# print(data.shape) | |
data = collate(data, samples_per_gpu=len(imgs)) | |
if next(model.parameters()).is_cuda: | |
# data = collate(data, samples_per_gpu=len(imgs)) | |
# scatter to specified GPU | |
data = scatter(data, [device])[0] | |
else: | |
# img_metas = scatter(data['img_metas'],'cpu') | |
# data['img_metas'] = [i.data[0] for i in data['img_metas']] | |
img_metas = data['img_metas'].data[0] | |
img = data['img'] | |
data = {'img': img, 'img_metas':img_metas} | |
with torch.no_grad(): | |
result = model(return_loss=False, rescale=True, **data) | |
return result | |
def inference_on_file(target_image, model, custom_test_pipeline): | |
target_image = target_image.name | |
# print(type(target_image)) | |
# output_image = target_image.replace('.tif', '_pred.tif') | |
time_taken=-1 | |
try: | |
st = time.time() | |
print('Running inference...') | |
result = inference_segmentor(model, target_image, custom_test_pipeline) | |
print("Output has shape: " + str(result[0].shape)) | |
##### get metadata mask | |
mask = open_tiff(target_image) | |
# rgb = mask[[2, 1, 0], :, :].transpose((1,2,0)) | |
rgb1 = mask[[2, 1, 0], :, :].transpose((1,2,0)) | |
rgb2 = mask[[8, 7, 6], :, :].transpose((1,2,0)) | |
rgb3 = mask[[14, 13, 12], :, :].transpose((1,2,0)) | |
meta = get_meta(target_image) | |
mask = np.where(mask == meta['nodata'], 1, 0) | |
mask = np.max(mask, axis=0)[None] | |
result[0] = np.where(mask == 1, -1, result[0]) | |
##### Save file to disk | |
meta["count"] = 1 | |
meta["dtype"] = "int16" | |
meta["compress"] = "lzw" | |
meta["nodata"] = -1 | |
print('Saving output...') | |
# write_tiff(result[0], output_image, meta) | |
et = time.time() | |
time_taken = np.round(et - st, 1) | |
print(f'Inference completed in {str(time_taken)} seconds') | |
except: | |
print(f'Error on image {target_image} \nContinue to next input') | |
return rgb, result[0][0]*255 | |
def process_test_pipeline(custom_test_pipeline, bands=None): | |
# change extracted bands if necessary | |
if bands is not None: | |
extract_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'] == 'BandsExtract' ] | |
if len(extract_index) > 0: | |
custom_test_pipeline[extract_index[0]]['bands'] = eval(bands) | |
collect_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'].find('Collect') > -1] | |
# adapt collected keys if necessary | |
if len(collect_index) > 0: | |
keys = ['img_info', 'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg'] | |
custom_test_pipeline[collect_index[0]]['meta_keys'] = keys | |
return custom_test_pipeline | |
config = Config.fromfile(config_path) | |
config.model.backbone.pretrained=None | |
model = init_segmentor(config, ckpt, device='cpu') | |
custom_test_pipeline=process_test_pipeline(model.cfg.data.test.pipeline, None) | |
func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline) | |
with gr.Blocks() as demo: | |
gr.Markdown(value='# Prithvi multi temporal crop classification') | |
gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to classify crop and other land use categories using multi temporal data. More detailes can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification).\n | |
The user needs to provide an HLS geotiff image, including 18 bands for 3 time-step, and each time-step includes the channels described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order. | |
''') | |
with gr.Row(): | |
with gr.Column(): | |
inp = gr.File() | |
btn = gr.Button("Submit") | |
with gr.Row(): | |
gr.Markdown(value='### T1') | |
gr.Markdown(value='### T2') | |
gr.Markdown(value='### T3') | |
gr.Markdown(value='### Model prediction') | |
with gr.Row(): | |
inp1=gr.Image(image_mode='RGB') | |
inp2=gr.Image(image_mode='RGB') | |
inp3=gr.Image(image_mode='RGB') | |
out = gr.Image(image_mode='L') | |
btn.click(fn=func, inputs=inp, outputs=[inp1, inp2, inp3, out]) | |
with gr.Row(): | |
gr.Examples(examples=["chip_102_345_merged.tif", | |
"chip_104_104_merged.tif", | |
"chip_109_421_merged.tif"], | |
inputs=inp, | |
outputs=[inp1, inp2, inp3, out], | |
preprocess=preprocess_example, | |
fn=func, | |
cache_examples=True, | |
) | |
demo.launch() |