SpaGAN / models /spabert /datasets /osm_sample_loader.py
JasonTPhillipsJr's picture
Update models/spabert/datasets/osm_sample_loader.py
97bda4d verified
import os
import sys
import numpy as np
import json
import math
import torch
from transformers import RobertaTokenizer, BertTokenizer
from torch.utils.data import Dataset
#sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets')
#sys.path.append('/content/drive/MyDrive/spaBERT/spabert/datasets')
from models.spabert.datasets.dataset_loader_ver2 import SpatialDataset
#from dataset_loader_ver2 import SpatialDataset
import pdb
class PbfMapDataset(SpatialDataset):
def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10,
with_type = True, sep_between_neighbors = False, label_encoder = None, mode = None, num_neighbor_limit = None, random_remove_neighbor = 0.,type_key_str='class'):
if tokenizer is None:
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
else:
self.tokenizer = tokenizer
self.max_token_len = max_token_len
self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance
self.with_type = with_type
self.sep_between_neighbors = sep_between_neighbors
self.label_encoder = label_encoder
self.num_neighbor_limit = num_neighbor_limit
self.read_file(data_file_path, mode)
self.random_remove_neighbor = random_remove_neighbor
self.type_key_str = type_key_str # key name of the class type in the input data dictionary
super(PbfMapDataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors )
def read_file(self, data_file_path, mode):
with open(data_file_path, 'r') as f:
data = f.readlines()
if mode == 'train':
data = data[0:int(len(data) * 0.8)]
elif mode == 'test':
data = data[int(len(data) * 0.8):]
elif mode is None: # use the full dataset (for mlm)
pass
else:
raise NotImplementedError
self.len_data = len(data) # updated data length
self.data = data
def load_data(self, index):
spatial_dist_fill = self.spatial_dist_fill
line = self.data[index] # take one line from the input data according to the index
line_data_dict = json.loads(line)
# process pivot
pivot_name = line_data_dict['info']['name']
pivot_pos = line_data_dict['info']['geometry']['coordinates']
neighbor_info = line_data_dict['neighbor_info']
neighbor_name_list = neighbor_info['name_list']
neighbor_geometry_list = neighbor_info['geometry_list']
if self.random_remove_neighbor != 0:
num_neighbors = len(neighbor_name_list)
rand_neighbor = np.random.uniform(size = num_neighbors)
neighbor_keep_arr = (rand_neighbor >= self.random_remove_neighbor) # select the neighbors to be removed
neighbor_keep_arr = np.where(neighbor_keep_arr)[0]
new_neighbor_name_list, new_neighbor_geometry_list = [],[]
for i in range(0, num_neighbors):
if i in neighbor_keep_arr:
new_neighbor_name_list.append(neighbor_name_list[i])
new_neighbor_geometry_list.append(neighbor_geometry_list[i])
neighbor_name_list = new_neighbor_name_list
neighbor_geometry_list = new_neighbor_geometry_list
if self.num_neighbor_limit is not None:
neighbor_name_list = neighbor_name_list[0:self.num_neighbor_limit]
neighbor_geometry_list = neighbor_geometry_list[0:self.num_neighbor_limit]
train_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill )
if self.with_type:
pivot_type = line_data_dict['info'][self.type_key_str]
train_data['pivot_type'] = torch.tensor(self.label_encoder.transform([pivot_type])[0]) # scalar, label_id
if 'ogc_fid' in line_data_dict['info']:
train_data['ogc_fid'] = line_data_dict['info']['ogc_fid']
return train_data
def __len__(self):
return self.len_data
def __getitem__(self, index):
return self.load_data(index)
class PbfMapDatasetMarginRanking(SpatialDataset):
def __init__(self, data_file_path, type_list = None, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10,
sep_between_neighbors = False, mode = None, num_neighbor_limit = None, random_remove_neighbor = 0., type_key_str='class'):
if tokenizer is None:
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
else:
self.tokenizer = tokenizer
self.type_list = type_list
self.type_key_str = type_key_str # key name of the class type in the input data dictionary
self.max_token_len = max_token_len
self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance
self.sep_between_neighbors = sep_between_neighbors
# self.label_encoder = label_encoder
self.num_neighbor_limit = num_neighbor_limit
self.read_file(data_file_path, mode)
self.random_remove_neighbor = random_remove_neighbor
self.mode = mode
super(PbfMapDatasetMarginRanking, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors )
def read_file(self, data_file_path, mode):
with open(data_file_path, 'r') as f:
data = f.readlines()
if mode == 'train':
data = data[0:int(len(data) * 0.8)]
elif mode == 'test':
data = data[int(len(data) * 0.8):]
self.all_types_data = self.prepare_all_types_data()
elif mode is None: # use the full dataset (for mlm)
pass
else:
raise NotImplementedError
self.len_data = len(data) # updated data length
self.data = data
def prepare_all_types_data(self):
type_list = self.type_list
spatial_dist_fill = self.spatial_dist_fill
type_data_dict = dict()
for type_name in type_list:
type_pos = [None, None] # use filler values
type_data = self.parse_spatial_context(type_name, type_pos, pivot_dist_fill = 0.,
neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill)
type_data_dict[type_name] = type_data
return type_data_dict
def load_data(self, index):
spatial_dist_fill = self.spatial_dist_fill
line = self.data[index] # take one line from the input data according to the index
line_data_dict = json.loads(line)
# process pivot
pivot_name = line_data_dict['info']['name']
pivot_pos = line_data_dict['info']['geometry']['coordinates']
neighbor_info = line_data_dict['neighbor_info']
neighbor_name_list = neighbor_info['name_list']
neighbor_geometry_list = neighbor_info['geometry_list']
if self.random_remove_neighbor != 0:
num_neighbors = len(neighbor_name_list)
rand_neighbor = np.random.uniform(size = num_neighbors)
neighbor_keep_arr = (rand_neighbor >= self.random_remove_neighbor) # select the neighbors to be removed
neighbor_keep_arr = np.where(neighbor_keep_arr)[0]
new_neighbor_name_list, new_neighbor_geometry_list = [],[]
for i in range(0, num_neighbors):
if i in neighbor_keep_arr:
new_neighbor_name_list.append(neighbor_name_list[i])
new_neighbor_geometry_list.append(neighbor_geometry_list[i])
neighbor_name_list = new_neighbor_name_list
neighbor_geometry_list = new_neighbor_geometry_list
if self.num_neighbor_limit is not None:
neighbor_name_list = neighbor_name_list[0:self.num_neighbor_limit]
neighbor_geometry_list = neighbor_geometry_list[0:self.num_neighbor_limit]
train_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill )
if 'ogc_fid' in line_data_dict['info']:
train_data['ogc_fid'] = line_data_dict['info']['ogc_fid']
# train_data['pivot_type'] = torch.tensor(self.label_encoder.transform([pivot_type])[0]) # scalar, label_id
pivot_type = line_data_dict['info'][self.type_key_str]
train_data['pivot_type'] = pivot_type
if self.mode == 'train':
# postive class
postive_name = pivot_type # class type string as input to tokenizer
positive_pos = [None, None] # use filler values
postive_type_data = self.parse_spatial_context(postive_name, positive_pos, pivot_dist_fill = 0.,
neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill)
train_data['positive_type_data'] = postive_type_data
# negative class
other_type_list = self.type_list.copy()
other_type_list.remove(pivot_type)
other_type = np.random.choice(other_type_list)
negative_name = other_type
negative_pos = [None, None] # use filler values
negative_type_data = self.parse_spatial_context(negative_name, negative_pos, pivot_dist_fill = 0.,
neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill)
train_data['negative_type_data'] = negative_type_data
elif self.mode == 'test':
# return data for all class types in type_list
train_data['all_types_data'] = self.all_types_data
else:
raise NotImplementedError
return train_data
def __len__(self):
return self.len_data
def __getitem__(self, index):
return self.load_data(index)