Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import numpy as np | |
import json | |
import math | |
import torch | |
from transformers import RobertaTokenizer, BertTokenizer | |
from torch.utils.data import Dataset | |
#sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets') | |
#sys.path.append('/content/drive/MyDrive/spaBERT/spabert/datasets') | |
from models.spabert.datasets.dataset_loader_ver2 import SpatialDataset | |
#from dataset_loader_ver2 import SpatialDataset | |
import pdb | |
class PbfMapDataset(SpatialDataset): | |
def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10, | |
with_type = True, sep_between_neighbors = False, label_encoder = None, mode = None, num_neighbor_limit = None, random_remove_neighbor = 0.,type_key_str='class'): | |
if tokenizer is None: | |
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
else: | |
self.tokenizer = tokenizer | |
self.max_token_len = max_token_len | |
self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance | |
self.with_type = with_type | |
self.sep_between_neighbors = sep_between_neighbors | |
self.label_encoder = label_encoder | |
self.num_neighbor_limit = num_neighbor_limit | |
self.read_file(data_file_path, mode) | |
self.random_remove_neighbor = random_remove_neighbor | |
self.type_key_str = type_key_str # key name of the class type in the input data dictionary | |
super(PbfMapDataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors ) | |
def read_file(self, data_file_path, mode): | |
with open(data_file_path, 'r') as f: | |
data = f.readlines() | |
if mode == 'train': | |
data = data[0:int(len(data) * 0.8)] | |
elif mode == 'test': | |
data = data[int(len(data) * 0.8):] | |
elif mode is None: # use the full dataset (for mlm) | |
pass | |
else: | |
raise NotImplementedError | |
self.len_data = len(data) # updated data length | |
self.data = data | |
def load_data(self, index): | |
spatial_dist_fill = self.spatial_dist_fill | |
line = self.data[index] # take one line from the input data according to the index | |
line_data_dict = json.loads(line) | |
# process pivot | |
pivot_name = line_data_dict['info']['name'] | |
pivot_pos = line_data_dict['info']['geometry']['coordinates'] | |
neighbor_info = line_data_dict['neighbor_info'] | |
neighbor_name_list = neighbor_info['name_list'] | |
neighbor_geometry_list = neighbor_info['geometry_list'] | |
if self.random_remove_neighbor != 0: | |
num_neighbors = len(neighbor_name_list) | |
rand_neighbor = np.random.uniform(size = num_neighbors) | |
neighbor_keep_arr = (rand_neighbor >= self.random_remove_neighbor) # select the neighbors to be removed | |
neighbor_keep_arr = np.where(neighbor_keep_arr)[0] | |
new_neighbor_name_list, new_neighbor_geometry_list = [],[] | |
for i in range(0, num_neighbors): | |
if i in neighbor_keep_arr: | |
new_neighbor_name_list.append(neighbor_name_list[i]) | |
new_neighbor_geometry_list.append(neighbor_geometry_list[i]) | |
neighbor_name_list = new_neighbor_name_list | |
neighbor_geometry_list = new_neighbor_geometry_list | |
if self.num_neighbor_limit is not None: | |
neighbor_name_list = neighbor_name_list[0:self.num_neighbor_limit] | |
neighbor_geometry_list = neighbor_geometry_list[0:self.num_neighbor_limit] | |
train_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill ) | |
if self.with_type: | |
pivot_type = line_data_dict['info'][self.type_key_str] | |
train_data['pivot_type'] = torch.tensor(self.label_encoder.transform([pivot_type])[0]) # scalar, label_id | |
if 'ogc_fid' in line_data_dict['info']: | |
train_data['ogc_fid'] = line_data_dict['info']['ogc_fid'] | |
return train_data | |
def __len__(self): | |
return self.len_data | |
def __getitem__(self, index): | |
return self.load_data(index) | |
class PbfMapDatasetMarginRanking(SpatialDataset): | |
def __init__(self, data_file_path, type_list = None, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10, | |
sep_between_neighbors = False, mode = None, num_neighbor_limit = None, random_remove_neighbor = 0., type_key_str='class'): | |
if tokenizer is None: | |
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
else: | |
self.tokenizer = tokenizer | |
self.type_list = type_list | |
self.type_key_str = type_key_str # key name of the class type in the input data dictionary | |
self.max_token_len = max_token_len | |
self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance | |
self.sep_between_neighbors = sep_between_neighbors | |
# self.label_encoder = label_encoder | |
self.num_neighbor_limit = num_neighbor_limit | |
self.read_file(data_file_path, mode) | |
self.random_remove_neighbor = random_remove_neighbor | |
self.mode = mode | |
super(PbfMapDatasetMarginRanking, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors ) | |
def read_file(self, data_file_path, mode): | |
with open(data_file_path, 'r') as f: | |
data = f.readlines() | |
if mode == 'train': | |
data = data[0:int(len(data) * 0.8)] | |
elif mode == 'test': | |
data = data[int(len(data) * 0.8):] | |
self.all_types_data = self.prepare_all_types_data() | |
elif mode is None: # use the full dataset (for mlm) | |
pass | |
else: | |
raise NotImplementedError | |
self.len_data = len(data) # updated data length | |
self.data = data | |
def prepare_all_types_data(self): | |
type_list = self.type_list | |
spatial_dist_fill = self.spatial_dist_fill | |
type_data_dict = dict() | |
for type_name in type_list: | |
type_pos = [None, None] # use filler values | |
type_data = self.parse_spatial_context(type_name, type_pos, pivot_dist_fill = 0., | |
neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill) | |
type_data_dict[type_name] = type_data | |
return type_data_dict | |
def load_data(self, index): | |
spatial_dist_fill = self.spatial_dist_fill | |
line = self.data[index] # take one line from the input data according to the index | |
line_data_dict = json.loads(line) | |
# process pivot | |
pivot_name = line_data_dict['info']['name'] | |
pivot_pos = line_data_dict['info']['geometry']['coordinates'] | |
neighbor_info = line_data_dict['neighbor_info'] | |
neighbor_name_list = neighbor_info['name_list'] | |
neighbor_geometry_list = neighbor_info['geometry_list'] | |
if self.random_remove_neighbor != 0: | |
num_neighbors = len(neighbor_name_list) | |
rand_neighbor = np.random.uniform(size = num_neighbors) | |
neighbor_keep_arr = (rand_neighbor >= self.random_remove_neighbor) # select the neighbors to be removed | |
neighbor_keep_arr = np.where(neighbor_keep_arr)[0] | |
new_neighbor_name_list, new_neighbor_geometry_list = [],[] | |
for i in range(0, num_neighbors): | |
if i in neighbor_keep_arr: | |
new_neighbor_name_list.append(neighbor_name_list[i]) | |
new_neighbor_geometry_list.append(neighbor_geometry_list[i]) | |
neighbor_name_list = new_neighbor_name_list | |
neighbor_geometry_list = new_neighbor_geometry_list | |
if self.num_neighbor_limit is not None: | |
neighbor_name_list = neighbor_name_list[0:self.num_neighbor_limit] | |
neighbor_geometry_list = neighbor_geometry_list[0:self.num_neighbor_limit] | |
train_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill ) | |
if 'ogc_fid' in line_data_dict['info']: | |
train_data['ogc_fid'] = line_data_dict['info']['ogc_fid'] | |
# train_data['pivot_type'] = torch.tensor(self.label_encoder.transform([pivot_type])[0]) # scalar, label_id | |
pivot_type = line_data_dict['info'][self.type_key_str] | |
train_data['pivot_type'] = pivot_type | |
if self.mode == 'train': | |
# postive class | |
postive_name = pivot_type # class type string as input to tokenizer | |
positive_pos = [None, None] # use filler values | |
postive_type_data = self.parse_spatial_context(postive_name, positive_pos, pivot_dist_fill = 0., | |
neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill) | |
train_data['positive_type_data'] = postive_type_data | |
# negative class | |
other_type_list = self.type_list.copy() | |
other_type_list.remove(pivot_type) | |
other_type = np.random.choice(other_type_list) | |
negative_name = other_type | |
negative_pos = [None, None] # use filler values | |
negative_type_data = self.parse_spatial_context(negative_name, negative_pos, pivot_dist_fill = 0., | |
neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill) | |
train_data['negative_type_data'] = negative_type_data | |
elif self.mode == 'test': | |
# return data for all class types in type_list | |
train_data['all_types_data'] = self.all_types_data | |
else: | |
raise NotImplementedError | |
return train_data | |
def __len__(self): | |
return self.len_data | |
def __getitem__(self, index): | |
return self.load_data(index) | |