mac999 commited on
Commit
af359c9
·
verified ·
1 Parent(s): 50e147d

Upload 7 files

Browse files
Files changed (7) hide show
  1. config.json +120 -0
  2. create_earthwork_dataset.py +232 -0
  3. ena_dataset.py +137 -0
  4. ena_run_model.py +548 -0
  5. eval_model.py +75 -0
  6. extract_ewlog.py +244 -0
  7. prepare_dataset.py +460 -0
config.json ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "layers": [
3
+ {
4
+ "layer": "Nru_Geo_Crs_Center",
5
+ "label": "center"
6
+ },
7
+ {
8
+ "layer": "Nru_Crs_Pave_Surface",
9
+ "label": "pave_surface"
10
+ },
11
+ {
12
+ "layer": "Nru_Crs_Pave_Subgrade",
13
+ "label": "pave_subgrade"
14
+ },
15
+ {
16
+ "layer": "Nru_Geo_Surface",
17
+ "label": "ground"
18
+ },
19
+ {
20
+ "layer": "Nru_Crs_Pave_Bottom",
21
+ "label": "pave_bottom"
22
+ },
23
+ {
24
+ "layer": "Nru_Geo_Underground_1",
25
+ "label": "rr"
26
+ },
27
+ {
28
+ "layer": "Nru_Geo_Underground_2",
29
+ "label": "br"
30
+ },
31
+ {
32
+ "layer": "Nru_Crs_Slope",
33
+ "label": "slope"
34
+ },
35
+ {
36
+ "layer": "Nru_Stru_Bench",
37
+ "label": "struct"
38
+ },
39
+ {
40
+ "layer": "Nru_Stru_Frt_Sodan",
41
+ "label": "struct"
42
+ },
43
+ {
44
+ "layer": "Nru_Stru_Smr",
45
+ "label": "struct"
46
+ },
47
+ {
48
+ "layer": "Nru_Stru_Smr_Ending",
49
+ "label": "struct"
50
+ },
51
+ {
52
+ "layer": "Nru_Stru_Ditch",
53
+ "label": "struct"
54
+ },
55
+ {
56
+ "layer": "Nru_Stru_Ditch_Bench",
57
+ "label": "struct"
58
+ },
59
+ {
60
+ "layer": "Nru_Stru_Frt",
61
+ "label": "struct"
62
+ },
63
+ {
64
+ "layer": "Nru_Crs_Ew_깎기_토사",
65
+ "label": "cut_ea"
66
+ },
67
+ {
68
+ "layer": "Nru_Crs_Ew_깎기_리핑암",
69
+ "label": "cut_rr"
70
+ },
71
+ {
72
+ "layer": "Nru_Crs_Ew_일반발파",
73
+ "label": "cut_br"
74
+ },
75
+ {
76
+ "layer": "Nru_Crs_Ew_대규모발파",
77
+ "label": "cut_br"
78
+ },
79
+ {
80
+ "layer": "Nru_Crs_Ew_중규모진동제어발파",
81
+ "label": "cut_br"
82
+ },
83
+ {
84
+ "layer": "Nru_Crs_Ew_터파기_토사",
85
+ "label": "cut_ditch"
86
+ },
87
+ {
88
+ "layer": "Nru_Crs_Ew_쌓기_노상",
89
+ "label": "fill_subbed"
90
+ },
91
+ {
92
+ "layer": "Nru_Crs_Ew_쌓기_노체",
93
+ "label": "fill_subbody"
94
+ },
95
+ {
96
+ "layer": "Nru_Crs_Pave_Layer-1",
97
+ "label": "pave_layer1"
98
+ },
99
+ {
100
+ "layer": "Nru_Crs_Pave_Layer-2",
101
+ "label": "pave_layer2"
102
+ },
103
+ {
104
+ "layer": "Nru_Crs_Pave_Layer-3",
105
+ "label": "pave_layer3"
106
+ },
107
+ {
108
+ "layer": "Nru_Crs_Pave_Layer-4",
109
+ "label": "pave_layer4"
110
+ },
111
+ {
112
+ "layer": "Nru_Crs_Steps",
113
+ "label": "steps"
114
+ },
115
+ {
116
+ "layer": "Nru_Crs_Curb",
117
+ "label": "curb"
118
+ }
119
+ ]
120
+ }
create_earthwork_dataset.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # title: create earthwwork train dataset
2
+ # author: Taewook Kang
3
+ # date: 2024.3.27
4
+ # description: create earthwork train dataset
5
+ # license: MIT
6
+ # reference: https://pyautocad.readthedocs.io/en/latest/_modules/pyautocad/api.html
7
+ # version
8
+ # 0.1. 2024.3.27. create file
9
+ #
10
+ import os, math, argparse, json, re, traceback, numpy as np, pandas as pd, trimesh, laspy, shutil
11
+ import pyautocad, open3d as o3d, seaborn as sns, win32com.client, pythoncom
12
+ import matplotlib.pyplot as plt
13
+ from scipy.spatial import distance
14
+ from tqdm import trange, tqdm
15
+ from math import pi
16
+
17
+ def get_layer_to_label(cfg, layer):
18
+ layers = cfg['layers']
19
+ for lay in layers:
20
+ if lay['layer'] == layer:
21
+ return lay['label']
22
+ return ''
23
+
24
+ def get_entity_from_acad(entity_names = ['AcDbLine', 'AcDbPolyline', 'AcDbText']):
25
+ acad = pyautocad.Autocad(create_if_not_exists=True)
26
+ selections = acad.get_selection('Select entities to extract geometry')
27
+
28
+ geoms = []
29
+ for entity in tqdm(selections): # tqdm(acad.iter_objects()): # selections:
30
+ try:
31
+ if entity.EntityName in entity_names:
32
+ geoms.append(entity)
33
+ except Exception as e:
34
+ print(f'error: {e}')
35
+ continue
36
+
37
+ if not geoms:
38
+ print("No entities found in the drawing.")
39
+ return
40
+
41
+ return geoms
42
+
43
+ def get_bbox(polyline):
44
+ xmin, ymin, xmax, ymax = polyline[0][0], polyline[0][1], polyline[0][0], polyline[0][1]
45
+ for x, y in polyline:
46
+ xmin = min(xmin, x)
47
+ ymin = min(ymin, y)
48
+ xmax = max(xmax, x)
49
+ ymax = max(ymax, y)
50
+ return (xmin, ymin, xmax, ymax)
51
+
52
+ def get_xsections_from_acad(cfg):
53
+ entities = get_entity_from_acad()
54
+
55
+ # extract cross sections
56
+ xsec_list = []
57
+ xsec_entities = []
58
+ for entity in entities:
59
+ if entity.Layer == 'Nru_Frame_Crs_Design' and entity.EntityName == 'AcDbPolyline':
60
+ polyline = []
61
+ vertex_list = entity.Coordinates
62
+ for i in range(0, len(vertex_list), 2):
63
+ polyline.append((vertex_list[i], vertex_list[i+1]))
64
+ if len(polyline) < 2:
65
+ continue
66
+ bbox = get_bbox(polyline)
67
+
68
+ xsec = {'bbox': bbox, 'station': '', 'geom': []}
69
+ xsec_list.append(xsec)
70
+ else:
71
+ xsec_entities.append(entity)
72
+ if len(xsec_entities) == 0:
73
+ print("No cross section found in the drawing.")
74
+ return []
75
+
76
+ for xsec in xsec_list:
77
+ for entity in xsec_entities:
78
+ if entity.EntityName != 'AcDbText':
79
+ continue
80
+ pt = (entity.InsertionPoint[0], entity.InsertionPoint[1])
81
+ bbox = xsec['bbox']
82
+ if pt[0] < bbox[0] or pt[1] < bbox[1] or pt[0] > bbox[2] or pt[1] > bbox[3]:
83
+ continue
84
+ xsec_station = entity.TextString
85
+ pattern = r'\d+\+\d+\.\d+'
86
+ match = re.search(pattern, xsec_station)
87
+ if match:
88
+ xsec_station = match.group()
89
+ else:
90
+ xsec_station = '-1+000.00'
91
+ xsec['station'] = xsec_station
92
+
93
+ if len(xsec_list) == 0:
94
+ xsec = {'bbox': (-9999999999.0, -9999999999.0, 9999999999.0, 9999999999.0), 'station': '0+000.00'}
95
+ xsec_list.append(xsec)
96
+
97
+ xsec_list = sorted(xsec_list, key=lambda x: x['station']) # sorting xsec_list by station string, format is 'xxx+xxx.xx'
98
+
99
+ # extract geometry in each cross section
100
+ for xsec in tqdm(xsec_list):
101
+ for entity in xsec_entities:
102
+ label = get_layer_to_label(cfg, entity.Layer)
103
+ if label == '':
104
+ continue
105
+
106
+ closed = False
107
+ polyline = []
108
+ if entity.EntityName == 'AcDbLine':
109
+ polyline = [entity.StartPoint, entity.EndPoint],
110
+ closed = False
111
+ elif entity.EntityName == 'AcDbPolyline':
112
+ vertex_list = entity.Coordinates
113
+ for i in range(0, len(vertex_list), 2):
114
+ polyline.append((vertex_list[i], vertex_list[i+1]))
115
+ closed = entity.Closed
116
+ else:
117
+ continue
118
+
119
+ xsec_bbox = xsec['bbox']
120
+ entity_bbox = get_bbox(polyline)
121
+ if entity_bbox[0] < xsec_bbox[0] or entity_bbox[1] < xsec_bbox[1] or entity_bbox[2] > xsec_bbox[2] or entity_bbox[3] > xsec_bbox[3]:
122
+ continue
123
+
124
+ geo = {
125
+ 'label': label,
126
+ 'polyline': polyline,
127
+ 'closed': closed,
128
+ 'earthwork_feature': []
129
+ }
130
+ xsec['geom'].append(geo)
131
+
132
+ return xsec_list
133
+
134
+ # defining function to add line plot
135
+ _draw_xsection_index = 0
136
+ _xsections = None
137
+ _plot_ax = None
138
+ def draw_xsections(ax, index):
139
+ xsec = _xsections[index]
140
+ for geo in xsec['geom']:
141
+ station = xsec['station']
142
+ ax.set_title(f'station: {station}')
143
+ polyline = np.array(geo['polyline'])
144
+ ax.plot(polyline[:,0], polyline[:,1], label=geo['label'])
145
+ ax.set_aspect('equal', 'box')
146
+
147
+ def next_button(event):
148
+ global _draw_xsection_index, _xsections, _plot_ax
149
+ _draw_xsection_index += 1
150
+ if _draw_xsection_index >= len(_xsections):
151
+ _draw_xsection_index = 0
152
+ _plot_ax.clear()
153
+ draw_xsections(_plot_ax, _draw_xsection_index)
154
+
155
+ def prev_button(event):
156
+ global _draw_xsection_index, _xsections, _plot_ax
157
+ _draw_xsection_index -= 1
158
+ if _draw_xsection_index < 0:
159
+ _draw_xsection_index = len(_xsections) - 1
160
+ _plot_ax.clear()
161
+ draw_xsections(_plot_ax, _draw_xsection_index)
162
+
163
+ def on_key_press(event):
164
+ if event.key == 'right':
165
+ next_button(None)
166
+ elif event.key == 'left':
167
+ prev_button(None)
168
+
169
+ def show_xsections(xsections):
170
+ from matplotlib.widgets import Button
171
+ global _draw_xsection_index, _xsections, _plot_ax
172
+ _xsections = xsections
173
+
174
+ fig = plt.figure()
175
+ _plot_ax = fig.subplots()
176
+ plt.subplots_adjust(left = 0.3, bottom = 0.25)
177
+ draw_xsections(_plot_ax, _draw_xsection_index)
178
+
179
+ # defining button and add its functionality
180
+ axprev = fig.add_axes([0.7, 0.05, 0.1, 0.075])
181
+ bprev = Button(axprev, 'prev', color="white")
182
+ bprev.on_clicked(prev_button)
183
+ axnext = fig.add_axes([0.81, 0.05, 0.1, 0.075])
184
+ bnext = Button(axnext, 'next', color="white")
185
+ bnext.on_clicked(next_button)
186
+
187
+ fig.canvas.mpl_connect('key_press_event', on_key_press)
188
+
189
+ plt.show()
190
+
191
+ def main():
192
+ parser = argparse.ArgumentParser(description='create earthwork train dataset')
193
+ parser.add_argument('--config', type=str, default='config.json', help='config file')
194
+ parser.add_argument('--output', type=str, default='output/', help='output directory')
195
+ parser.add_argument('--view', type=str, default='output/chain_chunk_6.json', help='view file')
196
+
197
+ args = parser.parse_args()
198
+ try:
199
+ if len(args.view) > 0:
200
+ with open(args.view, 'r') as f:
201
+ xsections = json.load(f)
202
+ show_xsections(xsections)
203
+ return
204
+
205
+ cfg = None
206
+ with open(args.config, 'r', encoding='utf-8') as f:
207
+ cfg = json.load(f)
208
+
209
+ chunk_index = 0
210
+ file_names = os.listdir(args.output)
211
+ if len(file_names):
212
+ pattern = r'chain_chunk_(\d+)\.json'
213
+ indices = [int(re.match(pattern, file_name).group(1)) for file_name in file_names if re.match(pattern, file_name)]
214
+ chunk_index = max(indices) + 1 if indices else 0
215
+
216
+ print(file_names)
217
+
218
+ while True:
219
+ xsections = get_xsections_from_acad(cfg)
220
+ if len(xsections) == 0:
221
+ break
222
+ geo_file = os.path.join(args.output, f'chain_chunk_{chunk_index}.json')
223
+ with open(geo_file, 'w') as f:
224
+ json.dump(xsections, f, indent=4)
225
+ print(f'{geo_file} was saved in {args.output}')
226
+ chunk_index += 1
227
+ except Exception as e:
228
+ print(f'error: {e}')
229
+ traceback.print_exc()
230
+
231
+ if __name__ == '__main__':
232
+ main()
ena_dataset.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # title: ENA dataset utility functions
2
+ # author: Taewook Kang, Kyubyung Kang
3
+ # date: 2024.3.27
4
+ # license: MIT
5
+ # reference: https://pyautocad.readthedocs.io/en/latest/_modules/pyautocad/api.html
6
+ # version
7
+ # 0.1. 2024.3.27. create file
8
+ #
9
+ import json, os, re, logging, numpy as np
10
+ from transformers import BertTokenizer
11
+
12
+ def load_train_chunk_data(data_dir, sort_fname=False):
13
+ geom_list = []
14
+ fnames = os.listdir(data_dir)
15
+ if sort_fname:
16
+ fnames.sort(key=lambda x: int(re.search(r'\d+', x).group()))
17
+ xsec_count = 0
18
+ for file_name in fnames:
19
+ if file_name.endswith('.json') == False:
20
+ continue
21
+ with open(os.path.join(data_dir, file_name), 'r') as f:
22
+ chunk = json.load(f)
23
+ for xsec in chunk:
24
+ xsec_count += 1
25
+ geom = xsec['geom']
26
+ for g in geom:
27
+ g['station'] = xsec['station']
28
+ features = g['earthwork_feature']
29
+ if len(features) == 0:
30
+ continue
31
+ geom_list.append(g)
32
+ print(f'Loaded {xsec_count} cross sections')
33
+ return geom_list
34
+
35
+ def update_feature_dims_token(geom_list):
36
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) # Load the BERT tokenizer
37
+
38
+ feature_dims = []
39
+ max_token = 0
40
+ padding_token_id = tokenizer.add_tokens(['padding'])
41
+ for geom in geom_list:
42
+ label = geom['label']
43
+ geom['feature_dims'] = []
44
+ for feature in geom['earthwork_feature']:
45
+ # token = tokenizer.tokenize(feature)
46
+ token_ids = tokenizer.convert_tokens_to_ids(feature)
47
+ geom['feature_dims'].append(token_ids)
48
+
49
+ word, count = extract_word_and_count(feature)
50
+ if word in tokens:
51
+ continue
52
+ feature_dims.append(word)
53
+
54
+ max_token = max(max_token, len(geom['feature_dims']))
55
+
56
+ for geom in geom_list:
57
+ label = geom['label']
58
+ geom['feature_dims'] += [padding_token_id] * (max_token - len(geom['feature_dims']))
59
+
60
+ print(f'Max token length: {max_token}')
61
+ return feature_dims
62
+
63
+ def extract_word_and_count(s):
64
+ match = re.match(r'(\w+)(?:\((\d+)\))?', s)
65
+ if match:
66
+ word, count = match.groups()
67
+ count = int(count) if count else 1
68
+ return word, count
69
+
70
+ return None, None
71
+
72
+ def update_feature_dims_freq(geom_list, augument=False):
73
+ feature_dims = []
74
+ for geom in geom_list:
75
+ label = geom['label']
76
+ geom['feature_dims'] = []
77
+ for feature in geom['earthwork_feature']:
78
+ word, count = extract_word_and_count(feature)
79
+ if word is None or count is None:
80
+ continue
81
+ if word in feature_dims:
82
+ continue
83
+ feature_dims.append(word)
84
+
85
+ feature_dims.sort()
86
+
87
+ max_feature_dims_count = [0.0] * len(feature_dims)
88
+ for geom in geom_list:
89
+ label = geom['label']
90
+ geom['feature_dims'] = [0.0] * len(feature_dims)
91
+ geom['feature_text'] = ''
92
+ # geom['feature_angle_dims'] = [0.0] * len(feature_dims)
93
+
94
+ for feature in geom['earthwork_feature']:
95
+ word, count = extract_word_and_count(feature)
96
+ if word is None or count is None:
97
+ continue
98
+ geom['feature_text'] += f'{word}({count}) '
99
+ index = feature_dims.index(word)
100
+
101
+ geom['feature_dims'][index] = count
102
+ max_feature_dims_count[index] = max(max_feature_dims_count[index], count)
103
+
104
+ # normalize feature_dims by usng max_feature_dims_count
105
+ for geom in geom_list:
106
+ label = geom['label']
107
+ for i in range(len(geom['feature_dims'])):
108
+ geom['feature_dims'][i] /= max_feature_dims_count[i]
109
+
110
+ # augument feature_dims dataset
111
+ if augument:
112
+ for geom in geom_list:
113
+ label = geom['label']
114
+ geom['feature_dims_aug'] = []
115
+ for i in range(len(geom['feature_dims'])):
116
+ geom['feature_dims_aug'].append(geom['feature_dims'][i])
117
+ geom['feature_dims_aug'].append(geom['feature_dims'][i] * geom['feature_dims'][i])
118
+
119
+ print(f'feature dims({len(feature_dims)}): {feature_dims}')
120
+ return feature_dims
121
+
122
+ def update_onehot_encoding(geom_list):
123
+ label_kinds = []
124
+ for geom in geom_list:
125
+ label = geom['label']
126
+ if label not in label_kinds:
127
+ label_kinds.append(label)
128
+
129
+ from collections import Counter # from sklearn.preprocessing import OneHotEncoder
130
+ for geom in geom_list: # count label's kind of train_labels. Initialize the one-hot encoder
131
+ label = geom['label']
132
+
133
+ label_counts = Counter(label_kinds)
134
+ onehot = np.zeros(len(label_kinds))
135
+ onehot[label_kinds.index(label)] = 1.0
136
+ geom['label_onehot'] = onehot
137
+ return label_kinds
ena_run_model.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # title: ENA model runner
2
+ # author: Taewook Kang, Kyubyung Kang
3
+ # date: 2024.3.27
4
+ # description: ENA model test and evaluation
5
+ # license: MIT
6
+ # version
7
+ # 0.1. 2024.3.27. create file
8
+ #
9
+ import json, os, re, logging
10
+ import torch, torch.nn as nn, torch.optim as optim, numpy as np, matplotlib.pyplot as plt, seaborn as sns
11
+ import torch.nn.functional as F
12
+ from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
13
+ from torch.utils.tensorboard import SummaryWriter
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, BertModel
16
+ from sklearn.metrics import confusion_matrix
17
+ from collections import defaultdict
18
+ from datetime import datetime
19
+ from tqdm import tqdm
20
+ from ena_dataset import load_train_chunk_data, update_feature_dims_freq, update_onehot_encoding
21
+
22
+ # write log file using logger
23
+ logging.basicConfig(filename= './ewnet_logs.txt', level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s', datefmt='%Y%m%d %H:%M')
24
+ logger = logging.getLogger('ewnet')
25
+
26
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+ print(f'device: {device}')
28
+
29
+ # param
30
+ hyperparam = None
31
+
32
+ # train model
33
+ class EarthworkNetMLP(nn.Module):
34
+ def __init__(self, input_dim, hidden_dim, output_dim, dropout_ratio=0.2):
35
+ super(EarthworkNetMLP, self).__init__()
36
+
37
+ models = []
38
+ models.append(nn.Linear(input_dim, hidden_dim[0]))
39
+ models.append(nn.ReLU())
40
+ models.append(nn.BatchNorm1d(hidden_dim[0])) # Batch normalization after activation
41
+ models.append(nn.Dropout(dropout_ratio))
42
+
43
+ for i in range(1, len(hidden_dim)):
44
+ models.append(nn.Linear(hidden_dim[i-1], hidden_dim[i]))
45
+ models.append(nn.ReLU())
46
+ models.append(nn.BatchNorm1d(hidden_dim[i]))
47
+ models.append(nn.Dropout(dropout_ratio))
48
+
49
+ models.append(nn.Linear(hidden_dim[-1], output_dim))
50
+ self.layers = nn.Sequential(*models)
51
+
52
+ def forward(self, x):
53
+ # print("Shape of x:", x.shape)
54
+ x = self.layers(x)
55
+ return x
56
+
57
+ # train model using LSTM
58
+ class EarthworkNetLSTM(nn.Module):
59
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, dropout_ratio=0.2):
60
+ super(EarthworkNetLSTM, self).__init__()
61
+
62
+ # sequence series data. ex) token pattern(slope angle). top(0.5), bottom(0.5), top(0.6), bottom(0.6)...
63
+ # time series features = (token_type, curve_angle)
64
+ # label = (label_onehot)
65
+ models = []
66
+
67
+ models.append(nn.LSTM(input_dim, hidden_dim[0], num_layers, batch_first=True, dropout=dropout_ratio))
68
+ for i in range(1, len(hidden_dim)):
69
+ models.append(nn.Linear(hidden_dim[i-1], hidden_dim[i]))
70
+
71
+ models.append(nn.Linear(hidden_dim[-1], output_dim))
72
+ self.layers = nn.Sequential(*models)
73
+
74
+ def forward(self, x):
75
+ # print("Shape of x:", x.shape)
76
+ for layer in self.layers:
77
+ if type(layer) == torch.nn.modules.rnn.LSTM:
78
+ x, _ = layer(x)
79
+ else:
80
+ x = layer(x)
81
+
82
+ return x
83
+
84
+ # create dataset. earthwork_feature -> label
85
+ class EarthworkDataset(Dataset):
86
+ def __init__(self, raw_data):
87
+ self.raw_dataset = raw_data
88
+
89
+ def __len__(self):
90
+ return len(self.raw_dataset)
91
+
92
+ def __getitem__(self, idx):
93
+ # origin_data = self.raw_dataset[idx]
94
+ features = self.raw_dataset[idx]['feature_dims'] # already, tokenized from 'feature_text'
95
+ label = self.raw_dataset[idx]['label_onehot']
96
+ features = torch.tensor(features, dtype=torch.float32).to(device)
97
+ label = torch.tensor(label, dtype=torch.float32).to(device)
98
+ return features, label
99
+
100
+ def decode_data_to_geom(input_dataset, predictions, labels, input_feature_dims, label_kinds):
101
+ global hyperparam
102
+ match_count = 0
103
+ for i in range(len(input_dataset)): # batch size
104
+ input_geom_features = input_dataset[i].cpu().numpy()
105
+ prediction_index = predictions[i].item()
106
+ label_index = labels[i].cpu().numpy()
107
+
108
+ geom_feautres = []
109
+ for j in range(len(input_feature_dims)):
110
+ if input_geom_features[j] == 0.0:
111
+ continue
112
+ geom_feautres.append(f'{input_feature_dims[j]}({input_geom_features[j]:.2f})')
113
+
114
+ prediction_label = label_kinds[prediction_index]
115
+ label = label_kinds[label_index]
116
+
117
+ match = prediction_label == label
118
+ if match:
119
+ match_count += 1
120
+ logger.debug(f'{hyperparam["model"]} {hyperparam["hidden_dim"]} Equal : {prediction_label == label}, Label: {label}, Predicted: {prediction_label}, Geom: {geom_feautres}')
121
+
122
+ return match_count
123
+
124
+ def test_mlp_model(model, batch_size, test_raw_dataset, input_feature_dims, label_kinds):
125
+ print(f'test data count: {len(test_raw_dataset)}')
126
+ test_dataset = EarthworkDataset(test_raw_dataset)
127
+ test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
128
+
129
+ # test model
130
+ accuracies = []
131
+ rmse = 0.0
132
+ correct = 0
133
+ total = 0
134
+ total_match = 0
135
+ with torch.no_grad():
136
+ for i, (data, labels) in enumerate(test_dataloader):
137
+ outputs = model(data)
138
+ _, predicted = torch.max(outputs.data, 1)
139
+ _, labels = torch.max(labels.data, 1)
140
+ total += labels.size(0)
141
+ correct += (predicted == labels).sum().item()
142
+ accuracies.append(correct / total)
143
+
144
+ match_count = decode_data_to_geom(data, predicted, labels, input_feature_dims, label_kinds)
145
+ total_match += match_count
146
+
147
+ average_accuracy = correct / total
148
+ print(f'Match count: {total_match}, Total count: {total}')
149
+ print(f'Accuracy of the network on the test data: {average_accuracy:.4f}')
150
+ return accuracies, average_accuracy
151
+
152
+ def run_MLP_LSTM(model_file_list, base_model):
153
+ global hyperparam
154
+
155
+ # prepare train dataset
156
+ data_dir = './dataset'
157
+ geom_list = load_train_chunk_data(data_dir)
158
+ input_feature_dims = update_feature_dims_freq(geom_list) # input_feature_dims = update_feature_dims_token(geom_list)
159
+ label_kinds = update_onehot_encoding(geom_list)
160
+
161
+ train_raw_dataset = geom_list[:int(len(geom_list) * 0.8)]
162
+ test_raw_dataset = geom_list[int(len(geom_list) * 0.8):]
163
+ print(f'total data count: {len(geom_list)}')
164
+ print(f'train data count: {len(train_raw_dataset)}, test data count: {len(test_raw_dataset)}')
165
+
166
+ # train model and write it
167
+ param_layers = [[128], [128, 64, 32], [256, 128, 64]]
168
+ if base_model == 'MLP':
169
+ param_layers = [[128, 64, 32], [64, 128, 64], [64, 128, 64, 32], [32, 64, 32]]
170
+ for index, param_layer in enumerate(param_layers):
171
+ logger.debug(f'model : {base_model}')
172
+
173
+ params = {
174
+ 'model': base_model,
175
+ 'input_dim': len(input_feature_dims),
176
+ 'hidden_dim': param_layer, # 0.95, [128, 64, 32],
177
+ 'output_dim': len(label_kinds),
178
+ 'batch_size': 32,
179
+ 'epochs': 150, # 150, # 5000
180
+ 'lr': 0.001
181
+ }
182
+ hyperparam = params
183
+ # create train model
184
+ model = EarthworkNetMLP(params['input_dim'], params['hidden_dim'], params['output_dim']).to(device)
185
+ if base_model == 'LSTM':
186
+ model = EarthworkNetLSTM(params['input_dim'], params['hidden_dim'], params['output_dim']).to(device)
187
+ model_file = './' + model_file_list[index]
188
+ model.load_state_dict(torch.load(model_file))
189
+ model.eval()
190
+
191
+ accuracies, acc = test_mlp_model(model, params['batch_size'], test_raw_dataset, input_feature_dims, label_kinds)
192
+
193
+ # Generate random training data
194
+ def generate_random_text(label_index, length=100):
195
+ base_text = f'This is text for label R{label_index + 1}. '
196
+ random_text_length = max(0, length - len(base_text)) # Calculate the length of the random text to generate
197
+ random_text = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(random_text_length)) # Generate the random text
198
+ return base_text + random_text
199
+
200
+ # Define dataset class
201
+ class EarthworkTransformDataset(Dataset):
202
+ def __init__(self, input_ids, attention_mask, labels):
203
+ self.input_ids = input_ids
204
+ self.attention_mask = attention_mask
205
+ self.labels = labels
206
+
207
+ def __len__(self):
208
+ return len(self.input_ids)
209
+
210
+ def __getitem__(self, idx):
211
+ input_ids_tensor = torch.tensor(self.input_ids[idx]).to(device)
212
+ attention_mask_tensor = torch.tensor(self.attention_mask[idx]).to(device)
213
+ label_tensor = torch.tensor(self.labels[idx]).to(device)
214
+ return input_ids_tensor, attention_mask_tensor, label_tensor
215
+
216
+ # custom transformer
217
+ class PositionalEncoding(nn.Module):
218
+ def __init__(self, d_model, vocab_size=5000, dropout=0.1):
219
+ super().__init__()
220
+ self.dropout = nn.Dropout(p=dropout)
221
+
222
+ pe = torch.zeros(vocab_size, d_model)
223
+ position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)
224
+ div_term = torch.exp(
225
+ torch.arange(0, d_model, 2).float()
226
+ * (-math.log(10000.0) / d_model)
227
+ )
228
+ pe[:, 0::2] = torch.sin(position * div_term)
229
+ pe[:, 1::2] = torch.cos(position * div_term)
230
+ pe = pe.unsqueeze(0)
231
+ self.register_buffer("pe", pe)
232
+
233
+ def forward(self, x):
234
+ x = x + self.pe[:, : x.size(1), :]
235
+ return self.dropout(x)
236
+
237
+ class EarthworkNetTransformer(nn.Module):
238
+ def __init__(
239
+ self,
240
+ input_feature_size,
241
+ d_model,
242
+ num_labels,
243
+ nhead=8,
244
+ dim_feedforward=2048,
245
+ dim_fc=[64, 32],
246
+ num_layers=6,
247
+ dropout=0.1,
248
+ activation="relu",
249
+ classifier_dropout=0.1,
250
+ ):
251
+ super().__init__()
252
+
253
+ self.d_model = d_model
254
+ # self.pos_encoder = PositionalEncoding(d_model=d_model, dropout=dropout, vocab_size=vocab_size)
255
+
256
+ self.input_fc = nn.Linear(input_feature_size, d_model)
257
+ encoder_layer = nn.TransformerEncoderLayer(
258
+ d_model=d_model,
259
+ nhead=nhead,
260
+ dim_feedforward=dim_feedforward,
261
+ dropout=dropout
262
+ )
263
+
264
+ self.src_mask = None
265
+ self.nhead = nhead
266
+ self.transformer_encoder = nn.TransformerEncoder(
267
+ encoder_layer,
268
+ num_layers=num_layers,
269
+ # TBD. output_attentions=True
270
+ )
271
+ self.fc_layers = []
272
+ fc_layers_dims = [d_model] + dim_fc + [num_labels]
273
+ for i in range(1, len(fc_layers_dims)):
274
+ fc = nn.Linear(fc_layers_dims[i-1], fc_layers_dims[i]).to(device)
275
+ self.fc_layers.append(fc)
276
+
277
+ self.init_weights()
278
+
279
+ def generate_square_subsequent_mask(self, sz):
280
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
281
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
282
+ return mask
283
+
284
+ def init_weights(self):
285
+ initrange = 0.1
286
+ for fc in self.fc_layers:
287
+ fc.bias.data.zero_()
288
+ fc.weight.data.uniform_(-initrange, initrange)
289
+
290
+ def forward(self, x, attention_mask):
291
+ # x = self.pos_encoder(x)
292
+ if self.src_mask is None or self.src_mask.size(0) != len(x):
293
+ device = x.device
294
+ mask = self.generate_square_subsequent_mask(len(x)).to(device)
295
+ self.src_mask = mask
296
+ # batch_size = x.shape[0]
297
+ # mask = torch.tril(torch.ones(self.nhead, batch_size, batch_size)).to(x.device)
298
+
299
+ x = x.float()
300
+ x = self.input_fc(x)
301
+ x = self.transformer_encoder(x, mask=self.src_mask) # , src_key_padding_mask=attention_mask1) # , mask=attention_mask)
302
+ # x = x.mean(dim=1)
303
+ for fc in self.fc_layers:
304
+ x = fc(x)
305
+
306
+ return x
307
+
308
+ def run_transform(model_file_list):
309
+ data_dir = './dataset'
310
+ geom_list = load_train_chunk_data(data_dir)
311
+ input_feature_dims = update_feature_dims_freq(geom_list) # input_feature_dims = update_feature_dims_token(geom_list)
312
+ label_kinds = update_onehot_encoding(geom_list)
313
+ num_labels = len(label_kinds)
314
+ max_input_string = max(len(d['feature_text']) for d in geom_list)
315
+ max_input_string = 320 # nhead=8. 320=8*40
316
+
317
+ train_raw_dataset = geom_list[:int(len(geom_list) * 0.8)]
318
+ test_raw_dataset = geom_list[int(len(geom_list) * 0.8):]
319
+ print(f'total data count: {len(geom_list)}')
320
+ print(f'train data count: {len(train_raw_dataset)}, test data count: {len(test_raw_dataset)}')
321
+
322
+ # Tokenize and pad sequences
323
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
324
+ max_length = max_input_string
325
+
326
+ batch_sizes = [32, 64, 128]
327
+ for index, batch_size in enumerate(batch_sizes):
328
+ encoding = {'input_ids': [], 'attention_mask': []}
329
+ for d in train_raw_dataset:
330
+ token_text = tokenizer(d['feature_text'], padding='max_length', truncation=True, max_length=max_length)
331
+ if len(token_text['input_ids']) < max_length: # fill the rest with padding token
332
+ token_text['input_ids'] += [tokenizer.pad_token_id] * (max_length - len(token_text['input_ids']))
333
+ token_text['attention_mask'] += [0] * (max_length - len(token_text['attention_mask']))
334
+ encoding['input_ids'].append(token_text['input_ids'])
335
+ encoding['attention_mask'].append(token_text['attention_mask'])
336
+
337
+ input_ids = encoding['input_ids']
338
+ attention_mask = encoding['attention_mask']
339
+
340
+ label2id = {label: i for i, label in enumerate(sorted(set(d['label'] for d in train_raw_dataset)))}
341
+ id2label = {v: k for k, v in label2id.items()}
342
+ labels = [label2id[d['label']] for d in train_raw_dataset] # Convert labels to numerical format
343
+
344
+ # hyperparameters
345
+ logger.debug(f'model : transformer')
346
+
347
+ params = {
348
+ 'model': 'transformer',
349
+ 'input_dim': len(input_feature_dims),
350
+ 'hidden_dim': [64],
351
+ 'output_dim': len(label2id),
352
+ 'batch_size': batch_size,
353
+ 'epochs': 300,
354
+ 'lr': 1e-5
355
+ }
356
+
357
+ # batch_size = params['batch_size'] # 32, 64, 128
358
+ dim_fc = params['hidden_dim']
359
+ epochs = params['epochs'] # 5000 # 500 150
360
+
361
+ # model
362
+ model = EarthworkNetTransformer(input_feature_size=max_length, d_model=512, num_labels=len(label2id), dim_fc=dim_fc).to(device)
363
+ dataset = EarthworkTransformDataset(input_ids, attention_mask, labels)
364
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
365
+
366
+ # test the model
367
+ model_file = './' + model_file_list[index]
368
+ model.load_state_dict(torch.load(model_file))
369
+ model.eval()
370
+
371
+ for i, test_raw in enumerate(test_raw_dataset):
372
+ label = test_raw['label']
373
+ input_text = test_raw['feature_text']
374
+ encoding = tokenizer(input_text, return_tensors='pt', padding='max_length', truncation=True, max_length=max_length)
375
+ input_ids = encoding['input_ids'].to(device)
376
+ attention_mask = encoding['attention_mask'].to(device)
377
+ output = model(input_ids, attention_mask)
378
+ predicted_label = id2label[output.argmax().item()]
379
+
380
+ feature_dims = input_text.split(' ')
381
+ logger.debug(f'{params["model"]} {params["batch_size"]} Equal : {predicted_label == label}, Label: {label}, Predicted: {predicted_label}, Geom: {feature_dims}')
382
+
383
+
384
+ print(f'test data count: {len(test_raw_dataset)}')
385
+ encoding = tokenizer([d['feature_text'] for d in test_raw_dataset], padding='max_length', truncation=True, max_length=max_length)
386
+ input_ids = encoding['input_ids']
387
+ attention_mask = encoding['attention_mask']
388
+
389
+ label2id = {label: i for i, label in enumerate(sorted(set(d['label'] for d in test_raw_dataset)))}
390
+ id2label = {v: k for k, v in label2id.items()}
391
+ labels = [label2id[d['label']] for d in test_raw_dataset] # Convert labels to numerical format
392
+
393
+ test_dataset = EarthworkTransformDataset(input_ids, attention_mask, labels)
394
+ test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
395
+
396
+ correct = 0
397
+ total = 0
398
+ accuracies = []
399
+ with torch.no_grad():
400
+ for i, (input_ids, attention_mask, labels) in enumerate(tqdm(test_dataloader, desc="test")):
401
+ outputs = model(input_ids, attention_mask)
402
+ _, predicted = torch.max(outputs, 1)
403
+ total += len(labels)
404
+ correct += (predicted == labels).sum().item()
405
+ accuracies.append(correct / total)
406
+
407
+ average_accuracy = correct / total
408
+ print(f'Accuracy of the network on the test data: {average_accuracy:.4f}')
409
+
410
+ # BERT model
411
+ class EarthworkBertDataset(Dataset):
412
+ def __init__(self, input_ids, attention_mask, labels):
413
+ self.input_ids = input_ids
414
+ self.attention_mask = attention_mask
415
+ self.labels = labels
416
+
417
+ def __len__(self):
418
+ return len(self.input_ids)
419
+
420
+ def __getitem__(self, idx):
421
+ input_ids_tensor = torch.tensor(self.input_ids[idx]).to(device)
422
+ attention_mask_tensor = torch.tensor(self.attention_mask[idx]).to(device)
423
+ label_tensor = torch.tensor(self.labels[idx]).to(device)
424
+ return input_ids_tensor, attention_mask_tensor, label_tensor
425
+
426
+ # Define EarthworkNetTransformer model architecture
427
+ class EarthworkNetTransformerBert(torch.nn.Module):
428
+ def __init__(self, num_labels):
429
+ super(EarthworkNetTransformerBert, self).__init__()
430
+ self.bert = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels, output_attentions=True)
431
+
432
+ def forward(self, input_ids, attention_mask):
433
+ outputs = self.bert(input_ids, attention_mask=attention_mask)
434
+ return outputs['logits'], outputs['attentions']
435
+
436
+ def run_bert(model_file):
437
+ # prepare train dataset
438
+ data_dir = './dataset'
439
+ geom_list = load_train_chunk_data(data_dir)
440
+ input_feature_dims = update_feature_dims_freq(geom_list) # input_feature_dims = update_feature_dims_token(geom_list)
441
+ label_kinds = update_onehot_encoding(geom_list)
442
+ num_labels = len(label_kinds)
443
+ max_input_string = max(len(d['feature_text']) for d in geom_list)
444
+
445
+ train_raw_dataset = geom_list[:int(len(geom_list) * 0.8)]
446
+ test_raw_dataset = geom_list[int(len(geom_list) * 0.8):]
447
+ print(f'total data count: {len(geom_list)}')
448
+ print(f'train data count: {len(train_raw_dataset)}, test data count: {len(test_raw_dataset)}')
449
+
450
+ # Tokenize and pad sequences
451
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
452
+ max_length = max_input_string
453
+
454
+ encoding = tokenizer([d['feature_text'] for d in train_raw_dataset], padding=True, truncation=True, max_length=max_length)
455
+ input_ids = encoding['input_ids'] # TBD. shape is 50?
456
+ attention_mask = encoding['attention_mask']
457
+
458
+ label2id = {label: i for i, label in enumerate(sorted(set(d['label'] for d in train_raw_dataset)))}
459
+ id2label = {v: k for k, v in label2id.items()}
460
+ labels = [label2id[d['label']] for d in train_raw_dataset] # Convert labels to numerical format
461
+
462
+ # Initialize model
463
+ model = EarthworkNetTransformerBert(num_labels=len(label2id)).to(device)
464
+
465
+ epochs = 150 # 50 #
466
+ batch_size = 32
467
+ params = {
468
+ 'model': 'BERT',
469
+ 'input_dim': len(input_feature_dims),
470
+ 'hidden_dim': 512,
471
+ 'output_dim': len(label2id),
472
+ 'batch_size': batch_size,
473
+ 'epochs': epochs,
474
+ 'lr': 1e-5,
475
+ }
476
+
477
+ dataset = EarthworkBertDataset(input_ids, attention_mask, labels)
478
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
479
+
480
+ # test the model
481
+ logger.debug(f'model : bert')
482
+
483
+ model_file = './' + model_file
484
+ model.load_state_dict(torch.load(model_file))
485
+ model.eval()
486
+
487
+ for i, test_raw in enumerate(test_raw_dataset):
488
+ label = test_raw['label']
489
+ input_text = test_raw['feature_text']
490
+ encoding = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
491
+ input_ids = encoding['input_ids'].to(device)
492
+ attention_mask = encoding['attention_mask'].to(device)
493
+ output, att = model(input_ids, attention_mask)
494
+ predicted_label = id2label[output.argmax().item()]
495
+
496
+ feature_dims = input_text.split(' ')
497
+ logger.debug(f'{params["model"]} Equal : {predicted_label == label}, Label: {label}, Predicted: {predicted_label}, Geom: {feature_dims}')
498
+
499
+ attention_matrix = att[-1]
500
+ attention_layer = attention_matrix[-1]
501
+ attention_mat = attention_layer[-1]
502
+ # for j, attention_mat in enumerate(attention_layer):
503
+ att_mat = attention_mat.detach().cpu().numpy()
504
+ fig, ax = plt.subplots()
505
+ cax = ax.matshow(att_mat, cmap='viridis')
506
+ fig.colorbar(cax)
507
+ plt.savefig(f'./graph/bert_attention_{i}.png')
508
+ plt.close()
509
+
510
+ print(f'test data count: {len(test_raw_dataset)}')
511
+ encoding = tokenizer([d['feature_text'] for d in test_raw_dataset], padding=True, truncation=True, max_length=max_length)
512
+ input_ids = encoding['input_ids']
513
+ attention_mask = encoding['attention_mask']
514
+
515
+ label2id = {label: i for i, label in enumerate(sorted(set(d['label'] for d in test_raw_dataset)))}
516
+ id2label = {v: k for k, v in label2id.items()}
517
+ labels = [label2id[d['label']] for d in test_raw_dataset] # Convert labels to numerical format
518
+
519
+ test_dataset = EarthworkBertDataset(input_ids, attention_mask, labels)
520
+ test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
521
+
522
+ correct = 0
523
+ total = 0
524
+ accuracies = []
525
+ with torch.no_grad():
526
+ for i, (input_ids, attention_mask, labels) in enumerate(tqdm(test_dataloader, desc="test")):
527
+ outputs, att = model(input_ids, attention_mask)
528
+ _, predicted = torch.max(outputs, 1)
529
+ total += len(labels)
530
+ correct += (predicted == labels).sum().item()
531
+ accuracies.append(correct / total)
532
+ y_score = torch.nn.functional.softmax(outputs, dim=1)
533
+
534
+ average_accuracy = correct / total
535
+ print(f'Accuracy of the network on the test data: {average_accuracy:.4f}')
536
+
537
+
538
+ if __name__ == '__main__':
539
+ models = ['earthwork_model_20240503_1650.pth','earthwork_model_20240503_1714.pth','earthwork_model_20240503_1716.pth','earthwork_model_20240503_1718.pth']
540
+ run_MLP_LSTM(models, 'MLP')
541
+
542
+ models = ['earthwork_model_20240503_1730.pth','earthwork_model_20240503_1732.pth','earthwork_model_20240503_1734.pth']
543
+ run_MLP_LSTM(models, 'LSTM')
544
+
545
+ models = ['earthwork_trans_model_20240503_2003.pth','earthwork_trans_model_20240503_2014.pth','earthwork_trans_model_20240503_2021.pth']
546
+ run_transform(models)
547
+
548
+ run_bert('earthwork_trans_model_20240504_0103.pth')
eval_model.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, confusion_matrix
3
+ from sklearn.metrics import average_precision_score
4
+
5
+ # Assuming y_true and y_pred are your data
6
+ y_true = [0, 1, 1, 0, 1, 1]
7
+ y_pred = [0, 0, 1, 0, 0, 1]
8
+
9
+ # Assuming y_true and y_pred are your data
10
+ y_true = [[0, 1, 1], [0, 1, 1], [1, 0, 1]]
11
+ y_pred = [[0, 0, 1], [0, 0, 1], [1, 0, 0]]
12
+
13
+ class model_metrics:
14
+ def __init__(self):
15
+ self.clear()
16
+
17
+ def clear(self):
18
+ self.accuracy = 0.0
19
+ self.recall = 0.0
20
+ self.precision = 0.0
21
+ self.f1 = 0.0
22
+ self.mAP = 0.0
23
+ self.cm = np.asarray([])
24
+
25
+ self.count = 0
26
+ self.total_accuracy = 0.0
27
+ self.total_recall = 0.0
28
+ self.total_precision = 0.0
29
+ self.total_f1 = 0.0
30
+ self.total_mAP = 0.0
31
+ self.total_cm = np.asarray([])
32
+
33
+ def get_indicators(self):
34
+ return self.total_accuracy / self.count, self.total_recall / self.count, self.total_precision / self.count, self.total_f1 / self.count, self.total_mAP / self.count, self.total_cm / self.count
35
+
36
+ def dump(self):
37
+ print(f"Accuracy: {self.accuracy}")
38
+ print(f"Recall: {self.recall}")
39
+ print(f"Precision: {self.precision}")
40
+ print(f"F1 Score: {self.f1}")
41
+ print(f"mAP: {self.mAP}")
42
+ print(f"Confusion Matrix: \n{self.cm}")
43
+
44
+ print(f'average accuracy: {self.total_accuracy / self.count}')
45
+ print(f'average recall: {self.total_recall / self.count}')
46
+ print(f'average precision: {self.total_precision / self.count}')
47
+ print(f'average f1: {self.total_f1 / self.count}')
48
+ print(f'average mAP: {self.total_mAP / self.count}')
49
+ print(f'average confusion matrix: \n{self.total_cm / self.count}')
50
+
51
+ def calc_metrics(self, y_true, y_pred, y_score):
52
+ self.accuracy = accuracy_score(y_true, y_pred)
53
+ self.recall = recall_score(y_true, y_pred, average='weighted')
54
+ self.precision = precision_score(y_true, y_pred, average='micro')
55
+ self.cm = confusion_matrix(y_true, y_pred)
56
+ self.count += 1
57
+
58
+ self.total_accuracy += self.accuracy
59
+ self.total_recall += self.recall
60
+ self.total_precision += self.precision
61
+ self.total_f1 += self.f1
62
+ self.total_mAP += self.mAP
63
+ self.total_cm = self.cm # TBD
64
+
65
+ return self.accuracy, self.recall, self.precision, self.f1, self.mAP, self.cm
66
+
67
+ def calc_metrics_multi(self, y_true, y_pred):
68
+ self.accuracy = accuracy_score(y_true, y_pred)
69
+ self.recall = recall_score(y_true, y_pred, average='micro')
70
+ self.precision = precision_score(y_true, y_pred, average='micro')
71
+ self.f1 = f1_score(y_true, y_pred, average='micro')
72
+ self.mAP = average_precision_score(y_true, y_pred, average='micro')
73
+ self.count += 1
74
+
75
+ return self.accuracy, self.recall, self.precision, self.f1, self.mAP, self.cm
extract_ewlog.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.pyplot as plt_xsec
4
+ from datetime import datetime
5
+
6
+ input_log_file = './ewnet_logs_TRANS3_20240708.txt'
7
+ flag_all_xsections = True
8
+ prev_station = ''
9
+
10
+ now = datetime.now()
11
+ now_str = now.strftime('%Y%m%d_%H%M')
12
+
13
+ label_list = ['pave_layer1', 'pave_layer2', 'pave_layer3', 'pave_layer4', 'cut_ea', 'cut_rr', 'cut_br', 'cut_ditch', 'fill_subbed', 'fill_subbody', 'curb', 'above', 'below', 'pave_int', 'pave_surface', 'pave_subgrade', 'ground', 'pave_bottom', 'rr', 'br', 'slope', 'struct', 'steps']
14
+ color_list = [[0.8,0.8,0.8],[0.6,0.6,0.6],[0.4,0.4,0.4],[0.2,0.2,0.2],[0.8,0.4,0.2],[0.8,0.6,0.2],[0.8,0.8,0.2],[0.6,0.8,0.2],[0.3,0.8,0.3],[0.3,0.6,0.3],[0.3,0.4,0.3],[0.0,0.8,0.0],[0.6,0.0,0.0],[0.8,0.0,0.0],[1.0,0.0,0.0],[0.2,0.2,0.6],[0.0,1.0,0.0],[0.2,0.2,1.0],[0.4,0.2,1.0],[0.6,0.2,1.0],[0.2,0.8,0.6],[0.8,0.2,1.0],[1.0,0.2,1.0]]
15
+
16
+ # make folder
17
+ if not os.path.exists('./graph'):
18
+ os.makedirs('./graph')
19
+
20
+ def draw_colorbox_list():
21
+ global label_list, color_list
22
+
23
+ fig, ax = plt.subplots(figsize=(9.2, 5))
24
+ ax.invert_yaxis()
25
+ ax.set_xlim(0, 1.5)
26
+ fig.set_size_inches(12, 7)
27
+
28
+ token_list = ['item1', 'item2', 'item3', 'item4', 'item5', 'item6']
29
+ for i, (colname, color) in enumerate(zip(label_list, color_list)):
30
+ width = 1.0 / len(label_list)
31
+ widths = [width] * len(token_list)
32
+ starts = width * i
33
+ rects = ax.barh(token_list, widths, left=starts, height=0.5, label=colname, color=color)
34
+
35
+ text_color = 'white' if np.max(color) < 0.4 else 'black'
36
+ ax.legend()
37
+ plt.savefig('./graph/box_colors.png')
38
+ plt.close()
39
+
40
+ def output_graph_matrics(index, tag, text):
41
+ global label_list, color_list
42
+
43
+ prediction = ''
44
+ tokens = []
45
+ polyline = []
46
+ geom_index = text.find('Geom:')
47
+ if geom_index >= 0:
48
+ pred_label = ''
49
+ label_index = text.find('Predicted: ')
50
+ if label_index >= 0:
51
+ pred = text[label_index + 11:geom_index]
52
+ labels = pred.split(', ')
53
+ if len(labels) > 0:
54
+ prediction = labels[0]
55
+ pred_label = labels[0] + '(0.3'
56
+
57
+ polyline_index = text.find('Polyline:')
58
+ if polyline_index > 0:
59
+ pred = text[geom_index + 6:polyline_index - 2]
60
+ polyline_text = text[polyline_index + 10:]
61
+ polyline = eval(polyline_text)
62
+ else:
63
+ pred = text[geom_index + 6:]
64
+ pred = pred.replace('[', '').replace(']', '')
65
+ pred = pred.replace(')', '').replace("'", '')
66
+ tokens = pred.split(',')
67
+ if len(tokens) <= 1:
68
+ tokens = pred.split(' ')
69
+ if len(tokens) > 0:
70
+ tokens.insert(0, pred_label)
71
+ last = tokens[-1]
72
+ if len(last) == 0:
73
+ tokens.pop()
74
+ else:
75
+ return
76
+
77
+ token_list = [token.split('(')[0] for token in tokens]
78
+ token_list = [token.replace(' ', '') for token in token_list]
79
+ ratios = [float(token.split('(')[1]) for token in tokens]
80
+ results = {token_list[0]: ratios}
81
+
82
+ labels = [label.replace(" ", "") for label in list(results.keys())]
83
+ data = np.array(list(results.values()))
84
+ data_cum = data.cumsum(axis=1)
85
+ token_colors = [color_list[label_list.index(label)] for label in token_list]
86
+
87
+ global plt_xsec, now_str, flag_all_xsections
88
+ if flag_all_xsections == False:
89
+ fig, ax = plt.subplots(figsize=(9.2, 5))
90
+ ax.invert_yaxis()
91
+ ax.xaxis.set_visible(False)
92
+ ax.set_xlim(0, np.sum(data, axis=1).max())
93
+ fig.set_size_inches(15, 0.5)
94
+
95
+ for i, (colname, color) in enumerate(zip(token_list, token_colors)):
96
+ widths = data[:, i]
97
+ starts = data_cum[:, i] - widths
98
+ if i > 0:
99
+ starts += 0.02
100
+ rects = ax.barh(labels, widths, left=starts, height=0.5, label=colname, color=color)
101
+
102
+ if i != 0:
103
+ text_color = 'white' if np.max(color) < 0.4 else 'black'
104
+ ax.bar_label(rects, label_type='center', color=text_color)
105
+ ax.legend(ncols=len(token_list), bbox_to_anchor=(0, 1), loc='lower right', fontsize='small')
106
+
107
+ tag = tag.replace(' ', '_')
108
+ tag = tag.replace(':', '')
109
+
110
+ if text.find('True') > 0:
111
+ plt.savefig(f'./graph/box_list_{now_str}_{tag}_{index}_T.png')
112
+ else:
113
+ plt.savefig(f'./graph/box_list_{now_str}_{tag}_{index}_F.png')
114
+ plt.close()
115
+ else:
116
+ if polyline[0] != polyline[-1]:
117
+ polyline.append(polyline[0])
118
+ x, y = zip(*polyline)
119
+ color = color_list[label_list.index(prediction)]
120
+
121
+ plt_xsec.fill(x, y, color=color)
122
+ centroid_x = sum(x) / len(x)
123
+ centroid_y = sum(y) / len(y)
124
+ area = 0.5 * abs(sum(x[i]*y[i+1] - x[i+1]*y[i] for i in range(len(polyline)-1)))
125
+
126
+ if prediction.find('pave') < 0:
127
+ plt_xsec.text(centroid_x, centroid_y, f'{prediction}={area:.2f}', horizontalalignment='center', verticalalignment='center', fontsize=5, color='black')
128
+
129
+ return prediction, area, token_list
130
+
131
+ output_stations = ['4+440.00000', '3+780.00000', '3+800.00000', '3+880.00000', '3+940.00000']
132
+ def output_logs(tag, equal='none'):
133
+ global input_log_file, plt_xsec, now_str, prev_station, flag_all_xsection, output_stations
134
+
135
+ text_list = []
136
+ logs = []
137
+
138
+ with open(input_log_file, 'r') as file:
139
+ for index, label in enumerate(label_list):
140
+ file.seek(0)
141
+ for line in file:
142
+ if flag_all_xsections == False and line.find(tag) < 0:
143
+ continue
144
+ tag_model = tag.split(' ')[0]
145
+ if flag_all_xsections == True and line.find(tag_model) < 0:
146
+ continue
147
+ if flag_all_xsections == False and line.find('Label: ' + label) < 0:
148
+ continue
149
+ line = line.replace('\n', '')
150
+ if equal == 'none':
151
+ text_list.append(line)
152
+ elif line.find(equal) > 0:
153
+ text_list.append(line)
154
+ if flag_all_xsections == False:
155
+ break
156
+ if flag_all_xsections:
157
+ break
158
+
159
+ if len(text_list) == 0:
160
+ return logs
161
+
162
+ def extract_station(text):
163
+ sta_index = text.find('Station:') + 9 # Start of station value
164
+ end_index = text.find(',', sta_index)
165
+ return text[sta_index:end_index] if end_index != -1 else text[sta_index:]
166
+
167
+ text_list = sorted(text_list, key=extract_station)
168
+ station = ''
169
+ for index, text in enumerate(text_list):
170
+ sta_index = text.find('Station:')
171
+ equal_index = text.find('Equal: ')
172
+ equal_check = 'T' if text.find('True') > 0 else 'F'
173
+
174
+ if sta_index > 0 and equal_index > 0:
175
+ station = text[sta_index + 9:equal_index-2]
176
+ print(station)
177
+
178
+ try:
179
+ if len(output_stations) and output_stations.index(station) < 0:
180
+ continue
181
+ except Exception as e:
182
+ continue
183
+
184
+ if prev_station != station:
185
+ if len(prev_station) > 0:
186
+ plt_xsec.savefig(f'./graph/polygon_{now_str}_{tag}_{prev_station}_{equal_check}.png', dpi=300)
187
+ plt_xsec.close()
188
+
189
+ plt_xsec.figure()
190
+ plt_xsec.gca().set_xlim([-60, 60])
191
+ plt_xsec.gca().axis('equal')
192
+ plt_xsec.gca().text(0, 0, f'{station}', fontsize=12, color='black')
193
+
194
+ prev_station = station
195
+
196
+ text = text.replace('\n', '')
197
+ label, area, tokens = output_graph_matrics(index, tag, text)
198
+ log = {
199
+ 'index': index,
200
+ 'station': station,
201
+ 'label': label,
202
+ 'area': area,
203
+ 'tokens': tokens
204
+ }
205
+ logs.append(log)
206
+
207
+ if index == len(text_list) - 1:
208
+ plt_xsec.savefig(f'./graph/polygon_{now_str}_{tag}_{prev_station}_{equal_check}.png', dpi=300)
209
+ plt_xsec.close()
210
+
211
+ return logs
212
+
213
+ def main():
214
+ draw_colorbox_list()
215
+
216
+ summary_log_file = open('./graph/summary_log.csv', 'a')
217
+ if summary_log_file is None:
218
+ return
219
+ summary_log_file.write(f'model, ground true, length, ground false, length\n')
220
+
221
+ tags = ['MLP [128, 64, 32]', 'MLP [64, 128, 64]', 'MLP [64, 128, 64, 32]', 'LSTM [128]', 'LSTM [128, 64, 32]', 'LSTM [256, 128, 64]', 'transformer 32', 'transformer 64', 'transformer 128', 'BERT']
222
+ for tag in tags:
223
+ print(tag)
224
+ if len(output_stations) > 0:
225
+ logs1 = output_logs(tag,)
226
+ continue
227
+
228
+ logs1 = output_logs(tag, 'Equal: True')
229
+ logs2 = output_logs(tag, 'Equal: False')
230
+ if len(logs1) == 0 or len(logs2) == 0:
231
+ continue
232
+ area1 = area2 = 0
233
+ area1 += sum([log['area'] for log in logs1])
234
+ area2 += sum([log['area'] for log in logs2])
235
+ log_record = f'{tag}, {area1}, {len(logs1)}, {area2}, {len(logs2)}'
236
+ summary_log_file.write(f'{log_record}\n')
237
+
238
+ if flag_all_xsections:
239
+ break
240
+
241
+ summary_log_file.close()
242
+
243
+ if __name__ == '__main__':
244
+ main()
prepare_dataset.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # title: create earthwwork train dataset
2
+ # author: Taewook Kang
3
+ # date: 2024.3.27
4
+ # description: create earthwork train dataset
5
+ # license: MIT
6
+ # version
7
+ # 0.1. 2024.3.27. create file
8
+ #
9
+ import os, math, argparse, json, re, traceback, numpy as np, pandas as pd, trimesh, laspy, shutil
10
+ import logging, matplotlib.pyplot as plt, shapely
11
+ from shapely.geometry import Polygon, LineString
12
+ from scipy.spatial import distance
13
+ from tqdm import trange, tqdm
14
+ from math import pi
15
+
16
+ logging.basicConfig(level=logging.DEBUG, filename='logs.txt',
17
+ format='%(asctime)s %(levelname)s %(message)s',
18
+ datefmt='%H:%M:%S')
19
+ logger = logging.getLogger("prep")
20
+
21
+ _precision = 0.00001
22
+
23
+ def get_bbox(polyline):
24
+ polyline_np = np.array(polyline)
25
+ xmin, ymin = np.amin(polyline_np, axis=0)
26
+ xmax, ymax = np.amax(polyline_np, axis=0)
27
+ return (xmin, ymin, xmax, ymax)
28
+
29
+ def get_center_point(pline):
30
+ if len(pline) == 0:
31
+ return (0, 0)
32
+ xs = [p[0] for p in pline]
33
+ ys = [p[1] for p in pline]
34
+ return (sum(xs) / len(pline), sum(ys) / len(pline))
35
+
36
+ def intersect_line(line1, line2):
37
+ (x1, y1), (x2, y2) = line1
38
+ (x3, y3), (x4, y4) = line2
39
+
40
+ denominator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
41
+ if denominator == 0:
42
+ return None # lines are parallel
43
+
44
+ x = ((x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)) / denominator
45
+ y = ((x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)) / denominator
46
+
47
+ # check (x, y) in line1 and line2
48
+ if x < min(x1, x2) or x > max(x1, x2) or x < min(x3, x4) or x > max(x3, x4):
49
+ return None
50
+
51
+ return (x, y)
52
+
53
+ def get_positions_pline(base_pline, target_pline):
54
+ target_pos_marks = []
55
+ for i in range(len(target_pline)):
56
+ target = [target_pline[i], (target_pline[i][0], target_pline[i][1] + 1e+10)] # vertical line to check below
57
+ pos = 0.0
58
+ for j in range(len(base_pline) - 1):
59
+ base = [base_pline[j], base_pline[j + 1]]
60
+ intersect = intersect_line(base, target)
61
+ if intersect == None:
62
+ continue
63
+
64
+ if equal(intersect[1], target[0][1]):
65
+ pos = 0.0
66
+ break
67
+
68
+ pos = -1.0 if intersect[1] > target[0][1] else 1.0
69
+ break
70
+ target_pos_marks.append(pos)
71
+
72
+ return target_pos_marks
73
+
74
+ def get_below_pline(base_pline, target_pline):
75
+ pos_marks = get_positions_pline(base_pline, target_pline)
76
+ average = sum(pos_marks) / len(pos_marks)
77
+ return average < 0.0
78
+
79
+ def get_geometry(xsec, label):
80
+ for geom in xsec['geom']:
81
+ if geom['label'] == label:
82
+ return geom
83
+ return None
84
+
85
+ def is_point_in_rect(point1, point2, perp):
86
+ return is_point_in_rectangle(point1[0], point1[1], point2[0], point2[1], perp[0], perp[1])
87
+
88
+ def is_point_in_rectangle(x1, y1, x2, y2, x, y):
89
+ # Ensure that x1 <= x2 and y1 <= y2
90
+ x1, x2 = min(x1, x2), max(x1, x2)
91
+ y1, y2 = min(y1, y2), max(y1, y2)
92
+
93
+ # Check if (x, y) is within the rectangle
94
+ return x1 <= x <= x2 and y1 <= y <= y2
95
+
96
+ def sign_distance(a, b, c, p):
97
+ d = math.sqrt(a*a + b*b)
98
+ if d == 0.0:
99
+ lp = 0.0
100
+ else:
101
+ lp = (a * p[0] + b * p[1] + c) / d
102
+ return lp
103
+
104
+ def equal(a, b):
105
+ return abs(a - b) < _precision
106
+
107
+ def equal_point(p1, p2):
108
+ return equal(p1[0], p2[0]) and equal(p1[1], p2[1])
109
+
110
+ def get_angle(x1, y1, x2, y2):
111
+ pi = math.acos(-1.0)
112
+
113
+ # Caculate the quadrant of the line.
114
+ dx = x2 - x1
115
+ dy = y2 - y1
116
+
117
+ # Calculate the angle in radians for lines in the left and right quadrants
118
+ if dx == 0 and dy == 0:
119
+ return -1.0
120
+ if dy < 0 and dx == 0:
121
+ angle_radius = pi + pi / 2
122
+ elif dy > 0 and dx == 0:
123
+ angle_radius = pi / 2
124
+ else:
125
+ angle_radius = math.atan(dy / dx)
126
+
127
+ # Adjust the angle for different quadrants
128
+ if dy >= 0 and dx > 0:
129
+ pass
130
+ if dy < 0 and dx > 0:
131
+ angle_radius += 2 * pi
132
+ elif dx < 0:
133
+ angle_radius += pi
134
+
135
+ return angle_radius
136
+
137
+ def line_coefficients(point1, point2):
138
+ x1, y1 = point1
139
+ x2, y2 = point2
140
+
141
+ A = y2 - y1
142
+ B = x1 - x2
143
+ C = x2*y1 - x1*y2
144
+
145
+ return A, B, C
146
+
147
+ def sign_point_on_line(point1, point2, new_point):
148
+ if equal(point1[0], new_point[0]) and equal(point1[1], new_point[1]):
149
+ return 0.0
150
+ if equal(point2[0], new_point[0]) and equal(point2[1], new_point[1]):
151
+ return 0.0
152
+
153
+ line_A, line_B, line_C = line_coefficients(point1, point2)
154
+ x, y = new_point
155
+ value = line_A * x + line_B * y + line_C
156
+ if math.fabs(value) < _precision:
157
+ return 0.0
158
+ elif value > 0.0:
159
+ return 1.0
160
+ return -1.0
161
+
162
+ def sign_distance_on_line(line_point1, line_point2, point):
163
+ direction = sign_point_on_line(line_point1, line_point2, point)
164
+ if direction == 0:
165
+ return 0.0
166
+
167
+ if math.fabs(line_point1[0] - line_point2[0]) < _precision and math.fabs(line_point1[1] <= line_point2[1]) < _precision:
168
+ return 0.0
169
+
170
+ # TBD. bug
171
+ x = point[0]
172
+ y = point[1]
173
+ x1 = line_point1[0]
174
+ y1 = line_point1[1]
175
+ x2 = line_point2[0]
176
+ y2 = line_point2[1]
177
+
178
+ if x1 <= x2:
179
+ a = 1
180
+ b = 0
181
+ c = -x1
182
+ else:
183
+ m = (y2 - y1) / (x2 - x1)
184
+ a = -m
185
+ b = 1
186
+ c = -y1 + (m * x1)
187
+
188
+ dist = abs(a * x + b * y + c) / math.sqrt(a * a + b * b)
189
+ dist *= float(direction)
190
+
191
+ return dist
192
+
193
+ def is_point_on_line(point1, point2, perp):
194
+ if is_point_in_rect(point1, point2, perp) == False:
195
+ return False
196
+ direction = sign_point_on_line(point1, point2, perp)
197
+ if math.fabs(direction) < _precision:
198
+ return True
199
+ return False
200
+
201
+ def is_overlap_line(line, part_seg):
202
+ p1, p2 = line
203
+ p3, p4 = part_seg
204
+
205
+ f1 = is_point_on_line(p1, p2, p3)
206
+ f2 = is_point_on_line(p1, p2, p4)
207
+
208
+ if (f1 or f2) and f1 != f2: # dangling point is not overlap.
209
+ if f1 and (equal_point(p1, p3) or equal_point(p2, p3)):
210
+ return False
211
+ if f2 and (equal_point(p1, p4) or equal_point(p2, p4)):
212
+ return False
213
+
214
+ return f1 or f2
215
+
216
+ def is_on_pline(polyline, base_line):
217
+ p1 = base_line[0]
218
+ p2 = base_line[1]
219
+
220
+ for i in range(len(polyline) - 1):
221
+ p3 = polyline[i]
222
+ p4 = polyline[i + 1]
223
+ if is_overlap_line((p1, p2), (p3, p4)) or is_overlap_line((p3, p4), (p1, p2)):
224
+ return True
225
+
226
+ return False
227
+
228
+ def get_match_line_labels(xsec, base_geom, base_line):
229
+ labels = []
230
+ for geom in xsec['geom']:
231
+ if geom == base_geom:
232
+ continue
233
+ geom_label = geom['label']
234
+ base_label = base_geom['label']
235
+ if geom_label == base_label:
236
+ continue
237
+ closed = geom['closed']
238
+ if closed == True: # only polyline is considered
239
+ continue
240
+ if geom_label == 'center':
241
+ continue
242
+
243
+ polyline = geom['polyline']
244
+ if is_on_pline(polyline, base_line):
245
+ labels.append(geom['label'])
246
+ return labels
247
+
248
+ def get_seq_feature_tokens(xsec, geom, closed_type):
249
+ polyline = geom['polyline']
250
+ closed = geom['closed']
251
+ if closed != closed_type:
252
+ return []
253
+
254
+ lines = []
255
+ for i in range(len(polyline) - 1):
256
+ line = (polyline[i], polyline[i + 1])
257
+ lines.append(line)
258
+
259
+ geom_tokens = []
260
+ for line in lines:
261
+ labels = get_match_line_labels(xsec, geom, line)
262
+ if len(labels) == 0:
263
+ continue
264
+
265
+ # if len(labels) == 1:
266
+ # geom_tokens.append(labels[0])
267
+ # else:
268
+ geom_tokens.extend(labels)
269
+
270
+ return geom_tokens
271
+
272
+ def translate_geometry(xsec, cp):
273
+ for geom in xsec['geom']:
274
+ polyline = geom['polyline']
275
+ geom['polyline'] = [(p[0] - cp[0], p[1] - cp[1]) for p in polyline]
276
+
277
+ return xsec
278
+
279
+ def is_closed(polyline):
280
+ if equal_point(polyline[0], polyline[-1]):
281
+ return True
282
+ return False
283
+
284
+ def summery_feature(features):
285
+ sum_features = []
286
+ if len(features) == 0:
287
+ return sum_features
288
+
289
+ index = 0
290
+ while index < len(features):
291
+ f = features[index]
292
+ sum_feature = f
293
+ if type(f) == list:
294
+ sum_feature = summery_feature(f)
295
+ if len(sum_feature) == 1:
296
+ sum_feature = sum_feature[0]
297
+ elif type(f) == str:
298
+ label = f
299
+ # find last index of same level in features array with label
300
+ last_index = index
301
+ for i in range(index + 1, len(features)):
302
+ if type(features[i]) == str:
303
+ if features[i] == label:
304
+ last_index = i
305
+ else:
306
+ break
307
+ else:
308
+ break
309
+ if last_index != index:
310
+ sum_feature = (f'{f}({last_index - index + 1})')
311
+ index = last_index
312
+ else:
313
+ pass
314
+
315
+ sum_features.append(sum_feature)
316
+ index += 1
317
+
318
+ return sum_features
319
+
320
+ def get_intersection_count(xsec, base_geom, target_label):
321
+ pave_top = get_geometry(xsec, 'pave_surface')
322
+ if pave_top == None:
323
+ return 0
324
+
325
+ polyline = base_geom['polyline']
326
+ polygon = Polygon(polyline)
327
+ base_p1 = polygon.centroid
328
+ base_p2 = (base_p1.x, base_p1.y + 1e+10)
329
+ vertical_line = LineString([base_p1, base_p2])
330
+
331
+ count = 0
332
+ for target_geom in xsec['geom']:
333
+ if base_geom == target_geom:
334
+ continue
335
+ label = target_geom['label']
336
+ if re.search(target_label, label) == None:
337
+ continue
338
+
339
+ # check intersection
340
+ target_polyline = target_geom['polyline']
341
+ polyline = LineString(target_polyline)
342
+ ip = shapely.intersection(polyline, vertical_line) # https://shapely.readthedocs.io/en/stable/reference/shapely.intersection.html
343
+ if ip.is_empty:
344
+ continue
345
+ count += 1
346
+
347
+ return count
348
+
349
+ def update_xsection_feature(xsec):
350
+ gnd_geom = get_geometry(xsec, 'ground')
351
+ if gnd_geom == None:
352
+ return None
353
+
354
+ center = get_geometry(xsec, 'center')
355
+ if center == None or 'polyline' not in center:
356
+ return None
357
+ cp = get_center_point(center['polyline'])
358
+
359
+ xsec = translate_geometry(xsec, cp)
360
+ station = xsec['station']
361
+
362
+ index = 0
363
+ while index < len(xsec['geom']):
364
+ geom = xsec['geom'][index]
365
+ label = geom['label']
366
+ polyline = geom['polyline']
367
+ closed = geom['closed']
368
+ if len(polyline) <= 2 or closed == False:
369
+ index += 1
370
+ continue
371
+
372
+ pt1 = polyline[0]
373
+ pt2 = polyline[-1]
374
+ if equal_point(pt1, pt2) == False: # closed polyline
375
+ polyline.append(pt1)
376
+
377
+ # noise filtering
378
+ polygon = Polygon(polyline) # calculate area of polyline as polygon
379
+ if math.fabs(polygon.area) < _precision:
380
+ xsec['geom'].pop(index) # remove index element in xsec['geom']
381
+ continue
382
+
383
+ if station == '1+660.00000' and label == 'cut_ditch':
384
+ label = 'cut_ditch'
385
+
386
+ # processing
387
+ if get_below_pline(gnd_geom['polyline'], polyline):
388
+ geom['earthwork_feature'].append('below')
389
+ else:
390
+ geom['earthwork_feature'].append('above')
391
+
392
+ if re.search('pave_.*', label):
393
+ pave_int_count = get_intersection_count(xsec, geom, 'pave_.*')
394
+ geom['earthwork_feature'].append(f'pave_int({pave_int_count})')
395
+
396
+ tokens = get_seq_feature_tokens(xsec, geom, True)
397
+ if len(tokens) == 1:
398
+ geom['earthwork_feature'].append(tokens[0])
399
+ else:
400
+ geom['earthwork_feature'].extend(tokens)
401
+
402
+ geom['earthwork_feature'] = summery_feature(geom['earthwork_feature'])
403
+
404
+ # print(f'{station}. {label} feature: {geom["earthwork_feature"]}')
405
+ logger.debug(f'{station}. {label} feature: {geom["earthwork_feature"]}')
406
+ index += 1
407
+
408
+ return xsec
409
+
410
+ def update_xsections_feature(xsections):
411
+ # update closed polygon
412
+ for xsec in xsections:
413
+ for geom in xsec['geom']:
414
+ label = geom['label']
415
+ polyline = geom['polyline']
416
+ if len(polyline) < 2:
417
+ continue
418
+ closed = is_closed(polyline)
419
+ if closed == False: # exception case, pavement
420
+ closed = False if re.search('pave_layer.*', label) == None else True
421
+ geom['closed'] = closed
422
+
423
+ # update features
424
+ out_xsections = []
425
+ for xsec in xsections:
426
+ out_xsec = update_xsection_feature(xsec)
427
+ if out_xsec == None:
428
+ continue
429
+ out_xsections.append(out_xsec)
430
+
431
+ return out_xsections
432
+
433
+ def main():
434
+ parser = argparse.ArgumentParser(description='create earthwork train dataset')
435
+ parser.add_argument('--input', type=str, default='output/', help='input folder')
436
+ parser.add_argument('--output', type=str, default='dataset/', help='output folder')
437
+
438
+ args = parser.parse_args()
439
+ try:
440
+ file_names = os.listdir(args.input)
441
+ for file_name in tqdm(file_names):
442
+ if file_name.endswith('.json') == False:
443
+ continue
444
+ print(f'processing {file_name}')
445
+ data = None
446
+ with open(os.path.join(args.input, file_name), 'r') as f:
447
+ data = json.load(f)
448
+
449
+ out_xsections = update_xsections_feature(data)
450
+
451
+ output_file = os.path.join(args.output, file_name)
452
+ with open(output_file, 'w') as f:
453
+ json.dump(out_xsections, f, indent=4)
454
+
455
+ except Exception as e:
456
+ print(f'error: {e}')
457
+ traceback.print_exc()
458
+
459
+ if __name__ == '__main__':
460
+ main()