Upload 7 files
Browse files- config.json +120 -0
- create_earthwork_dataset.py +232 -0
- ena_dataset.py +137 -0
- ena_run_model.py +548 -0
- eval_model.py +75 -0
- extract_ewlog.py +244 -0
- prepare_dataset.py +460 -0
config.json
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"layers": [
|
3 |
+
{
|
4 |
+
"layer": "Nru_Geo_Crs_Center",
|
5 |
+
"label": "center"
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"layer": "Nru_Crs_Pave_Surface",
|
9 |
+
"label": "pave_surface"
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"layer": "Nru_Crs_Pave_Subgrade",
|
13 |
+
"label": "pave_subgrade"
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"layer": "Nru_Geo_Surface",
|
17 |
+
"label": "ground"
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"layer": "Nru_Crs_Pave_Bottom",
|
21 |
+
"label": "pave_bottom"
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"layer": "Nru_Geo_Underground_1",
|
25 |
+
"label": "rr"
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"layer": "Nru_Geo_Underground_2",
|
29 |
+
"label": "br"
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"layer": "Nru_Crs_Slope",
|
33 |
+
"label": "slope"
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"layer": "Nru_Stru_Bench",
|
37 |
+
"label": "struct"
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"layer": "Nru_Stru_Frt_Sodan",
|
41 |
+
"label": "struct"
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"layer": "Nru_Stru_Smr",
|
45 |
+
"label": "struct"
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"layer": "Nru_Stru_Smr_Ending",
|
49 |
+
"label": "struct"
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"layer": "Nru_Stru_Ditch",
|
53 |
+
"label": "struct"
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"layer": "Nru_Stru_Ditch_Bench",
|
57 |
+
"label": "struct"
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"layer": "Nru_Stru_Frt",
|
61 |
+
"label": "struct"
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"layer": "Nru_Crs_Ew_깎기_토사",
|
65 |
+
"label": "cut_ea"
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"layer": "Nru_Crs_Ew_깎기_리핑암",
|
69 |
+
"label": "cut_rr"
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"layer": "Nru_Crs_Ew_일반발파",
|
73 |
+
"label": "cut_br"
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"layer": "Nru_Crs_Ew_대규모발파",
|
77 |
+
"label": "cut_br"
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"layer": "Nru_Crs_Ew_중규모진동제어발파",
|
81 |
+
"label": "cut_br"
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"layer": "Nru_Crs_Ew_터파기_토사",
|
85 |
+
"label": "cut_ditch"
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"layer": "Nru_Crs_Ew_쌓기_노상",
|
89 |
+
"label": "fill_subbed"
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"layer": "Nru_Crs_Ew_쌓기_노체",
|
93 |
+
"label": "fill_subbody"
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"layer": "Nru_Crs_Pave_Layer-1",
|
97 |
+
"label": "pave_layer1"
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"layer": "Nru_Crs_Pave_Layer-2",
|
101 |
+
"label": "pave_layer2"
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"layer": "Nru_Crs_Pave_Layer-3",
|
105 |
+
"label": "pave_layer3"
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"layer": "Nru_Crs_Pave_Layer-4",
|
109 |
+
"label": "pave_layer4"
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"layer": "Nru_Crs_Steps",
|
113 |
+
"label": "steps"
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"layer": "Nru_Crs_Curb",
|
117 |
+
"label": "curb"
|
118 |
+
}
|
119 |
+
]
|
120 |
+
}
|
create_earthwork_dataset.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# title: create earthwwork train dataset
|
2 |
+
# author: Taewook Kang
|
3 |
+
# date: 2024.3.27
|
4 |
+
# description: create earthwork train dataset
|
5 |
+
# license: MIT
|
6 |
+
# reference: https://pyautocad.readthedocs.io/en/latest/_modules/pyautocad/api.html
|
7 |
+
# version
|
8 |
+
# 0.1. 2024.3.27. create file
|
9 |
+
#
|
10 |
+
import os, math, argparse, json, re, traceback, numpy as np, pandas as pd, trimesh, laspy, shutil
|
11 |
+
import pyautocad, open3d as o3d, seaborn as sns, win32com.client, pythoncom
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from scipy.spatial import distance
|
14 |
+
from tqdm import trange, tqdm
|
15 |
+
from math import pi
|
16 |
+
|
17 |
+
def get_layer_to_label(cfg, layer):
|
18 |
+
layers = cfg['layers']
|
19 |
+
for lay in layers:
|
20 |
+
if lay['layer'] == layer:
|
21 |
+
return lay['label']
|
22 |
+
return ''
|
23 |
+
|
24 |
+
def get_entity_from_acad(entity_names = ['AcDbLine', 'AcDbPolyline', 'AcDbText']):
|
25 |
+
acad = pyautocad.Autocad(create_if_not_exists=True)
|
26 |
+
selections = acad.get_selection('Select entities to extract geometry')
|
27 |
+
|
28 |
+
geoms = []
|
29 |
+
for entity in tqdm(selections): # tqdm(acad.iter_objects()): # selections:
|
30 |
+
try:
|
31 |
+
if entity.EntityName in entity_names:
|
32 |
+
geoms.append(entity)
|
33 |
+
except Exception as e:
|
34 |
+
print(f'error: {e}')
|
35 |
+
continue
|
36 |
+
|
37 |
+
if not geoms:
|
38 |
+
print("No entities found in the drawing.")
|
39 |
+
return
|
40 |
+
|
41 |
+
return geoms
|
42 |
+
|
43 |
+
def get_bbox(polyline):
|
44 |
+
xmin, ymin, xmax, ymax = polyline[0][0], polyline[0][1], polyline[0][0], polyline[0][1]
|
45 |
+
for x, y in polyline:
|
46 |
+
xmin = min(xmin, x)
|
47 |
+
ymin = min(ymin, y)
|
48 |
+
xmax = max(xmax, x)
|
49 |
+
ymax = max(ymax, y)
|
50 |
+
return (xmin, ymin, xmax, ymax)
|
51 |
+
|
52 |
+
def get_xsections_from_acad(cfg):
|
53 |
+
entities = get_entity_from_acad()
|
54 |
+
|
55 |
+
# extract cross sections
|
56 |
+
xsec_list = []
|
57 |
+
xsec_entities = []
|
58 |
+
for entity in entities:
|
59 |
+
if entity.Layer == 'Nru_Frame_Crs_Design' and entity.EntityName == 'AcDbPolyline':
|
60 |
+
polyline = []
|
61 |
+
vertex_list = entity.Coordinates
|
62 |
+
for i in range(0, len(vertex_list), 2):
|
63 |
+
polyline.append((vertex_list[i], vertex_list[i+1]))
|
64 |
+
if len(polyline) < 2:
|
65 |
+
continue
|
66 |
+
bbox = get_bbox(polyline)
|
67 |
+
|
68 |
+
xsec = {'bbox': bbox, 'station': '', 'geom': []}
|
69 |
+
xsec_list.append(xsec)
|
70 |
+
else:
|
71 |
+
xsec_entities.append(entity)
|
72 |
+
if len(xsec_entities) == 0:
|
73 |
+
print("No cross section found in the drawing.")
|
74 |
+
return []
|
75 |
+
|
76 |
+
for xsec in xsec_list:
|
77 |
+
for entity in xsec_entities:
|
78 |
+
if entity.EntityName != 'AcDbText':
|
79 |
+
continue
|
80 |
+
pt = (entity.InsertionPoint[0], entity.InsertionPoint[1])
|
81 |
+
bbox = xsec['bbox']
|
82 |
+
if pt[0] < bbox[0] or pt[1] < bbox[1] or pt[0] > bbox[2] or pt[1] > bbox[3]:
|
83 |
+
continue
|
84 |
+
xsec_station = entity.TextString
|
85 |
+
pattern = r'\d+\+\d+\.\d+'
|
86 |
+
match = re.search(pattern, xsec_station)
|
87 |
+
if match:
|
88 |
+
xsec_station = match.group()
|
89 |
+
else:
|
90 |
+
xsec_station = '-1+000.00'
|
91 |
+
xsec['station'] = xsec_station
|
92 |
+
|
93 |
+
if len(xsec_list) == 0:
|
94 |
+
xsec = {'bbox': (-9999999999.0, -9999999999.0, 9999999999.0, 9999999999.0), 'station': '0+000.00'}
|
95 |
+
xsec_list.append(xsec)
|
96 |
+
|
97 |
+
xsec_list = sorted(xsec_list, key=lambda x: x['station']) # sorting xsec_list by station string, format is 'xxx+xxx.xx'
|
98 |
+
|
99 |
+
# extract geometry in each cross section
|
100 |
+
for xsec in tqdm(xsec_list):
|
101 |
+
for entity in xsec_entities:
|
102 |
+
label = get_layer_to_label(cfg, entity.Layer)
|
103 |
+
if label == '':
|
104 |
+
continue
|
105 |
+
|
106 |
+
closed = False
|
107 |
+
polyline = []
|
108 |
+
if entity.EntityName == 'AcDbLine':
|
109 |
+
polyline = [entity.StartPoint, entity.EndPoint],
|
110 |
+
closed = False
|
111 |
+
elif entity.EntityName == 'AcDbPolyline':
|
112 |
+
vertex_list = entity.Coordinates
|
113 |
+
for i in range(0, len(vertex_list), 2):
|
114 |
+
polyline.append((vertex_list[i], vertex_list[i+1]))
|
115 |
+
closed = entity.Closed
|
116 |
+
else:
|
117 |
+
continue
|
118 |
+
|
119 |
+
xsec_bbox = xsec['bbox']
|
120 |
+
entity_bbox = get_bbox(polyline)
|
121 |
+
if entity_bbox[0] < xsec_bbox[0] or entity_bbox[1] < xsec_bbox[1] or entity_bbox[2] > xsec_bbox[2] or entity_bbox[3] > xsec_bbox[3]:
|
122 |
+
continue
|
123 |
+
|
124 |
+
geo = {
|
125 |
+
'label': label,
|
126 |
+
'polyline': polyline,
|
127 |
+
'closed': closed,
|
128 |
+
'earthwork_feature': []
|
129 |
+
}
|
130 |
+
xsec['geom'].append(geo)
|
131 |
+
|
132 |
+
return xsec_list
|
133 |
+
|
134 |
+
# defining function to add line plot
|
135 |
+
_draw_xsection_index = 0
|
136 |
+
_xsections = None
|
137 |
+
_plot_ax = None
|
138 |
+
def draw_xsections(ax, index):
|
139 |
+
xsec = _xsections[index]
|
140 |
+
for geo in xsec['geom']:
|
141 |
+
station = xsec['station']
|
142 |
+
ax.set_title(f'station: {station}')
|
143 |
+
polyline = np.array(geo['polyline'])
|
144 |
+
ax.plot(polyline[:,0], polyline[:,1], label=geo['label'])
|
145 |
+
ax.set_aspect('equal', 'box')
|
146 |
+
|
147 |
+
def next_button(event):
|
148 |
+
global _draw_xsection_index, _xsections, _plot_ax
|
149 |
+
_draw_xsection_index += 1
|
150 |
+
if _draw_xsection_index >= len(_xsections):
|
151 |
+
_draw_xsection_index = 0
|
152 |
+
_plot_ax.clear()
|
153 |
+
draw_xsections(_plot_ax, _draw_xsection_index)
|
154 |
+
|
155 |
+
def prev_button(event):
|
156 |
+
global _draw_xsection_index, _xsections, _plot_ax
|
157 |
+
_draw_xsection_index -= 1
|
158 |
+
if _draw_xsection_index < 0:
|
159 |
+
_draw_xsection_index = len(_xsections) - 1
|
160 |
+
_plot_ax.clear()
|
161 |
+
draw_xsections(_plot_ax, _draw_xsection_index)
|
162 |
+
|
163 |
+
def on_key_press(event):
|
164 |
+
if event.key == 'right':
|
165 |
+
next_button(None)
|
166 |
+
elif event.key == 'left':
|
167 |
+
prev_button(None)
|
168 |
+
|
169 |
+
def show_xsections(xsections):
|
170 |
+
from matplotlib.widgets import Button
|
171 |
+
global _draw_xsection_index, _xsections, _plot_ax
|
172 |
+
_xsections = xsections
|
173 |
+
|
174 |
+
fig = plt.figure()
|
175 |
+
_plot_ax = fig.subplots()
|
176 |
+
plt.subplots_adjust(left = 0.3, bottom = 0.25)
|
177 |
+
draw_xsections(_plot_ax, _draw_xsection_index)
|
178 |
+
|
179 |
+
# defining button and add its functionality
|
180 |
+
axprev = fig.add_axes([0.7, 0.05, 0.1, 0.075])
|
181 |
+
bprev = Button(axprev, 'prev', color="white")
|
182 |
+
bprev.on_clicked(prev_button)
|
183 |
+
axnext = fig.add_axes([0.81, 0.05, 0.1, 0.075])
|
184 |
+
bnext = Button(axnext, 'next', color="white")
|
185 |
+
bnext.on_clicked(next_button)
|
186 |
+
|
187 |
+
fig.canvas.mpl_connect('key_press_event', on_key_press)
|
188 |
+
|
189 |
+
plt.show()
|
190 |
+
|
191 |
+
def main():
|
192 |
+
parser = argparse.ArgumentParser(description='create earthwork train dataset')
|
193 |
+
parser.add_argument('--config', type=str, default='config.json', help='config file')
|
194 |
+
parser.add_argument('--output', type=str, default='output/', help='output directory')
|
195 |
+
parser.add_argument('--view', type=str, default='output/chain_chunk_6.json', help='view file')
|
196 |
+
|
197 |
+
args = parser.parse_args()
|
198 |
+
try:
|
199 |
+
if len(args.view) > 0:
|
200 |
+
with open(args.view, 'r') as f:
|
201 |
+
xsections = json.load(f)
|
202 |
+
show_xsections(xsections)
|
203 |
+
return
|
204 |
+
|
205 |
+
cfg = None
|
206 |
+
with open(args.config, 'r', encoding='utf-8') as f:
|
207 |
+
cfg = json.load(f)
|
208 |
+
|
209 |
+
chunk_index = 0
|
210 |
+
file_names = os.listdir(args.output)
|
211 |
+
if len(file_names):
|
212 |
+
pattern = r'chain_chunk_(\d+)\.json'
|
213 |
+
indices = [int(re.match(pattern, file_name).group(1)) for file_name in file_names if re.match(pattern, file_name)]
|
214 |
+
chunk_index = max(indices) + 1 if indices else 0
|
215 |
+
|
216 |
+
print(file_names)
|
217 |
+
|
218 |
+
while True:
|
219 |
+
xsections = get_xsections_from_acad(cfg)
|
220 |
+
if len(xsections) == 0:
|
221 |
+
break
|
222 |
+
geo_file = os.path.join(args.output, f'chain_chunk_{chunk_index}.json')
|
223 |
+
with open(geo_file, 'w') as f:
|
224 |
+
json.dump(xsections, f, indent=4)
|
225 |
+
print(f'{geo_file} was saved in {args.output}')
|
226 |
+
chunk_index += 1
|
227 |
+
except Exception as e:
|
228 |
+
print(f'error: {e}')
|
229 |
+
traceback.print_exc()
|
230 |
+
|
231 |
+
if __name__ == '__main__':
|
232 |
+
main()
|
ena_dataset.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# title: ENA dataset utility functions
|
2 |
+
# author: Taewook Kang, Kyubyung Kang
|
3 |
+
# date: 2024.3.27
|
4 |
+
# license: MIT
|
5 |
+
# reference: https://pyautocad.readthedocs.io/en/latest/_modules/pyautocad/api.html
|
6 |
+
# version
|
7 |
+
# 0.1. 2024.3.27. create file
|
8 |
+
#
|
9 |
+
import json, os, re, logging, numpy as np
|
10 |
+
from transformers import BertTokenizer
|
11 |
+
|
12 |
+
def load_train_chunk_data(data_dir, sort_fname=False):
|
13 |
+
geom_list = []
|
14 |
+
fnames = os.listdir(data_dir)
|
15 |
+
if sort_fname:
|
16 |
+
fnames.sort(key=lambda x: int(re.search(r'\d+', x).group()))
|
17 |
+
xsec_count = 0
|
18 |
+
for file_name in fnames:
|
19 |
+
if file_name.endswith('.json') == False:
|
20 |
+
continue
|
21 |
+
with open(os.path.join(data_dir, file_name), 'r') as f:
|
22 |
+
chunk = json.load(f)
|
23 |
+
for xsec in chunk:
|
24 |
+
xsec_count += 1
|
25 |
+
geom = xsec['geom']
|
26 |
+
for g in geom:
|
27 |
+
g['station'] = xsec['station']
|
28 |
+
features = g['earthwork_feature']
|
29 |
+
if len(features) == 0:
|
30 |
+
continue
|
31 |
+
geom_list.append(g)
|
32 |
+
print(f'Loaded {xsec_count} cross sections')
|
33 |
+
return geom_list
|
34 |
+
|
35 |
+
def update_feature_dims_token(geom_list):
|
36 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) # Load the BERT tokenizer
|
37 |
+
|
38 |
+
feature_dims = []
|
39 |
+
max_token = 0
|
40 |
+
padding_token_id = tokenizer.add_tokens(['padding'])
|
41 |
+
for geom in geom_list:
|
42 |
+
label = geom['label']
|
43 |
+
geom['feature_dims'] = []
|
44 |
+
for feature in geom['earthwork_feature']:
|
45 |
+
# token = tokenizer.tokenize(feature)
|
46 |
+
token_ids = tokenizer.convert_tokens_to_ids(feature)
|
47 |
+
geom['feature_dims'].append(token_ids)
|
48 |
+
|
49 |
+
word, count = extract_word_and_count(feature)
|
50 |
+
if word in tokens:
|
51 |
+
continue
|
52 |
+
feature_dims.append(word)
|
53 |
+
|
54 |
+
max_token = max(max_token, len(geom['feature_dims']))
|
55 |
+
|
56 |
+
for geom in geom_list:
|
57 |
+
label = geom['label']
|
58 |
+
geom['feature_dims'] += [padding_token_id] * (max_token - len(geom['feature_dims']))
|
59 |
+
|
60 |
+
print(f'Max token length: {max_token}')
|
61 |
+
return feature_dims
|
62 |
+
|
63 |
+
def extract_word_and_count(s):
|
64 |
+
match = re.match(r'(\w+)(?:\((\d+)\))?', s)
|
65 |
+
if match:
|
66 |
+
word, count = match.groups()
|
67 |
+
count = int(count) if count else 1
|
68 |
+
return word, count
|
69 |
+
|
70 |
+
return None, None
|
71 |
+
|
72 |
+
def update_feature_dims_freq(geom_list, augument=False):
|
73 |
+
feature_dims = []
|
74 |
+
for geom in geom_list:
|
75 |
+
label = geom['label']
|
76 |
+
geom['feature_dims'] = []
|
77 |
+
for feature in geom['earthwork_feature']:
|
78 |
+
word, count = extract_word_and_count(feature)
|
79 |
+
if word is None or count is None:
|
80 |
+
continue
|
81 |
+
if word in feature_dims:
|
82 |
+
continue
|
83 |
+
feature_dims.append(word)
|
84 |
+
|
85 |
+
feature_dims.sort()
|
86 |
+
|
87 |
+
max_feature_dims_count = [0.0] * len(feature_dims)
|
88 |
+
for geom in geom_list:
|
89 |
+
label = geom['label']
|
90 |
+
geom['feature_dims'] = [0.0] * len(feature_dims)
|
91 |
+
geom['feature_text'] = ''
|
92 |
+
# geom['feature_angle_dims'] = [0.0] * len(feature_dims)
|
93 |
+
|
94 |
+
for feature in geom['earthwork_feature']:
|
95 |
+
word, count = extract_word_and_count(feature)
|
96 |
+
if word is None or count is None:
|
97 |
+
continue
|
98 |
+
geom['feature_text'] += f'{word}({count}) '
|
99 |
+
index = feature_dims.index(word)
|
100 |
+
|
101 |
+
geom['feature_dims'][index] = count
|
102 |
+
max_feature_dims_count[index] = max(max_feature_dims_count[index], count)
|
103 |
+
|
104 |
+
# normalize feature_dims by usng max_feature_dims_count
|
105 |
+
for geom in geom_list:
|
106 |
+
label = geom['label']
|
107 |
+
for i in range(len(geom['feature_dims'])):
|
108 |
+
geom['feature_dims'][i] /= max_feature_dims_count[i]
|
109 |
+
|
110 |
+
# augument feature_dims dataset
|
111 |
+
if augument:
|
112 |
+
for geom in geom_list:
|
113 |
+
label = geom['label']
|
114 |
+
geom['feature_dims_aug'] = []
|
115 |
+
for i in range(len(geom['feature_dims'])):
|
116 |
+
geom['feature_dims_aug'].append(geom['feature_dims'][i])
|
117 |
+
geom['feature_dims_aug'].append(geom['feature_dims'][i] * geom['feature_dims'][i])
|
118 |
+
|
119 |
+
print(f'feature dims({len(feature_dims)}): {feature_dims}')
|
120 |
+
return feature_dims
|
121 |
+
|
122 |
+
def update_onehot_encoding(geom_list):
|
123 |
+
label_kinds = []
|
124 |
+
for geom in geom_list:
|
125 |
+
label = geom['label']
|
126 |
+
if label not in label_kinds:
|
127 |
+
label_kinds.append(label)
|
128 |
+
|
129 |
+
from collections import Counter # from sklearn.preprocessing import OneHotEncoder
|
130 |
+
for geom in geom_list: # count label's kind of train_labels. Initialize the one-hot encoder
|
131 |
+
label = geom['label']
|
132 |
+
|
133 |
+
label_counts = Counter(label_kinds)
|
134 |
+
onehot = np.zeros(len(label_kinds))
|
135 |
+
onehot[label_kinds.index(label)] = 1.0
|
136 |
+
geom['label_onehot'] = onehot
|
137 |
+
return label_kinds
|
ena_run_model.py
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# title: ENA model runner
|
2 |
+
# author: Taewook Kang, Kyubyung Kang
|
3 |
+
# date: 2024.3.27
|
4 |
+
# description: ENA model test and evaluation
|
5 |
+
# license: MIT
|
6 |
+
# version
|
7 |
+
# 0.1. 2024.3.27. create file
|
8 |
+
#
|
9 |
+
import json, os, re, logging
|
10 |
+
import torch, torch.nn as nn, torch.optim as optim, numpy as np, matplotlib.pyplot as plt, seaborn as sns
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
|
13 |
+
from torch.utils.tensorboard import SummaryWriter
|
14 |
+
from torch.utils.data import Dataset, DataLoader
|
15 |
+
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, BertModel
|
16 |
+
from sklearn.metrics import confusion_matrix
|
17 |
+
from collections import defaultdict
|
18 |
+
from datetime import datetime
|
19 |
+
from tqdm import tqdm
|
20 |
+
from ena_dataset import load_train_chunk_data, update_feature_dims_freq, update_onehot_encoding
|
21 |
+
|
22 |
+
# write log file using logger
|
23 |
+
logging.basicConfig(filename= './ewnet_logs.txt', level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s', datefmt='%Y%m%d %H:%M')
|
24 |
+
logger = logging.getLogger('ewnet')
|
25 |
+
|
26 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
27 |
+
print(f'device: {device}')
|
28 |
+
|
29 |
+
# param
|
30 |
+
hyperparam = None
|
31 |
+
|
32 |
+
# train model
|
33 |
+
class EarthworkNetMLP(nn.Module):
|
34 |
+
def __init__(self, input_dim, hidden_dim, output_dim, dropout_ratio=0.2):
|
35 |
+
super(EarthworkNetMLP, self).__init__()
|
36 |
+
|
37 |
+
models = []
|
38 |
+
models.append(nn.Linear(input_dim, hidden_dim[0]))
|
39 |
+
models.append(nn.ReLU())
|
40 |
+
models.append(nn.BatchNorm1d(hidden_dim[0])) # Batch normalization after activation
|
41 |
+
models.append(nn.Dropout(dropout_ratio))
|
42 |
+
|
43 |
+
for i in range(1, len(hidden_dim)):
|
44 |
+
models.append(nn.Linear(hidden_dim[i-1], hidden_dim[i]))
|
45 |
+
models.append(nn.ReLU())
|
46 |
+
models.append(nn.BatchNorm1d(hidden_dim[i]))
|
47 |
+
models.append(nn.Dropout(dropout_ratio))
|
48 |
+
|
49 |
+
models.append(nn.Linear(hidden_dim[-1], output_dim))
|
50 |
+
self.layers = nn.Sequential(*models)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
# print("Shape of x:", x.shape)
|
54 |
+
x = self.layers(x)
|
55 |
+
return x
|
56 |
+
|
57 |
+
# train model using LSTM
|
58 |
+
class EarthworkNetLSTM(nn.Module):
|
59 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, dropout_ratio=0.2):
|
60 |
+
super(EarthworkNetLSTM, self).__init__()
|
61 |
+
|
62 |
+
# sequence series data. ex) token pattern(slope angle). top(0.5), bottom(0.5), top(0.6), bottom(0.6)...
|
63 |
+
# time series features = (token_type, curve_angle)
|
64 |
+
# label = (label_onehot)
|
65 |
+
models = []
|
66 |
+
|
67 |
+
models.append(nn.LSTM(input_dim, hidden_dim[0], num_layers, batch_first=True, dropout=dropout_ratio))
|
68 |
+
for i in range(1, len(hidden_dim)):
|
69 |
+
models.append(nn.Linear(hidden_dim[i-1], hidden_dim[i]))
|
70 |
+
|
71 |
+
models.append(nn.Linear(hidden_dim[-1], output_dim))
|
72 |
+
self.layers = nn.Sequential(*models)
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
# print("Shape of x:", x.shape)
|
76 |
+
for layer in self.layers:
|
77 |
+
if type(layer) == torch.nn.modules.rnn.LSTM:
|
78 |
+
x, _ = layer(x)
|
79 |
+
else:
|
80 |
+
x = layer(x)
|
81 |
+
|
82 |
+
return x
|
83 |
+
|
84 |
+
# create dataset. earthwork_feature -> label
|
85 |
+
class EarthworkDataset(Dataset):
|
86 |
+
def __init__(self, raw_data):
|
87 |
+
self.raw_dataset = raw_data
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
return len(self.raw_dataset)
|
91 |
+
|
92 |
+
def __getitem__(self, idx):
|
93 |
+
# origin_data = self.raw_dataset[idx]
|
94 |
+
features = self.raw_dataset[idx]['feature_dims'] # already, tokenized from 'feature_text'
|
95 |
+
label = self.raw_dataset[idx]['label_onehot']
|
96 |
+
features = torch.tensor(features, dtype=torch.float32).to(device)
|
97 |
+
label = torch.tensor(label, dtype=torch.float32).to(device)
|
98 |
+
return features, label
|
99 |
+
|
100 |
+
def decode_data_to_geom(input_dataset, predictions, labels, input_feature_dims, label_kinds):
|
101 |
+
global hyperparam
|
102 |
+
match_count = 0
|
103 |
+
for i in range(len(input_dataset)): # batch size
|
104 |
+
input_geom_features = input_dataset[i].cpu().numpy()
|
105 |
+
prediction_index = predictions[i].item()
|
106 |
+
label_index = labels[i].cpu().numpy()
|
107 |
+
|
108 |
+
geom_feautres = []
|
109 |
+
for j in range(len(input_feature_dims)):
|
110 |
+
if input_geom_features[j] == 0.0:
|
111 |
+
continue
|
112 |
+
geom_feautres.append(f'{input_feature_dims[j]}({input_geom_features[j]:.2f})')
|
113 |
+
|
114 |
+
prediction_label = label_kinds[prediction_index]
|
115 |
+
label = label_kinds[label_index]
|
116 |
+
|
117 |
+
match = prediction_label == label
|
118 |
+
if match:
|
119 |
+
match_count += 1
|
120 |
+
logger.debug(f'{hyperparam["model"]} {hyperparam["hidden_dim"]} Equal : {prediction_label == label}, Label: {label}, Predicted: {prediction_label}, Geom: {geom_feautres}')
|
121 |
+
|
122 |
+
return match_count
|
123 |
+
|
124 |
+
def test_mlp_model(model, batch_size, test_raw_dataset, input_feature_dims, label_kinds):
|
125 |
+
print(f'test data count: {len(test_raw_dataset)}')
|
126 |
+
test_dataset = EarthworkDataset(test_raw_dataset)
|
127 |
+
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
|
128 |
+
|
129 |
+
# test model
|
130 |
+
accuracies = []
|
131 |
+
rmse = 0.0
|
132 |
+
correct = 0
|
133 |
+
total = 0
|
134 |
+
total_match = 0
|
135 |
+
with torch.no_grad():
|
136 |
+
for i, (data, labels) in enumerate(test_dataloader):
|
137 |
+
outputs = model(data)
|
138 |
+
_, predicted = torch.max(outputs.data, 1)
|
139 |
+
_, labels = torch.max(labels.data, 1)
|
140 |
+
total += labels.size(0)
|
141 |
+
correct += (predicted == labels).sum().item()
|
142 |
+
accuracies.append(correct / total)
|
143 |
+
|
144 |
+
match_count = decode_data_to_geom(data, predicted, labels, input_feature_dims, label_kinds)
|
145 |
+
total_match += match_count
|
146 |
+
|
147 |
+
average_accuracy = correct / total
|
148 |
+
print(f'Match count: {total_match}, Total count: {total}')
|
149 |
+
print(f'Accuracy of the network on the test data: {average_accuracy:.4f}')
|
150 |
+
return accuracies, average_accuracy
|
151 |
+
|
152 |
+
def run_MLP_LSTM(model_file_list, base_model):
|
153 |
+
global hyperparam
|
154 |
+
|
155 |
+
# prepare train dataset
|
156 |
+
data_dir = './dataset'
|
157 |
+
geom_list = load_train_chunk_data(data_dir)
|
158 |
+
input_feature_dims = update_feature_dims_freq(geom_list) # input_feature_dims = update_feature_dims_token(geom_list)
|
159 |
+
label_kinds = update_onehot_encoding(geom_list)
|
160 |
+
|
161 |
+
train_raw_dataset = geom_list[:int(len(geom_list) * 0.8)]
|
162 |
+
test_raw_dataset = geom_list[int(len(geom_list) * 0.8):]
|
163 |
+
print(f'total data count: {len(geom_list)}')
|
164 |
+
print(f'train data count: {len(train_raw_dataset)}, test data count: {len(test_raw_dataset)}')
|
165 |
+
|
166 |
+
# train model and write it
|
167 |
+
param_layers = [[128], [128, 64, 32], [256, 128, 64]]
|
168 |
+
if base_model == 'MLP':
|
169 |
+
param_layers = [[128, 64, 32], [64, 128, 64], [64, 128, 64, 32], [32, 64, 32]]
|
170 |
+
for index, param_layer in enumerate(param_layers):
|
171 |
+
logger.debug(f'model : {base_model}')
|
172 |
+
|
173 |
+
params = {
|
174 |
+
'model': base_model,
|
175 |
+
'input_dim': len(input_feature_dims),
|
176 |
+
'hidden_dim': param_layer, # 0.95, [128, 64, 32],
|
177 |
+
'output_dim': len(label_kinds),
|
178 |
+
'batch_size': 32,
|
179 |
+
'epochs': 150, # 150, # 5000
|
180 |
+
'lr': 0.001
|
181 |
+
}
|
182 |
+
hyperparam = params
|
183 |
+
# create train model
|
184 |
+
model = EarthworkNetMLP(params['input_dim'], params['hidden_dim'], params['output_dim']).to(device)
|
185 |
+
if base_model == 'LSTM':
|
186 |
+
model = EarthworkNetLSTM(params['input_dim'], params['hidden_dim'], params['output_dim']).to(device)
|
187 |
+
model_file = './' + model_file_list[index]
|
188 |
+
model.load_state_dict(torch.load(model_file))
|
189 |
+
model.eval()
|
190 |
+
|
191 |
+
accuracies, acc = test_mlp_model(model, params['batch_size'], test_raw_dataset, input_feature_dims, label_kinds)
|
192 |
+
|
193 |
+
# Generate random training data
|
194 |
+
def generate_random_text(label_index, length=100):
|
195 |
+
base_text = f'This is text for label R{label_index + 1}. '
|
196 |
+
random_text_length = max(0, length - len(base_text)) # Calculate the length of the random text to generate
|
197 |
+
random_text = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(random_text_length)) # Generate the random text
|
198 |
+
return base_text + random_text
|
199 |
+
|
200 |
+
# Define dataset class
|
201 |
+
class EarthworkTransformDataset(Dataset):
|
202 |
+
def __init__(self, input_ids, attention_mask, labels):
|
203 |
+
self.input_ids = input_ids
|
204 |
+
self.attention_mask = attention_mask
|
205 |
+
self.labels = labels
|
206 |
+
|
207 |
+
def __len__(self):
|
208 |
+
return len(self.input_ids)
|
209 |
+
|
210 |
+
def __getitem__(self, idx):
|
211 |
+
input_ids_tensor = torch.tensor(self.input_ids[idx]).to(device)
|
212 |
+
attention_mask_tensor = torch.tensor(self.attention_mask[idx]).to(device)
|
213 |
+
label_tensor = torch.tensor(self.labels[idx]).to(device)
|
214 |
+
return input_ids_tensor, attention_mask_tensor, label_tensor
|
215 |
+
|
216 |
+
# custom transformer
|
217 |
+
class PositionalEncoding(nn.Module):
|
218 |
+
def __init__(self, d_model, vocab_size=5000, dropout=0.1):
|
219 |
+
super().__init__()
|
220 |
+
self.dropout = nn.Dropout(p=dropout)
|
221 |
+
|
222 |
+
pe = torch.zeros(vocab_size, d_model)
|
223 |
+
position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)
|
224 |
+
div_term = torch.exp(
|
225 |
+
torch.arange(0, d_model, 2).float()
|
226 |
+
* (-math.log(10000.0) / d_model)
|
227 |
+
)
|
228 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
229 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
230 |
+
pe = pe.unsqueeze(0)
|
231 |
+
self.register_buffer("pe", pe)
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
x = x + self.pe[:, : x.size(1), :]
|
235 |
+
return self.dropout(x)
|
236 |
+
|
237 |
+
class EarthworkNetTransformer(nn.Module):
|
238 |
+
def __init__(
|
239 |
+
self,
|
240 |
+
input_feature_size,
|
241 |
+
d_model,
|
242 |
+
num_labels,
|
243 |
+
nhead=8,
|
244 |
+
dim_feedforward=2048,
|
245 |
+
dim_fc=[64, 32],
|
246 |
+
num_layers=6,
|
247 |
+
dropout=0.1,
|
248 |
+
activation="relu",
|
249 |
+
classifier_dropout=0.1,
|
250 |
+
):
|
251 |
+
super().__init__()
|
252 |
+
|
253 |
+
self.d_model = d_model
|
254 |
+
# self.pos_encoder = PositionalEncoding(d_model=d_model, dropout=dropout, vocab_size=vocab_size)
|
255 |
+
|
256 |
+
self.input_fc = nn.Linear(input_feature_size, d_model)
|
257 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
258 |
+
d_model=d_model,
|
259 |
+
nhead=nhead,
|
260 |
+
dim_feedforward=dim_feedforward,
|
261 |
+
dropout=dropout
|
262 |
+
)
|
263 |
+
|
264 |
+
self.src_mask = None
|
265 |
+
self.nhead = nhead
|
266 |
+
self.transformer_encoder = nn.TransformerEncoder(
|
267 |
+
encoder_layer,
|
268 |
+
num_layers=num_layers,
|
269 |
+
# TBD. output_attentions=True
|
270 |
+
)
|
271 |
+
self.fc_layers = []
|
272 |
+
fc_layers_dims = [d_model] + dim_fc + [num_labels]
|
273 |
+
for i in range(1, len(fc_layers_dims)):
|
274 |
+
fc = nn.Linear(fc_layers_dims[i-1], fc_layers_dims[i]).to(device)
|
275 |
+
self.fc_layers.append(fc)
|
276 |
+
|
277 |
+
self.init_weights()
|
278 |
+
|
279 |
+
def generate_square_subsequent_mask(self, sz):
|
280 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
281 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
282 |
+
return mask
|
283 |
+
|
284 |
+
def init_weights(self):
|
285 |
+
initrange = 0.1
|
286 |
+
for fc in self.fc_layers:
|
287 |
+
fc.bias.data.zero_()
|
288 |
+
fc.weight.data.uniform_(-initrange, initrange)
|
289 |
+
|
290 |
+
def forward(self, x, attention_mask):
|
291 |
+
# x = self.pos_encoder(x)
|
292 |
+
if self.src_mask is None or self.src_mask.size(0) != len(x):
|
293 |
+
device = x.device
|
294 |
+
mask = self.generate_square_subsequent_mask(len(x)).to(device)
|
295 |
+
self.src_mask = mask
|
296 |
+
# batch_size = x.shape[0]
|
297 |
+
# mask = torch.tril(torch.ones(self.nhead, batch_size, batch_size)).to(x.device)
|
298 |
+
|
299 |
+
x = x.float()
|
300 |
+
x = self.input_fc(x)
|
301 |
+
x = self.transformer_encoder(x, mask=self.src_mask) # , src_key_padding_mask=attention_mask1) # , mask=attention_mask)
|
302 |
+
# x = x.mean(dim=1)
|
303 |
+
for fc in self.fc_layers:
|
304 |
+
x = fc(x)
|
305 |
+
|
306 |
+
return x
|
307 |
+
|
308 |
+
def run_transform(model_file_list):
|
309 |
+
data_dir = './dataset'
|
310 |
+
geom_list = load_train_chunk_data(data_dir)
|
311 |
+
input_feature_dims = update_feature_dims_freq(geom_list) # input_feature_dims = update_feature_dims_token(geom_list)
|
312 |
+
label_kinds = update_onehot_encoding(geom_list)
|
313 |
+
num_labels = len(label_kinds)
|
314 |
+
max_input_string = max(len(d['feature_text']) for d in geom_list)
|
315 |
+
max_input_string = 320 # nhead=8. 320=8*40
|
316 |
+
|
317 |
+
train_raw_dataset = geom_list[:int(len(geom_list) * 0.8)]
|
318 |
+
test_raw_dataset = geom_list[int(len(geom_list) * 0.8):]
|
319 |
+
print(f'total data count: {len(geom_list)}')
|
320 |
+
print(f'train data count: {len(train_raw_dataset)}, test data count: {len(test_raw_dataset)}')
|
321 |
+
|
322 |
+
# Tokenize and pad sequences
|
323 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
324 |
+
max_length = max_input_string
|
325 |
+
|
326 |
+
batch_sizes = [32, 64, 128]
|
327 |
+
for index, batch_size in enumerate(batch_sizes):
|
328 |
+
encoding = {'input_ids': [], 'attention_mask': []}
|
329 |
+
for d in train_raw_dataset:
|
330 |
+
token_text = tokenizer(d['feature_text'], padding='max_length', truncation=True, max_length=max_length)
|
331 |
+
if len(token_text['input_ids']) < max_length: # fill the rest with padding token
|
332 |
+
token_text['input_ids'] += [tokenizer.pad_token_id] * (max_length - len(token_text['input_ids']))
|
333 |
+
token_text['attention_mask'] += [0] * (max_length - len(token_text['attention_mask']))
|
334 |
+
encoding['input_ids'].append(token_text['input_ids'])
|
335 |
+
encoding['attention_mask'].append(token_text['attention_mask'])
|
336 |
+
|
337 |
+
input_ids = encoding['input_ids']
|
338 |
+
attention_mask = encoding['attention_mask']
|
339 |
+
|
340 |
+
label2id = {label: i for i, label in enumerate(sorted(set(d['label'] for d in train_raw_dataset)))}
|
341 |
+
id2label = {v: k for k, v in label2id.items()}
|
342 |
+
labels = [label2id[d['label']] for d in train_raw_dataset] # Convert labels to numerical format
|
343 |
+
|
344 |
+
# hyperparameters
|
345 |
+
logger.debug(f'model : transformer')
|
346 |
+
|
347 |
+
params = {
|
348 |
+
'model': 'transformer',
|
349 |
+
'input_dim': len(input_feature_dims),
|
350 |
+
'hidden_dim': [64],
|
351 |
+
'output_dim': len(label2id),
|
352 |
+
'batch_size': batch_size,
|
353 |
+
'epochs': 300,
|
354 |
+
'lr': 1e-5
|
355 |
+
}
|
356 |
+
|
357 |
+
# batch_size = params['batch_size'] # 32, 64, 128
|
358 |
+
dim_fc = params['hidden_dim']
|
359 |
+
epochs = params['epochs'] # 5000 # 500 150
|
360 |
+
|
361 |
+
# model
|
362 |
+
model = EarthworkNetTransformer(input_feature_size=max_length, d_model=512, num_labels=len(label2id), dim_fc=dim_fc).to(device)
|
363 |
+
dataset = EarthworkTransformDataset(input_ids, attention_mask, labels)
|
364 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
365 |
+
|
366 |
+
# test the model
|
367 |
+
model_file = './' + model_file_list[index]
|
368 |
+
model.load_state_dict(torch.load(model_file))
|
369 |
+
model.eval()
|
370 |
+
|
371 |
+
for i, test_raw in enumerate(test_raw_dataset):
|
372 |
+
label = test_raw['label']
|
373 |
+
input_text = test_raw['feature_text']
|
374 |
+
encoding = tokenizer(input_text, return_tensors='pt', padding='max_length', truncation=True, max_length=max_length)
|
375 |
+
input_ids = encoding['input_ids'].to(device)
|
376 |
+
attention_mask = encoding['attention_mask'].to(device)
|
377 |
+
output = model(input_ids, attention_mask)
|
378 |
+
predicted_label = id2label[output.argmax().item()]
|
379 |
+
|
380 |
+
feature_dims = input_text.split(' ')
|
381 |
+
logger.debug(f'{params["model"]} {params["batch_size"]} Equal : {predicted_label == label}, Label: {label}, Predicted: {predicted_label}, Geom: {feature_dims}')
|
382 |
+
|
383 |
+
|
384 |
+
print(f'test data count: {len(test_raw_dataset)}')
|
385 |
+
encoding = tokenizer([d['feature_text'] for d in test_raw_dataset], padding='max_length', truncation=True, max_length=max_length)
|
386 |
+
input_ids = encoding['input_ids']
|
387 |
+
attention_mask = encoding['attention_mask']
|
388 |
+
|
389 |
+
label2id = {label: i for i, label in enumerate(sorted(set(d['label'] for d in test_raw_dataset)))}
|
390 |
+
id2label = {v: k for k, v in label2id.items()}
|
391 |
+
labels = [label2id[d['label']] for d in test_raw_dataset] # Convert labels to numerical format
|
392 |
+
|
393 |
+
test_dataset = EarthworkTransformDataset(input_ids, attention_mask, labels)
|
394 |
+
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
|
395 |
+
|
396 |
+
correct = 0
|
397 |
+
total = 0
|
398 |
+
accuracies = []
|
399 |
+
with torch.no_grad():
|
400 |
+
for i, (input_ids, attention_mask, labels) in enumerate(tqdm(test_dataloader, desc="test")):
|
401 |
+
outputs = model(input_ids, attention_mask)
|
402 |
+
_, predicted = torch.max(outputs, 1)
|
403 |
+
total += len(labels)
|
404 |
+
correct += (predicted == labels).sum().item()
|
405 |
+
accuracies.append(correct / total)
|
406 |
+
|
407 |
+
average_accuracy = correct / total
|
408 |
+
print(f'Accuracy of the network on the test data: {average_accuracy:.4f}')
|
409 |
+
|
410 |
+
# BERT model
|
411 |
+
class EarthworkBertDataset(Dataset):
|
412 |
+
def __init__(self, input_ids, attention_mask, labels):
|
413 |
+
self.input_ids = input_ids
|
414 |
+
self.attention_mask = attention_mask
|
415 |
+
self.labels = labels
|
416 |
+
|
417 |
+
def __len__(self):
|
418 |
+
return len(self.input_ids)
|
419 |
+
|
420 |
+
def __getitem__(self, idx):
|
421 |
+
input_ids_tensor = torch.tensor(self.input_ids[idx]).to(device)
|
422 |
+
attention_mask_tensor = torch.tensor(self.attention_mask[idx]).to(device)
|
423 |
+
label_tensor = torch.tensor(self.labels[idx]).to(device)
|
424 |
+
return input_ids_tensor, attention_mask_tensor, label_tensor
|
425 |
+
|
426 |
+
# Define EarthworkNetTransformer model architecture
|
427 |
+
class EarthworkNetTransformerBert(torch.nn.Module):
|
428 |
+
def __init__(self, num_labels):
|
429 |
+
super(EarthworkNetTransformerBert, self).__init__()
|
430 |
+
self.bert = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels, output_attentions=True)
|
431 |
+
|
432 |
+
def forward(self, input_ids, attention_mask):
|
433 |
+
outputs = self.bert(input_ids, attention_mask=attention_mask)
|
434 |
+
return outputs['logits'], outputs['attentions']
|
435 |
+
|
436 |
+
def run_bert(model_file):
|
437 |
+
# prepare train dataset
|
438 |
+
data_dir = './dataset'
|
439 |
+
geom_list = load_train_chunk_data(data_dir)
|
440 |
+
input_feature_dims = update_feature_dims_freq(geom_list) # input_feature_dims = update_feature_dims_token(geom_list)
|
441 |
+
label_kinds = update_onehot_encoding(geom_list)
|
442 |
+
num_labels = len(label_kinds)
|
443 |
+
max_input_string = max(len(d['feature_text']) for d in geom_list)
|
444 |
+
|
445 |
+
train_raw_dataset = geom_list[:int(len(geom_list) * 0.8)]
|
446 |
+
test_raw_dataset = geom_list[int(len(geom_list) * 0.8):]
|
447 |
+
print(f'total data count: {len(geom_list)}')
|
448 |
+
print(f'train data count: {len(train_raw_dataset)}, test data count: {len(test_raw_dataset)}')
|
449 |
+
|
450 |
+
# Tokenize and pad sequences
|
451 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
452 |
+
max_length = max_input_string
|
453 |
+
|
454 |
+
encoding = tokenizer([d['feature_text'] for d in train_raw_dataset], padding=True, truncation=True, max_length=max_length)
|
455 |
+
input_ids = encoding['input_ids'] # TBD. shape is 50?
|
456 |
+
attention_mask = encoding['attention_mask']
|
457 |
+
|
458 |
+
label2id = {label: i for i, label in enumerate(sorted(set(d['label'] for d in train_raw_dataset)))}
|
459 |
+
id2label = {v: k for k, v in label2id.items()}
|
460 |
+
labels = [label2id[d['label']] for d in train_raw_dataset] # Convert labels to numerical format
|
461 |
+
|
462 |
+
# Initialize model
|
463 |
+
model = EarthworkNetTransformerBert(num_labels=len(label2id)).to(device)
|
464 |
+
|
465 |
+
epochs = 150 # 50 #
|
466 |
+
batch_size = 32
|
467 |
+
params = {
|
468 |
+
'model': 'BERT',
|
469 |
+
'input_dim': len(input_feature_dims),
|
470 |
+
'hidden_dim': 512,
|
471 |
+
'output_dim': len(label2id),
|
472 |
+
'batch_size': batch_size,
|
473 |
+
'epochs': epochs,
|
474 |
+
'lr': 1e-5,
|
475 |
+
}
|
476 |
+
|
477 |
+
dataset = EarthworkBertDataset(input_ids, attention_mask, labels)
|
478 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
479 |
+
|
480 |
+
# test the model
|
481 |
+
logger.debug(f'model : bert')
|
482 |
+
|
483 |
+
model_file = './' + model_file
|
484 |
+
model.load_state_dict(torch.load(model_file))
|
485 |
+
model.eval()
|
486 |
+
|
487 |
+
for i, test_raw in enumerate(test_raw_dataset):
|
488 |
+
label = test_raw['label']
|
489 |
+
input_text = test_raw['feature_text']
|
490 |
+
encoding = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
|
491 |
+
input_ids = encoding['input_ids'].to(device)
|
492 |
+
attention_mask = encoding['attention_mask'].to(device)
|
493 |
+
output, att = model(input_ids, attention_mask)
|
494 |
+
predicted_label = id2label[output.argmax().item()]
|
495 |
+
|
496 |
+
feature_dims = input_text.split(' ')
|
497 |
+
logger.debug(f'{params["model"]} Equal : {predicted_label == label}, Label: {label}, Predicted: {predicted_label}, Geom: {feature_dims}')
|
498 |
+
|
499 |
+
attention_matrix = att[-1]
|
500 |
+
attention_layer = attention_matrix[-1]
|
501 |
+
attention_mat = attention_layer[-1]
|
502 |
+
# for j, attention_mat in enumerate(attention_layer):
|
503 |
+
att_mat = attention_mat.detach().cpu().numpy()
|
504 |
+
fig, ax = plt.subplots()
|
505 |
+
cax = ax.matshow(att_mat, cmap='viridis')
|
506 |
+
fig.colorbar(cax)
|
507 |
+
plt.savefig(f'./graph/bert_attention_{i}.png')
|
508 |
+
plt.close()
|
509 |
+
|
510 |
+
print(f'test data count: {len(test_raw_dataset)}')
|
511 |
+
encoding = tokenizer([d['feature_text'] for d in test_raw_dataset], padding=True, truncation=True, max_length=max_length)
|
512 |
+
input_ids = encoding['input_ids']
|
513 |
+
attention_mask = encoding['attention_mask']
|
514 |
+
|
515 |
+
label2id = {label: i for i, label in enumerate(sorted(set(d['label'] for d in test_raw_dataset)))}
|
516 |
+
id2label = {v: k for k, v in label2id.items()}
|
517 |
+
labels = [label2id[d['label']] for d in test_raw_dataset] # Convert labels to numerical format
|
518 |
+
|
519 |
+
test_dataset = EarthworkBertDataset(input_ids, attention_mask, labels)
|
520 |
+
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
|
521 |
+
|
522 |
+
correct = 0
|
523 |
+
total = 0
|
524 |
+
accuracies = []
|
525 |
+
with torch.no_grad():
|
526 |
+
for i, (input_ids, attention_mask, labels) in enumerate(tqdm(test_dataloader, desc="test")):
|
527 |
+
outputs, att = model(input_ids, attention_mask)
|
528 |
+
_, predicted = torch.max(outputs, 1)
|
529 |
+
total += len(labels)
|
530 |
+
correct += (predicted == labels).sum().item()
|
531 |
+
accuracies.append(correct / total)
|
532 |
+
y_score = torch.nn.functional.softmax(outputs, dim=1)
|
533 |
+
|
534 |
+
average_accuracy = correct / total
|
535 |
+
print(f'Accuracy of the network on the test data: {average_accuracy:.4f}')
|
536 |
+
|
537 |
+
|
538 |
+
if __name__ == '__main__':
|
539 |
+
models = ['earthwork_model_20240503_1650.pth','earthwork_model_20240503_1714.pth','earthwork_model_20240503_1716.pth','earthwork_model_20240503_1718.pth']
|
540 |
+
run_MLP_LSTM(models, 'MLP')
|
541 |
+
|
542 |
+
models = ['earthwork_model_20240503_1730.pth','earthwork_model_20240503_1732.pth','earthwork_model_20240503_1734.pth']
|
543 |
+
run_MLP_LSTM(models, 'LSTM')
|
544 |
+
|
545 |
+
models = ['earthwork_trans_model_20240503_2003.pth','earthwork_trans_model_20240503_2014.pth','earthwork_trans_model_20240503_2021.pth']
|
546 |
+
run_transform(models)
|
547 |
+
|
548 |
+
run_bert('earthwork_trans_model_20240504_0103.pth')
|
eval_model.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, confusion_matrix
|
3 |
+
from sklearn.metrics import average_precision_score
|
4 |
+
|
5 |
+
# Assuming y_true and y_pred are your data
|
6 |
+
y_true = [0, 1, 1, 0, 1, 1]
|
7 |
+
y_pred = [0, 0, 1, 0, 0, 1]
|
8 |
+
|
9 |
+
# Assuming y_true and y_pred are your data
|
10 |
+
y_true = [[0, 1, 1], [0, 1, 1], [1, 0, 1]]
|
11 |
+
y_pred = [[0, 0, 1], [0, 0, 1], [1, 0, 0]]
|
12 |
+
|
13 |
+
class model_metrics:
|
14 |
+
def __init__(self):
|
15 |
+
self.clear()
|
16 |
+
|
17 |
+
def clear(self):
|
18 |
+
self.accuracy = 0.0
|
19 |
+
self.recall = 0.0
|
20 |
+
self.precision = 0.0
|
21 |
+
self.f1 = 0.0
|
22 |
+
self.mAP = 0.0
|
23 |
+
self.cm = np.asarray([])
|
24 |
+
|
25 |
+
self.count = 0
|
26 |
+
self.total_accuracy = 0.0
|
27 |
+
self.total_recall = 0.0
|
28 |
+
self.total_precision = 0.0
|
29 |
+
self.total_f1 = 0.0
|
30 |
+
self.total_mAP = 0.0
|
31 |
+
self.total_cm = np.asarray([])
|
32 |
+
|
33 |
+
def get_indicators(self):
|
34 |
+
return self.total_accuracy / self.count, self.total_recall / self.count, self.total_precision / self.count, self.total_f1 / self.count, self.total_mAP / self.count, self.total_cm / self.count
|
35 |
+
|
36 |
+
def dump(self):
|
37 |
+
print(f"Accuracy: {self.accuracy}")
|
38 |
+
print(f"Recall: {self.recall}")
|
39 |
+
print(f"Precision: {self.precision}")
|
40 |
+
print(f"F1 Score: {self.f1}")
|
41 |
+
print(f"mAP: {self.mAP}")
|
42 |
+
print(f"Confusion Matrix: \n{self.cm}")
|
43 |
+
|
44 |
+
print(f'average accuracy: {self.total_accuracy / self.count}')
|
45 |
+
print(f'average recall: {self.total_recall / self.count}')
|
46 |
+
print(f'average precision: {self.total_precision / self.count}')
|
47 |
+
print(f'average f1: {self.total_f1 / self.count}')
|
48 |
+
print(f'average mAP: {self.total_mAP / self.count}')
|
49 |
+
print(f'average confusion matrix: \n{self.total_cm / self.count}')
|
50 |
+
|
51 |
+
def calc_metrics(self, y_true, y_pred, y_score):
|
52 |
+
self.accuracy = accuracy_score(y_true, y_pred)
|
53 |
+
self.recall = recall_score(y_true, y_pred, average='weighted')
|
54 |
+
self.precision = precision_score(y_true, y_pred, average='micro')
|
55 |
+
self.cm = confusion_matrix(y_true, y_pred)
|
56 |
+
self.count += 1
|
57 |
+
|
58 |
+
self.total_accuracy += self.accuracy
|
59 |
+
self.total_recall += self.recall
|
60 |
+
self.total_precision += self.precision
|
61 |
+
self.total_f1 += self.f1
|
62 |
+
self.total_mAP += self.mAP
|
63 |
+
self.total_cm = self.cm # TBD
|
64 |
+
|
65 |
+
return self.accuracy, self.recall, self.precision, self.f1, self.mAP, self.cm
|
66 |
+
|
67 |
+
def calc_metrics_multi(self, y_true, y_pred):
|
68 |
+
self.accuracy = accuracy_score(y_true, y_pred)
|
69 |
+
self.recall = recall_score(y_true, y_pred, average='micro')
|
70 |
+
self.precision = precision_score(y_true, y_pred, average='micro')
|
71 |
+
self.f1 = f1_score(y_true, y_pred, average='micro')
|
72 |
+
self.mAP = average_precision_score(y_true, y_pred, average='micro')
|
73 |
+
self.count += 1
|
74 |
+
|
75 |
+
return self.accuracy, self.recall, self.precision, self.f1, self.mAP, self.cm
|
extract_ewlog.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, re, numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import matplotlib.pyplot as plt_xsec
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
input_log_file = './ewnet_logs_TRANS3_20240708.txt'
|
7 |
+
flag_all_xsections = True
|
8 |
+
prev_station = ''
|
9 |
+
|
10 |
+
now = datetime.now()
|
11 |
+
now_str = now.strftime('%Y%m%d_%H%M')
|
12 |
+
|
13 |
+
label_list = ['pave_layer1', 'pave_layer2', 'pave_layer3', 'pave_layer4', 'cut_ea', 'cut_rr', 'cut_br', 'cut_ditch', 'fill_subbed', 'fill_subbody', 'curb', 'above', 'below', 'pave_int', 'pave_surface', 'pave_subgrade', 'ground', 'pave_bottom', 'rr', 'br', 'slope', 'struct', 'steps']
|
14 |
+
color_list = [[0.8,0.8,0.8],[0.6,0.6,0.6],[0.4,0.4,0.4],[0.2,0.2,0.2],[0.8,0.4,0.2],[0.8,0.6,0.2],[0.8,0.8,0.2],[0.6,0.8,0.2],[0.3,0.8,0.3],[0.3,0.6,0.3],[0.3,0.4,0.3],[0.0,0.8,0.0],[0.6,0.0,0.0],[0.8,0.0,0.0],[1.0,0.0,0.0],[0.2,0.2,0.6],[0.0,1.0,0.0],[0.2,0.2,1.0],[0.4,0.2,1.0],[0.6,0.2,1.0],[0.2,0.8,0.6],[0.8,0.2,1.0],[1.0,0.2,1.0]]
|
15 |
+
|
16 |
+
# make folder
|
17 |
+
if not os.path.exists('./graph'):
|
18 |
+
os.makedirs('./graph')
|
19 |
+
|
20 |
+
def draw_colorbox_list():
|
21 |
+
global label_list, color_list
|
22 |
+
|
23 |
+
fig, ax = plt.subplots(figsize=(9.2, 5))
|
24 |
+
ax.invert_yaxis()
|
25 |
+
ax.set_xlim(0, 1.5)
|
26 |
+
fig.set_size_inches(12, 7)
|
27 |
+
|
28 |
+
token_list = ['item1', 'item2', 'item3', 'item4', 'item5', 'item6']
|
29 |
+
for i, (colname, color) in enumerate(zip(label_list, color_list)):
|
30 |
+
width = 1.0 / len(label_list)
|
31 |
+
widths = [width] * len(token_list)
|
32 |
+
starts = width * i
|
33 |
+
rects = ax.barh(token_list, widths, left=starts, height=0.5, label=colname, color=color)
|
34 |
+
|
35 |
+
text_color = 'white' if np.max(color) < 0.4 else 'black'
|
36 |
+
ax.legend()
|
37 |
+
plt.savefig('./graph/box_colors.png')
|
38 |
+
plt.close()
|
39 |
+
|
40 |
+
def output_graph_matrics(index, tag, text):
|
41 |
+
global label_list, color_list
|
42 |
+
|
43 |
+
prediction = ''
|
44 |
+
tokens = []
|
45 |
+
polyline = []
|
46 |
+
geom_index = text.find('Geom:')
|
47 |
+
if geom_index >= 0:
|
48 |
+
pred_label = ''
|
49 |
+
label_index = text.find('Predicted: ')
|
50 |
+
if label_index >= 0:
|
51 |
+
pred = text[label_index + 11:geom_index]
|
52 |
+
labels = pred.split(', ')
|
53 |
+
if len(labels) > 0:
|
54 |
+
prediction = labels[0]
|
55 |
+
pred_label = labels[0] + '(0.3'
|
56 |
+
|
57 |
+
polyline_index = text.find('Polyline:')
|
58 |
+
if polyline_index > 0:
|
59 |
+
pred = text[geom_index + 6:polyline_index - 2]
|
60 |
+
polyline_text = text[polyline_index + 10:]
|
61 |
+
polyline = eval(polyline_text)
|
62 |
+
else:
|
63 |
+
pred = text[geom_index + 6:]
|
64 |
+
pred = pred.replace('[', '').replace(']', '')
|
65 |
+
pred = pred.replace(')', '').replace("'", '')
|
66 |
+
tokens = pred.split(',')
|
67 |
+
if len(tokens) <= 1:
|
68 |
+
tokens = pred.split(' ')
|
69 |
+
if len(tokens) > 0:
|
70 |
+
tokens.insert(0, pred_label)
|
71 |
+
last = tokens[-1]
|
72 |
+
if len(last) == 0:
|
73 |
+
tokens.pop()
|
74 |
+
else:
|
75 |
+
return
|
76 |
+
|
77 |
+
token_list = [token.split('(')[0] for token in tokens]
|
78 |
+
token_list = [token.replace(' ', '') for token in token_list]
|
79 |
+
ratios = [float(token.split('(')[1]) for token in tokens]
|
80 |
+
results = {token_list[0]: ratios}
|
81 |
+
|
82 |
+
labels = [label.replace(" ", "") for label in list(results.keys())]
|
83 |
+
data = np.array(list(results.values()))
|
84 |
+
data_cum = data.cumsum(axis=1)
|
85 |
+
token_colors = [color_list[label_list.index(label)] for label in token_list]
|
86 |
+
|
87 |
+
global plt_xsec, now_str, flag_all_xsections
|
88 |
+
if flag_all_xsections == False:
|
89 |
+
fig, ax = plt.subplots(figsize=(9.2, 5))
|
90 |
+
ax.invert_yaxis()
|
91 |
+
ax.xaxis.set_visible(False)
|
92 |
+
ax.set_xlim(0, np.sum(data, axis=1).max())
|
93 |
+
fig.set_size_inches(15, 0.5)
|
94 |
+
|
95 |
+
for i, (colname, color) in enumerate(zip(token_list, token_colors)):
|
96 |
+
widths = data[:, i]
|
97 |
+
starts = data_cum[:, i] - widths
|
98 |
+
if i > 0:
|
99 |
+
starts += 0.02
|
100 |
+
rects = ax.barh(labels, widths, left=starts, height=0.5, label=colname, color=color)
|
101 |
+
|
102 |
+
if i != 0:
|
103 |
+
text_color = 'white' if np.max(color) < 0.4 else 'black'
|
104 |
+
ax.bar_label(rects, label_type='center', color=text_color)
|
105 |
+
ax.legend(ncols=len(token_list), bbox_to_anchor=(0, 1), loc='lower right', fontsize='small')
|
106 |
+
|
107 |
+
tag = tag.replace(' ', '_')
|
108 |
+
tag = tag.replace(':', '')
|
109 |
+
|
110 |
+
if text.find('True') > 0:
|
111 |
+
plt.savefig(f'./graph/box_list_{now_str}_{tag}_{index}_T.png')
|
112 |
+
else:
|
113 |
+
plt.savefig(f'./graph/box_list_{now_str}_{tag}_{index}_F.png')
|
114 |
+
plt.close()
|
115 |
+
else:
|
116 |
+
if polyline[0] != polyline[-1]:
|
117 |
+
polyline.append(polyline[0])
|
118 |
+
x, y = zip(*polyline)
|
119 |
+
color = color_list[label_list.index(prediction)]
|
120 |
+
|
121 |
+
plt_xsec.fill(x, y, color=color)
|
122 |
+
centroid_x = sum(x) / len(x)
|
123 |
+
centroid_y = sum(y) / len(y)
|
124 |
+
area = 0.5 * abs(sum(x[i]*y[i+1] - x[i+1]*y[i] for i in range(len(polyline)-1)))
|
125 |
+
|
126 |
+
if prediction.find('pave') < 0:
|
127 |
+
plt_xsec.text(centroid_x, centroid_y, f'{prediction}={area:.2f}', horizontalalignment='center', verticalalignment='center', fontsize=5, color='black')
|
128 |
+
|
129 |
+
return prediction, area, token_list
|
130 |
+
|
131 |
+
output_stations = ['4+440.00000', '3+780.00000', '3+800.00000', '3+880.00000', '3+940.00000']
|
132 |
+
def output_logs(tag, equal='none'):
|
133 |
+
global input_log_file, plt_xsec, now_str, prev_station, flag_all_xsection, output_stations
|
134 |
+
|
135 |
+
text_list = []
|
136 |
+
logs = []
|
137 |
+
|
138 |
+
with open(input_log_file, 'r') as file:
|
139 |
+
for index, label in enumerate(label_list):
|
140 |
+
file.seek(0)
|
141 |
+
for line in file:
|
142 |
+
if flag_all_xsections == False and line.find(tag) < 0:
|
143 |
+
continue
|
144 |
+
tag_model = tag.split(' ')[0]
|
145 |
+
if flag_all_xsections == True and line.find(tag_model) < 0:
|
146 |
+
continue
|
147 |
+
if flag_all_xsections == False and line.find('Label: ' + label) < 0:
|
148 |
+
continue
|
149 |
+
line = line.replace('\n', '')
|
150 |
+
if equal == 'none':
|
151 |
+
text_list.append(line)
|
152 |
+
elif line.find(equal) > 0:
|
153 |
+
text_list.append(line)
|
154 |
+
if flag_all_xsections == False:
|
155 |
+
break
|
156 |
+
if flag_all_xsections:
|
157 |
+
break
|
158 |
+
|
159 |
+
if len(text_list) == 0:
|
160 |
+
return logs
|
161 |
+
|
162 |
+
def extract_station(text):
|
163 |
+
sta_index = text.find('Station:') + 9 # Start of station value
|
164 |
+
end_index = text.find(',', sta_index)
|
165 |
+
return text[sta_index:end_index] if end_index != -1 else text[sta_index:]
|
166 |
+
|
167 |
+
text_list = sorted(text_list, key=extract_station)
|
168 |
+
station = ''
|
169 |
+
for index, text in enumerate(text_list):
|
170 |
+
sta_index = text.find('Station:')
|
171 |
+
equal_index = text.find('Equal: ')
|
172 |
+
equal_check = 'T' if text.find('True') > 0 else 'F'
|
173 |
+
|
174 |
+
if sta_index > 0 and equal_index > 0:
|
175 |
+
station = text[sta_index + 9:equal_index-2]
|
176 |
+
print(station)
|
177 |
+
|
178 |
+
try:
|
179 |
+
if len(output_stations) and output_stations.index(station) < 0:
|
180 |
+
continue
|
181 |
+
except Exception as e:
|
182 |
+
continue
|
183 |
+
|
184 |
+
if prev_station != station:
|
185 |
+
if len(prev_station) > 0:
|
186 |
+
plt_xsec.savefig(f'./graph/polygon_{now_str}_{tag}_{prev_station}_{equal_check}.png', dpi=300)
|
187 |
+
plt_xsec.close()
|
188 |
+
|
189 |
+
plt_xsec.figure()
|
190 |
+
plt_xsec.gca().set_xlim([-60, 60])
|
191 |
+
plt_xsec.gca().axis('equal')
|
192 |
+
plt_xsec.gca().text(0, 0, f'{station}', fontsize=12, color='black')
|
193 |
+
|
194 |
+
prev_station = station
|
195 |
+
|
196 |
+
text = text.replace('\n', '')
|
197 |
+
label, area, tokens = output_graph_matrics(index, tag, text)
|
198 |
+
log = {
|
199 |
+
'index': index,
|
200 |
+
'station': station,
|
201 |
+
'label': label,
|
202 |
+
'area': area,
|
203 |
+
'tokens': tokens
|
204 |
+
}
|
205 |
+
logs.append(log)
|
206 |
+
|
207 |
+
if index == len(text_list) - 1:
|
208 |
+
plt_xsec.savefig(f'./graph/polygon_{now_str}_{tag}_{prev_station}_{equal_check}.png', dpi=300)
|
209 |
+
plt_xsec.close()
|
210 |
+
|
211 |
+
return logs
|
212 |
+
|
213 |
+
def main():
|
214 |
+
draw_colorbox_list()
|
215 |
+
|
216 |
+
summary_log_file = open('./graph/summary_log.csv', 'a')
|
217 |
+
if summary_log_file is None:
|
218 |
+
return
|
219 |
+
summary_log_file.write(f'model, ground true, length, ground false, length\n')
|
220 |
+
|
221 |
+
tags = ['MLP [128, 64, 32]', 'MLP [64, 128, 64]', 'MLP [64, 128, 64, 32]', 'LSTM [128]', 'LSTM [128, 64, 32]', 'LSTM [256, 128, 64]', 'transformer 32', 'transformer 64', 'transformer 128', 'BERT']
|
222 |
+
for tag in tags:
|
223 |
+
print(tag)
|
224 |
+
if len(output_stations) > 0:
|
225 |
+
logs1 = output_logs(tag,)
|
226 |
+
continue
|
227 |
+
|
228 |
+
logs1 = output_logs(tag, 'Equal: True')
|
229 |
+
logs2 = output_logs(tag, 'Equal: False')
|
230 |
+
if len(logs1) == 0 or len(logs2) == 0:
|
231 |
+
continue
|
232 |
+
area1 = area2 = 0
|
233 |
+
area1 += sum([log['area'] for log in logs1])
|
234 |
+
area2 += sum([log['area'] for log in logs2])
|
235 |
+
log_record = f'{tag}, {area1}, {len(logs1)}, {area2}, {len(logs2)}'
|
236 |
+
summary_log_file.write(f'{log_record}\n')
|
237 |
+
|
238 |
+
if flag_all_xsections:
|
239 |
+
break
|
240 |
+
|
241 |
+
summary_log_file.close()
|
242 |
+
|
243 |
+
if __name__ == '__main__':
|
244 |
+
main()
|
prepare_dataset.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# title: create earthwwork train dataset
|
2 |
+
# author: Taewook Kang
|
3 |
+
# date: 2024.3.27
|
4 |
+
# description: create earthwork train dataset
|
5 |
+
# license: MIT
|
6 |
+
# version
|
7 |
+
# 0.1. 2024.3.27. create file
|
8 |
+
#
|
9 |
+
import os, math, argparse, json, re, traceback, numpy as np, pandas as pd, trimesh, laspy, shutil
|
10 |
+
import logging, matplotlib.pyplot as plt, shapely
|
11 |
+
from shapely.geometry import Polygon, LineString
|
12 |
+
from scipy.spatial import distance
|
13 |
+
from tqdm import trange, tqdm
|
14 |
+
from math import pi
|
15 |
+
|
16 |
+
logging.basicConfig(level=logging.DEBUG, filename='logs.txt',
|
17 |
+
format='%(asctime)s %(levelname)s %(message)s',
|
18 |
+
datefmt='%H:%M:%S')
|
19 |
+
logger = logging.getLogger("prep")
|
20 |
+
|
21 |
+
_precision = 0.00001
|
22 |
+
|
23 |
+
def get_bbox(polyline):
|
24 |
+
polyline_np = np.array(polyline)
|
25 |
+
xmin, ymin = np.amin(polyline_np, axis=0)
|
26 |
+
xmax, ymax = np.amax(polyline_np, axis=0)
|
27 |
+
return (xmin, ymin, xmax, ymax)
|
28 |
+
|
29 |
+
def get_center_point(pline):
|
30 |
+
if len(pline) == 0:
|
31 |
+
return (0, 0)
|
32 |
+
xs = [p[0] for p in pline]
|
33 |
+
ys = [p[1] for p in pline]
|
34 |
+
return (sum(xs) / len(pline), sum(ys) / len(pline))
|
35 |
+
|
36 |
+
def intersect_line(line1, line2):
|
37 |
+
(x1, y1), (x2, y2) = line1
|
38 |
+
(x3, y3), (x4, y4) = line2
|
39 |
+
|
40 |
+
denominator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
|
41 |
+
if denominator == 0:
|
42 |
+
return None # lines are parallel
|
43 |
+
|
44 |
+
x = ((x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)) / denominator
|
45 |
+
y = ((x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)) / denominator
|
46 |
+
|
47 |
+
# check (x, y) in line1 and line2
|
48 |
+
if x < min(x1, x2) or x > max(x1, x2) or x < min(x3, x4) or x > max(x3, x4):
|
49 |
+
return None
|
50 |
+
|
51 |
+
return (x, y)
|
52 |
+
|
53 |
+
def get_positions_pline(base_pline, target_pline):
|
54 |
+
target_pos_marks = []
|
55 |
+
for i in range(len(target_pline)):
|
56 |
+
target = [target_pline[i], (target_pline[i][0], target_pline[i][1] + 1e+10)] # vertical line to check below
|
57 |
+
pos = 0.0
|
58 |
+
for j in range(len(base_pline) - 1):
|
59 |
+
base = [base_pline[j], base_pline[j + 1]]
|
60 |
+
intersect = intersect_line(base, target)
|
61 |
+
if intersect == None:
|
62 |
+
continue
|
63 |
+
|
64 |
+
if equal(intersect[1], target[0][1]):
|
65 |
+
pos = 0.0
|
66 |
+
break
|
67 |
+
|
68 |
+
pos = -1.0 if intersect[1] > target[0][1] else 1.0
|
69 |
+
break
|
70 |
+
target_pos_marks.append(pos)
|
71 |
+
|
72 |
+
return target_pos_marks
|
73 |
+
|
74 |
+
def get_below_pline(base_pline, target_pline):
|
75 |
+
pos_marks = get_positions_pline(base_pline, target_pline)
|
76 |
+
average = sum(pos_marks) / len(pos_marks)
|
77 |
+
return average < 0.0
|
78 |
+
|
79 |
+
def get_geometry(xsec, label):
|
80 |
+
for geom in xsec['geom']:
|
81 |
+
if geom['label'] == label:
|
82 |
+
return geom
|
83 |
+
return None
|
84 |
+
|
85 |
+
def is_point_in_rect(point1, point2, perp):
|
86 |
+
return is_point_in_rectangle(point1[0], point1[1], point2[0], point2[1], perp[0], perp[1])
|
87 |
+
|
88 |
+
def is_point_in_rectangle(x1, y1, x2, y2, x, y):
|
89 |
+
# Ensure that x1 <= x2 and y1 <= y2
|
90 |
+
x1, x2 = min(x1, x2), max(x1, x2)
|
91 |
+
y1, y2 = min(y1, y2), max(y1, y2)
|
92 |
+
|
93 |
+
# Check if (x, y) is within the rectangle
|
94 |
+
return x1 <= x <= x2 and y1 <= y <= y2
|
95 |
+
|
96 |
+
def sign_distance(a, b, c, p):
|
97 |
+
d = math.sqrt(a*a + b*b)
|
98 |
+
if d == 0.0:
|
99 |
+
lp = 0.0
|
100 |
+
else:
|
101 |
+
lp = (a * p[0] + b * p[1] + c) / d
|
102 |
+
return lp
|
103 |
+
|
104 |
+
def equal(a, b):
|
105 |
+
return abs(a - b) < _precision
|
106 |
+
|
107 |
+
def equal_point(p1, p2):
|
108 |
+
return equal(p1[0], p2[0]) and equal(p1[1], p2[1])
|
109 |
+
|
110 |
+
def get_angle(x1, y1, x2, y2):
|
111 |
+
pi = math.acos(-1.0)
|
112 |
+
|
113 |
+
# Caculate the quadrant of the line.
|
114 |
+
dx = x2 - x1
|
115 |
+
dy = y2 - y1
|
116 |
+
|
117 |
+
# Calculate the angle in radians for lines in the left and right quadrants
|
118 |
+
if dx == 0 and dy == 0:
|
119 |
+
return -1.0
|
120 |
+
if dy < 0 and dx == 0:
|
121 |
+
angle_radius = pi + pi / 2
|
122 |
+
elif dy > 0 and dx == 0:
|
123 |
+
angle_radius = pi / 2
|
124 |
+
else:
|
125 |
+
angle_radius = math.atan(dy / dx)
|
126 |
+
|
127 |
+
# Adjust the angle for different quadrants
|
128 |
+
if dy >= 0 and dx > 0:
|
129 |
+
pass
|
130 |
+
if dy < 0 and dx > 0:
|
131 |
+
angle_radius += 2 * pi
|
132 |
+
elif dx < 0:
|
133 |
+
angle_radius += pi
|
134 |
+
|
135 |
+
return angle_radius
|
136 |
+
|
137 |
+
def line_coefficients(point1, point2):
|
138 |
+
x1, y1 = point1
|
139 |
+
x2, y2 = point2
|
140 |
+
|
141 |
+
A = y2 - y1
|
142 |
+
B = x1 - x2
|
143 |
+
C = x2*y1 - x1*y2
|
144 |
+
|
145 |
+
return A, B, C
|
146 |
+
|
147 |
+
def sign_point_on_line(point1, point2, new_point):
|
148 |
+
if equal(point1[0], new_point[0]) and equal(point1[1], new_point[1]):
|
149 |
+
return 0.0
|
150 |
+
if equal(point2[0], new_point[0]) and equal(point2[1], new_point[1]):
|
151 |
+
return 0.0
|
152 |
+
|
153 |
+
line_A, line_B, line_C = line_coefficients(point1, point2)
|
154 |
+
x, y = new_point
|
155 |
+
value = line_A * x + line_B * y + line_C
|
156 |
+
if math.fabs(value) < _precision:
|
157 |
+
return 0.0
|
158 |
+
elif value > 0.0:
|
159 |
+
return 1.0
|
160 |
+
return -1.0
|
161 |
+
|
162 |
+
def sign_distance_on_line(line_point1, line_point2, point):
|
163 |
+
direction = sign_point_on_line(line_point1, line_point2, point)
|
164 |
+
if direction == 0:
|
165 |
+
return 0.0
|
166 |
+
|
167 |
+
if math.fabs(line_point1[0] - line_point2[0]) < _precision and math.fabs(line_point1[1] <= line_point2[1]) < _precision:
|
168 |
+
return 0.0
|
169 |
+
|
170 |
+
# TBD. bug
|
171 |
+
x = point[0]
|
172 |
+
y = point[1]
|
173 |
+
x1 = line_point1[0]
|
174 |
+
y1 = line_point1[1]
|
175 |
+
x2 = line_point2[0]
|
176 |
+
y2 = line_point2[1]
|
177 |
+
|
178 |
+
if x1 <= x2:
|
179 |
+
a = 1
|
180 |
+
b = 0
|
181 |
+
c = -x1
|
182 |
+
else:
|
183 |
+
m = (y2 - y1) / (x2 - x1)
|
184 |
+
a = -m
|
185 |
+
b = 1
|
186 |
+
c = -y1 + (m * x1)
|
187 |
+
|
188 |
+
dist = abs(a * x + b * y + c) / math.sqrt(a * a + b * b)
|
189 |
+
dist *= float(direction)
|
190 |
+
|
191 |
+
return dist
|
192 |
+
|
193 |
+
def is_point_on_line(point1, point2, perp):
|
194 |
+
if is_point_in_rect(point1, point2, perp) == False:
|
195 |
+
return False
|
196 |
+
direction = sign_point_on_line(point1, point2, perp)
|
197 |
+
if math.fabs(direction) < _precision:
|
198 |
+
return True
|
199 |
+
return False
|
200 |
+
|
201 |
+
def is_overlap_line(line, part_seg):
|
202 |
+
p1, p2 = line
|
203 |
+
p3, p4 = part_seg
|
204 |
+
|
205 |
+
f1 = is_point_on_line(p1, p2, p3)
|
206 |
+
f2 = is_point_on_line(p1, p2, p4)
|
207 |
+
|
208 |
+
if (f1 or f2) and f1 != f2: # dangling point is not overlap.
|
209 |
+
if f1 and (equal_point(p1, p3) or equal_point(p2, p3)):
|
210 |
+
return False
|
211 |
+
if f2 and (equal_point(p1, p4) or equal_point(p2, p4)):
|
212 |
+
return False
|
213 |
+
|
214 |
+
return f1 or f2
|
215 |
+
|
216 |
+
def is_on_pline(polyline, base_line):
|
217 |
+
p1 = base_line[0]
|
218 |
+
p2 = base_line[1]
|
219 |
+
|
220 |
+
for i in range(len(polyline) - 1):
|
221 |
+
p3 = polyline[i]
|
222 |
+
p4 = polyline[i + 1]
|
223 |
+
if is_overlap_line((p1, p2), (p3, p4)) or is_overlap_line((p3, p4), (p1, p2)):
|
224 |
+
return True
|
225 |
+
|
226 |
+
return False
|
227 |
+
|
228 |
+
def get_match_line_labels(xsec, base_geom, base_line):
|
229 |
+
labels = []
|
230 |
+
for geom in xsec['geom']:
|
231 |
+
if geom == base_geom:
|
232 |
+
continue
|
233 |
+
geom_label = geom['label']
|
234 |
+
base_label = base_geom['label']
|
235 |
+
if geom_label == base_label:
|
236 |
+
continue
|
237 |
+
closed = geom['closed']
|
238 |
+
if closed == True: # only polyline is considered
|
239 |
+
continue
|
240 |
+
if geom_label == 'center':
|
241 |
+
continue
|
242 |
+
|
243 |
+
polyline = geom['polyline']
|
244 |
+
if is_on_pline(polyline, base_line):
|
245 |
+
labels.append(geom['label'])
|
246 |
+
return labels
|
247 |
+
|
248 |
+
def get_seq_feature_tokens(xsec, geom, closed_type):
|
249 |
+
polyline = geom['polyline']
|
250 |
+
closed = geom['closed']
|
251 |
+
if closed != closed_type:
|
252 |
+
return []
|
253 |
+
|
254 |
+
lines = []
|
255 |
+
for i in range(len(polyline) - 1):
|
256 |
+
line = (polyline[i], polyline[i + 1])
|
257 |
+
lines.append(line)
|
258 |
+
|
259 |
+
geom_tokens = []
|
260 |
+
for line in lines:
|
261 |
+
labels = get_match_line_labels(xsec, geom, line)
|
262 |
+
if len(labels) == 0:
|
263 |
+
continue
|
264 |
+
|
265 |
+
# if len(labels) == 1:
|
266 |
+
# geom_tokens.append(labels[0])
|
267 |
+
# else:
|
268 |
+
geom_tokens.extend(labels)
|
269 |
+
|
270 |
+
return geom_tokens
|
271 |
+
|
272 |
+
def translate_geometry(xsec, cp):
|
273 |
+
for geom in xsec['geom']:
|
274 |
+
polyline = geom['polyline']
|
275 |
+
geom['polyline'] = [(p[0] - cp[0], p[1] - cp[1]) for p in polyline]
|
276 |
+
|
277 |
+
return xsec
|
278 |
+
|
279 |
+
def is_closed(polyline):
|
280 |
+
if equal_point(polyline[0], polyline[-1]):
|
281 |
+
return True
|
282 |
+
return False
|
283 |
+
|
284 |
+
def summery_feature(features):
|
285 |
+
sum_features = []
|
286 |
+
if len(features) == 0:
|
287 |
+
return sum_features
|
288 |
+
|
289 |
+
index = 0
|
290 |
+
while index < len(features):
|
291 |
+
f = features[index]
|
292 |
+
sum_feature = f
|
293 |
+
if type(f) == list:
|
294 |
+
sum_feature = summery_feature(f)
|
295 |
+
if len(sum_feature) == 1:
|
296 |
+
sum_feature = sum_feature[0]
|
297 |
+
elif type(f) == str:
|
298 |
+
label = f
|
299 |
+
# find last index of same level in features array with label
|
300 |
+
last_index = index
|
301 |
+
for i in range(index + 1, len(features)):
|
302 |
+
if type(features[i]) == str:
|
303 |
+
if features[i] == label:
|
304 |
+
last_index = i
|
305 |
+
else:
|
306 |
+
break
|
307 |
+
else:
|
308 |
+
break
|
309 |
+
if last_index != index:
|
310 |
+
sum_feature = (f'{f}({last_index - index + 1})')
|
311 |
+
index = last_index
|
312 |
+
else:
|
313 |
+
pass
|
314 |
+
|
315 |
+
sum_features.append(sum_feature)
|
316 |
+
index += 1
|
317 |
+
|
318 |
+
return sum_features
|
319 |
+
|
320 |
+
def get_intersection_count(xsec, base_geom, target_label):
|
321 |
+
pave_top = get_geometry(xsec, 'pave_surface')
|
322 |
+
if pave_top == None:
|
323 |
+
return 0
|
324 |
+
|
325 |
+
polyline = base_geom['polyline']
|
326 |
+
polygon = Polygon(polyline)
|
327 |
+
base_p1 = polygon.centroid
|
328 |
+
base_p2 = (base_p1.x, base_p1.y + 1e+10)
|
329 |
+
vertical_line = LineString([base_p1, base_p2])
|
330 |
+
|
331 |
+
count = 0
|
332 |
+
for target_geom in xsec['geom']:
|
333 |
+
if base_geom == target_geom:
|
334 |
+
continue
|
335 |
+
label = target_geom['label']
|
336 |
+
if re.search(target_label, label) == None:
|
337 |
+
continue
|
338 |
+
|
339 |
+
# check intersection
|
340 |
+
target_polyline = target_geom['polyline']
|
341 |
+
polyline = LineString(target_polyline)
|
342 |
+
ip = shapely.intersection(polyline, vertical_line) # https://shapely.readthedocs.io/en/stable/reference/shapely.intersection.html
|
343 |
+
if ip.is_empty:
|
344 |
+
continue
|
345 |
+
count += 1
|
346 |
+
|
347 |
+
return count
|
348 |
+
|
349 |
+
def update_xsection_feature(xsec):
|
350 |
+
gnd_geom = get_geometry(xsec, 'ground')
|
351 |
+
if gnd_geom == None:
|
352 |
+
return None
|
353 |
+
|
354 |
+
center = get_geometry(xsec, 'center')
|
355 |
+
if center == None or 'polyline' not in center:
|
356 |
+
return None
|
357 |
+
cp = get_center_point(center['polyline'])
|
358 |
+
|
359 |
+
xsec = translate_geometry(xsec, cp)
|
360 |
+
station = xsec['station']
|
361 |
+
|
362 |
+
index = 0
|
363 |
+
while index < len(xsec['geom']):
|
364 |
+
geom = xsec['geom'][index]
|
365 |
+
label = geom['label']
|
366 |
+
polyline = geom['polyline']
|
367 |
+
closed = geom['closed']
|
368 |
+
if len(polyline) <= 2 or closed == False:
|
369 |
+
index += 1
|
370 |
+
continue
|
371 |
+
|
372 |
+
pt1 = polyline[0]
|
373 |
+
pt2 = polyline[-1]
|
374 |
+
if equal_point(pt1, pt2) == False: # closed polyline
|
375 |
+
polyline.append(pt1)
|
376 |
+
|
377 |
+
# noise filtering
|
378 |
+
polygon = Polygon(polyline) # calculate area of polyline as polygon
|
379 |
+
if math.fabs(polygon.area) < _precision:
|
380 |
+
xsec['geom'].pop(index) # remove index element in xsec['geom']
|
381 |
+
continue
|
382 |
+
|
383 |
+
if station == '1+660.00000' and label == 'cut_ditch':
|
384 |
+
label = 'cut_ditch'
|
385 |
+
|
386 |
+
# processing
|
387 |
+
if get_below_pline(gnd_geom['polyline'], polyline):
|
388 |
+
geom['earthwork_feature'].append('below')
|
389 |
+
else:
|
390 |
+
geom['earthwork_feature'].append('above')
|
391 |
+
|
392 |
+
if re.search('pave_.*', label):
|
393 |
+
pave_int_count = get_intersection_count(xsec, geom, 'pave_.*')
|
394 |
+
geom['earthwork_feature'].append(f'pave_int({pave_int_count})')
|
395 |
+
|
396 |
+
tokens = get_seq_feature_tokens(xsec, geom, True)
|
397 |
+
if len(tokens) == 1:
|
398 |
+
geom['earthwork_feature'].append(tokens[0])
|
399 |
+
else:
|
400 |
+
geom['earthwork_feature'].extend(tokens)
|
401 |
+
|
402 |
+
geom['earthwork_feature'] = summery_feature(geom['earthwork_feature'])
|
403 |
+
|
404 |
+
# print(f'{station}. {label} feature: {geom["earthwork_feature"]}')
|
405 |
+
logger.debug(f'{station}. {label} feature: {geom["earthwork_feature"]}')
|
406 |
+
index += 1
|
407 |
+
|
408 |
+
return xsec
|
409 |
+
|
410 |
+
def update_xsections_feature(xsections):
|
411 |
+
# update closed polygon
|
412 |
+
for xsec in xsections:
|
413 |
+
for geom in xsec['geom']:
|
414 |
+
label = geom['label']
|
415 |
+
polyline = geom['polyline']
|
416 |
+
if len(polyline) < 2:
|
417 |
+
continue
|
418 |
+
closed = is_closed(polyline)
|
419 |
+
if closed == False: # exception case, pavement
|
420 |
+
closed = False if re.search('pave_layer.*', label) == None else True
|
421 |
+
geom['closed'] = closed
|
422 |
+
|
423 |
+
# update features
|
424 |
+
out_xsections = []
|
425 |
+
for xsec in xsections:
|
426 |
+
out_xsec = update_xsection_feature(xsec)
|
427 |
+
if out_xsec == None:
|
428 |
+
continue
|
429 |
+
out_xsections.append(out_xsec)
|
430 |
+
|
431 |
+
return out_xsections
|
432 |
+
|
433 |
+
def main():
|
434 |
+
parser = argparse.ArgumentParser(description='create earthwork train dataset')
|
435 |
+
parser.add_argument('--input', type=str, default='output/', help='input folder')
|
436 |
+
parser.add_argument('--output', type=str, default='dataset/', help='output folder')
|
437 |
+
|
438 |
+
args = parser.parse_args()
|
439 |
+
try:
|
440 |
+
file_names = os.listdir(args.input)
|
441 |
+
for file_name in tqdm(file_names):
|
442 |
+
if file_name.endswith('.json') == False:
|
443 |
+
continue
|
444 |
+
print(f'processing {file_name}')
|
445 |
+
data = None
|
446 |
+
with open(os.path.join(args.input, file_name), 'r') as f:
|
447 |
+
data = json.load(f)
|
448 |
+
|
449 |
+
out_xsections = update_xsections_feature(data)
|
450 |
+
|
451 |
+
output_file = os.path.join(args.output, file_name)
|
452 |
+
with open(output_file, 'w') as f:
|
453 |
+
json.dump(out_xsections, f, indent=4)
|
454 |
+
|
455 |
+
except Exception as e:
|
456 |
+
print(f'error: {e}')
|
457 |
+
traceback.print_exc()
|
458 |
+
|
459 |
+
if __name__ == '__main__':
|
460 |
+
main()
|