|
import os |
|
import torch |
|
import yaml |
|
import json |
|
import rasterio |
|
from rasterio.windows import Window |
|
from rasterio.transform import rowcol |
|
from pyproj import Transformer |
|
from torchvision import transforms |
|
import numpy as np |
|
from rasterio.features import shapes |
|
from shapely.geometry import shape |
|
import geopandas as gpd |
|
from dotenv import load_dotenv |
|
|
|
from messis.messis import LogConfusionMatrix |
|
|
|
|
|
load_dotenv() |
|
|
|
class InferenceDataLoader: |
|
def __init__(self, features_path, labels_path, field_ids_path, stats_path, window_size=224, n_timesteps=3, fold_indices=None, debug=False): |
|
self.features_path = features_path |
|
self.labels_path = labels_path |
|
self.field_ids_path = field_ids_path |
|
self.stats_path = stats_path |
|
self.window_size = window_size |
|
self.n_timesteps = n_timesteps |
|
self.fold_indices = fold_indices if fold_indices is not None else [] |
|
self.debug = debug |
|
|
|
|
|
self.means, self.stds = self.load_stats() |
|
|
|
|
|
self.transformer = Transformer.from_crs("EPSG:4326", "EPSG:32632", always_xy=True) |
|
|
|
def load_stats(self): |
|
"""Load normalization statistics for dataset from YAML file.""" |
|
if self.debug: |
|
print(f"Loading mean/std stats from {self.stats_path}") |
|
assert os.path.exists(self.stats_path), f"Mean/std stats file not found at {self.stats_path}" |
|
|
|
with open(self.stats_path, 'r') as file: |
|
stats = yaml.safe_load(file) |
|
|
|
mean_list, std_list, n_list = [], [], [] |
|
for fold in self.fold_indices: |
|
key = f'fold_{fold}' |
|
if key not in stats: |
|
raise ValueError(f"Mean/std stats for fold {fold} not found in {self.stats_path}") |
|
if self.debug: |
|
print(f"Stats with selected test fold {fold}: {stats[key]} over {self.n_timesteps} timesteps.") |
|
mean_list.append(torch.tensor(stats[key]['mean'])) |
|
std_list.append(torch.tensor(stats[key]['std'])) |
|
n_list.append(stats[key]['n_chips']) |
|
|
|
means, stds = [], [] |
|
for channel in range(mean_list[0].shape[0]): |
|
means.append(torch.stack([mean_list[i][channel] for i in range(len(mean_list))]).mean()) |
|
variances = torch.stack([std_list[i][channel] ** 2 for i in range(len(std_list))]) |
|
n = torch.tensor([n_list[i] for i in range(len(n_list))], dtype=torch.float32) |
|
combined_variance = torch.sum(variances * (n - 1)) / (torch.sum(n) - len(n_list)) |
|
stds.append(torch.sqrt(combined_variance)) |
|
|
|
return means * self.n_timesteps, stds * self.n_timesteps |
|
|
|
def identify_window(self, path, lon, lat): |
|
"""Identify the 224x224 window centered on the clicked coordinates (lon, lat) from the specified GeoTIFF.""" |
|
with rasterio.open(path) as src: |
|
|
|
utm_x, utm_y = self.transformer.transform(lon, lat) |
|
if self.debug: |
|
print("Source Transform", src.transform) |
|
print(f"UTM X: {utm_x}, UTM Y: {utm_y}") |
|
|
|
try: |
|
px, py = rowcol(src.transform, utm_x, utm_y) |
|
except ValueError: |
|
raise ValueError("Coordinates out of bounds for this raster.") |
|
|
|
if self.debug: |
|
print(f"Row: {py}, Column: {px}") |
|
|
|
half_window_size = self.window_size // 2 |
|
|
|
row_off = px - half_window_size |
|
col_off = py - half_window_size |
|
|
|
if row_off < 0: |
|
row_off = 0 |
|
if col_off < 0: |
|
col_off = 0 |
|
if row_off + self.window_size > src.width: |
|
row_off = src.width - self.window_size |
|
if col_off + self.window_size > src.height: |
|
col_off = src.height - self.window_size |
|
|
|
window = Window(col_off, row_off, self.window_size, self.window_size) |
|
window_transform = src.window_transform(window) |
|
if self.debug: |
|
print(f"Window: {window}") |
|
print(f"Window Transform: {window_transform}") |
|
crs = src.crs |
|
|
|
return window, window_transform, crs |
|
|
|
def extract_window(self, path, window): |
|
"""Extract data from the specified window from the GeoTIFF.""" |
|
with rasterio.open(path) as src: |
|
window_data = src.read(window=window) |
|
|
|
if self.debug: |
|
print(f"Extracted window data from {path}") |
|
print(f"Min: {window_data.min()}, Max: {window_data.max()}") |
|
|
|
return window_data |
|
|
|
def prepare_data_for_model(self, features_data): |
|
"""Prepare the window data for model inference.""" |
|
|
|
features_data = torch.tensor(features_data, dtype=torch.float32) |
|
|
|
|
|
normalize = transforms.Normalize(mean=self.means, std=self.stds) |
|
features_data = normalize(features_data) |
|
|
|
|
|
height, width = features_data.shape[-2:] |
|
features_data = features_data.view(self.n_timesteps, 6, height, width).permute(1, 0, 2, 3) |
|
|
|
|
|
features_data = features_data.unsqueeze(0) |
|
|
|
return features_data |
|
|
|
def get_data(self, lon, lat): |
|
"""Extract, normalize, and prepare data for inference, including labels and field IDs.""" |
|
|
|
window, features_transform, features_crs = self.identify_window(self.features_path, lon, lat) |
|
|
|
|
|
features_data = self.extract_window(self.features_path, window) |
|
label_data = self.extract_window(self.labels_path, window) |
|
field_ids_data = self.extract_window(self.field_ids_path, window) |
|
|
|
|
|
prepared_features_data = self.prepare_data_for_model(features_data) |
|
|
|
|
|
label_data = torch.tensor(label_data, dtype=torch.long) |
|
field_ids_data = torch.tensor(field_ids_data, dtype=torch.long) |
|
|
|
|
|
return prepared_features_data, label_data, field_ids_data, features_transform, features_crs |
|
|
|
def crop_predictions_to_gdf(field_ids, targets, predictions, transform, crs, class_names): |
|
""" |
|
Convert field_ids, targets, and predictions tensors to field polygons with corresponding class reference. |
|
|
|
:param field_ids: PyTorch tensor of shape (1, 224, 224) representing individual fields |
|
:param targets: PyTorch tensor of shape (1, 224, 224) for targets |
|
:param predictions: PyTorch tensor of shape (1, 224, 224) for predictions |
|
:param transform: Affine transform for georeferencing |
|
:param crs: Coordinate reference system (CRS) of the data |
|
:param class_names: Dictionary mapping class indices to class names |
|
:return: GeoPandas DataFrame with polygons, prediction class labels, and target class labels |
|
""" |
|
field_array = field_ids.squeeze().cpu().numpy().astype(np.int32) |
|
target_array = targets.squeeze().cpu().numpy().astype(np.int8) |
|
pred_array = predictions.squeeze().cpu().numpy().astype(np.int8) |
|
|
|
polygons = [] |
|
field_values = [] |
|
target_values = [] |
|
pred_values = [] |
|
|
|
for geom, field_value in shapes(field_array, transform=transform): |
|
polygons.append(shape(geom)) |
|
field_values.append(field_value) |
|
|
|
|
|
target_value = target_array[field_array == field_value][0] |
|
pred_value = pred_array[field_array == field_value][0] |
|
|
|
target_values.append(target_value) |
|
pred_values.append(pred_value) |
|
|
|
gdf = gpd.GeoDataFrame({ |
|
'geometry': polygons, |
|
'field_id': field_values, |
|
'target': target_values, |
|
'prediction': pred_values |
|
}, crs=crs) |
|
|
|
gdf['prediction_class'] = gdf['prediction'].apply(lambda x: class_names[x]) |
|
gdf['target_class'] = gdf['target'].apply(lambda x: class_names[x]) |
|
|
|
gdf['correct'] = gdf['target'] == gdf['prediction'] |
|
|
|
gdf = gdf[gdf.geometry.area > 250] |
|
|
|
return gdf |
|
|
|
def perform_inference(lon, lat, model, config, debug=False): |
|
if os.environ.get('USE_LOCAL_DATA') == "True": |
|
features_path = "./data/stacked_features_cog.tif" |
|
labels_path = "./data/labels_cog.tif" |
|
field_ids_path = "./data/field_ids_cog.tif" |
|
else: |
|
features_path = "https://huggingface.co/datasets/crop-classification/zueri-crop-2/resolve/main/stacked_features_cog.tif" |
|
labels_path = "https://huggingface.co/datasets/crop-classification/zueri-crop-2/resolve/main/labels_cog.tif" |
|
field_ids_path = "https://huggingface.co/datasets/crop-classification/zueri-crop-2/resolve/main/field_ids_cog.tif" |
|
|
|
|
|
stats_path = "./data/chips_stats.yaml" |
|
dataset_info_path = "./data/dataset_info.json" |
|
|
|
loader = InferenceDataLoader(features_path, labels_path, field_ids_path, stats_path, n_timesteps=9, fold_indices=[0], debug=True) |
|
|
|
|
|
satellite_data, label_data, field_ids_data, features_transform, features_crs = loader.get_data(lon, lat) |
|
|
|
if debug: |
|
|
|
print(satellite_data.shape) |
|
print(label_data.shape) |
|
print(field_ids_data.shape) |
|
|
|
with open(dataset_info_path, 'r') as file: |
|
dataset_info = json.load(file) |
|
|
|
class_names = dataset_info['tier3'] |
|
|
|
tiers_dict = {k: v for k, v in config.hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)} |
|
tiers = list(tiers_dict.keys()) |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
output = model(satellite_data)['tier3_refinement_head'] |
|
|
|
pixelwise_outputs_stacked, majority_outputs_stacked = LogConfusionMatrix.get_pixelwise_and_majority_outputs(output, tiers, field_ids=field_ids_data, dataset_info=dataset_info) |
|
majority_tier3_predictions = majority_outputs_stacked[2] |
|
|
|
|
|
gdf = crop_predictions_to_gdf(field_ids_data, label_data, majority_tier3_predictions, features_transform, features_crs, class_names) |
|
|
|
|
|
gdf = gdf[['prediction_class', 'target_class', 'correct', 'geometry']] |
|
gdf.columns = ['Prediction', 'Target', 'Correct', 'geometry'] |
|
|
|
return gdf |