earthwork-net-model / create_earthwork_dataset.py
mac999's picture
Upload 7 files
af359c9 verified
# title: create earthwwork train dataset
# author: Taewook Kang
# date: 2024.3.27
# description: create earthwork train dataset
# license: MIT
# reference: https://pyautocad.readthedocs.io/en/latest/_modules/pyautocad/api.html
# version
# 0.1. 2024.3.27. create file
#
import os, math, argparse, json, re, traceback, numpy as np, pandas as pd, trimesh, laspy, shutil
import pyautocad, open3d as o3d, seaborn as sns, win32com.client, pythoncom
import matplotlib.pyplot as plt
from scipy.spatial import distance
from tqdm import trange, tqdm
from math import pi
def get_layer_to_label(cfg, layer):
layers = cfg['layers']
for lay in layers:
if lay['layer'] == layer:
return lay['label']
return ''
def get_entity_from_acad(entity_names = ['AcDbLine', 'AcDbPolyline', 'AcDbText']):
acad = pyautocad.Autocad(create_if_not_exists=True)
selections = acad.get_selection('Select entities to extract geometry')
geoms = []
for entity in tqdm(selections): # tqdm(acad.iter_objects()): # selections:
try:
if entity.EntityName in entity_names:
geoms.append(entity)
except Exception as e:
print(f'error: {e}')
continue
if not geoms:
print("No entities found in the drawing.")
return
return geoms
def get_bbox(polyline):
xmin, ymin, xmax, ymax = polyline[0][0], polyline[0][1], polyline[0][0], polyline[0][1]
for x, y in polyline:
xmin = min(xmin, x)
ymin = min(ymin, y)
xmax = max(xmax, x)
ymax = max(ymax, y)
return (xmin, ymin, xmax, ymax)
def get_xsections_from_acad(cfg):
entities = get_entity_from_acad()
# extract cross sections
xsec_list = []
xsec_entities = []
for entity in entities:
if entity.Layer == 'Nru_Frame_Crs_Design' and entity.EntityName == 'AcDbPolyline':
polyline = []
vertex_list = entity.Coordinates
for i in range(0, len(vertex_list), 2):
polyline.append((vertex_list[i], vertex_list[i+1]))
if len(polyline) < 2:
continue
bbox = get_bbox(polyline)
xsec = {'bbox': bbox, 'station': '', 'geom': []}
xsec_list.append(xsec)
else:
xsec_entities.append(entity)
if len(xsec_entities) == 0:
print("No cross section found in the drawing.")
return []
for xsec in xsec_list:
for entity in xsec_entities:
if entity.EntityName != 'AcDbText':
continue
pt = (entity.InsertionPoint[0], entity.InsertionPoint[1])
bbox = xsec['bbox']
if pt[0] < bbox[0] or pt[1] < bbox[1] or pt[0] > bbox[2] or pt[1] > bbox[3]:
continue
xsec_station = entity.TextString
pattern = r'\d+\+\d+\.\d+'
match = re.search(pattern, xsec_station)
if match:
xsec_station = match.group()
else:
xsec_station = '-1+000.00'
xsec['station'] = xsec_station
if len(xsec_list) == 0:
xsec = {'bbox': (-9999999999.0, -9999999999.0, 9999999999.0, 9999999999.0), 'station': '0+000.00'}
xsec_list.append(xsec)
xsec_list = sorted(xsec_list, key=lambda x: x['station']) # sorting xsec_list by station string, format is 'xxx+xxx.xx'
# extract geometry in each cross section
for xsec in tqdm(xsec_list):
for entity in xsec_entities:
label = get_layer_to_label(cfg, entity.Layer)
if label == '':
continue
closed = False
polyline = []
if entity.EntityName == 'AcDbLine':
polyline = [entity.StartPoint, entity.EndPoint],
closed = False
elif entity.EntityName == 'AcDbPolyline':
vertex_list = entity.Coordinates
for i in range(0, len(vertex_list), 2):
polyline.append((vertex_list[i], vertex_list[i+1]))
closed = entity.Closed
else:
continue
xsec_bbox = xsec['bbox']
entity_bbox = get_bbox(polyline)
if entity_bbox[0] < xsec_bbox[0] or entity_bbox[1] < xsec_bbox[1] or entity_bbox[2] > xsec_bbox[2] or entity_bbox[3] > xsec_bbox[3]:
continue
geo = {
'label': label,
'polyline': polyline,
'closed': closed,
'earthwork_feature': []
}
xsec['geom'].append(geo)
return xsec_list
# defining function to add line plot
_draw_xsection_index = 0
_xsections = None
_plot_ax = None
def draw_xsections(ax, index):
xsec = _xsections[index]
for geo in xsec['geom']:
station = xsec['station']
ax.set_title(f'station: {station}')
polyline = np.array(geo['polyline'])
ax.plot(polyline[:,0], polyline[:,1], label=geo['label'])
ax.set_aspect('equal', 'box')
def next_button(event):
global _draw_xsection_index, _xsections, _plot_ax
_draw_xsection_index += 1
if _draw_xsection_index >= len(_xsections):
_draw_xsection_index = 0
_plot_ax.clear()
draw_xsections(_plot_ax, _draw_xsection_index)
def prev_button(event):
global _draw_xsection_index, _xsections, _plot_ax
_draw_xsection_index -= 1
if _draw_xsection_index < 0:
_draw_xsection_index = len(_xsections) - 1
_plot_ax.clear()
draw_xsections(_plot_ax, _draw_xsection_index)
def on_key_press(event):
if event.key == 'right':
next_button(None)
elif event.key == 'left':
prev_button(None)
def show_xsections(xsections):
from matplotlib.widgets import Button
global _draw_xsection_index, _xsections, _plot_ax
_xsections = xsections
fig = plt.figure()
_plot_ax = fig.subplots()
plt.subplots_adjust(left = 0.3, bottom = 0.25)
draw_xsections(_plot_ax, _draw_xsection_index)
# defining button and add its functionality
axprev = fig.add_axes([0.7, 0.05, 0.1, 0.075])
bprev = Button(axprev, 'prev', color="white")
bprev.on_clicked(prev_button)
axnext = fig.add_axes([0.81, 0.05, 0.1, 0.075])
bnext = Button(axnext, 'next', color="white")
bnext.on_clicked(next_button)
fig.canvas.mpl_connect('key_press_event', on_key_press)
plt.show()
def main():
parser = argparse.ArgumentParser(description='create earthwork train dataset')
parser.add_argument('--config', type=str, default='config.json', help='config file')
parser.add_argument('--output', type=str, default='output/', help='output directory')
parser.add_argument('--view', type=str, default='output/chain_chunk_6.json', help='view file')
args = parser.parse_args()
try:
if len(args.view) > 0:
with open(args.view, 'r') as f:
xsections = json.load(f)
show_xsections(xsections)
return
cfg = None
with open(args.config, 'r', encoding='utf-8') as f:
cfg = json.load(f)
chunk_index = 0
file_names = os.listdir(args.output)
if len(file_names):
pattern = r'chain_chunk_(\d+)\.json'
indices = [int(re.match(pattern, file_name).group(1)) for file_name in file_names if re.match(pattern, file_name)]
chunk_index = max(indices) + 1 if indices else 0
print(file_names)
while True:
xsections = get_xsections_from_acad(cfg)
if len(xsections) == 0:
break
geo_file = os.path.join(args.output, f'chain_chunk_{chunk_index}.json')
with open(geo_file, 'w') as f:
json.dump(xsections, f, indent=4)
print(f'{geo_file} was saved in {args.output}')
chunk_index += 1
except Exception as e:
print(f'error: {e}')
traceback.print_exc()
if __name__ == '__main__':
main()