import os, re, numpy as np import matplotlib.pyplot as plt import matplotlib.pyplot as plt_xsec from datetime import datetime input_log_file = './ewnet_logs_TRANS3_20240708.txt' flag_all_xsections = True prev_station = '' now = datetime.now() now_str = now.strftime('%Y%m%d_%H%M') label_list = ['pave_layer1', 'pave_layer2', 'pave_layer3', 'pave_layer4', 'cut_ea', 'cut_rr', 'cut_br', 'cut_ditch', 'fill_subbed', 'fill_subbody', 'curb', 'above', 'below', 'pave_int', 'pave_surface', 'pave_subgrade', 'ground', 'pave_bottom', 'rr', 'br', 'slope', 'struct', 'steps'] color_list = [[0.8,0.8,0.8],[0.6,0.6,0.6],[0.4,0.4,0.4],[0.2,0.2,0.2],[0.8,0.4,0.2],[0.8,0.6,0.2],[0.8,0.8,0.2],[0.6,0.8,0.2],[0.3,0.8,0.3],[0.3,0.6,0.3],[0.3,0.4,0.3],[0.0,0.8,0.0],[0.6,0.0,0.0],[0.8,0.0,0.0],[1.0,0.0,0.0],[0.2,0.2,0.6],[0.0,1.0,0.0],[0.2,0.2,1.0],[0.4,0.2,1.0],[0.6,0.2,1.0],[0.2,0.8,0.6],[0.8,0.2,1.0],[1.0,0.2,1.0]] # make folder if not os.path.exists('./graph'): os.makedirs('./graph') def draw_colorbox_list(): global label_list, color_list fig, ax = plt.subplots(figsize=(9.2, 5)) ax.invert_yaxis() ax.set_xlim(0, 1.5) fig.set_size_inches(12, 7) token_list = ['item1', 'item2', 'item3', 'item4', 'item5', 'item6'] for i, (colname, color) in enumerate(zip(label_list, color_list)): width = 1.0 / len(label_list) widths = [width] * len(token_list) starts = width * i rects = ax.barh(token_list, widths, left=starts, height=0.5, label=colname, color=color) text_color = 'white' if np.max(color) < 0.4 else 'black' ax.legend() plt.savefig('./graph/box_colors.png') plt.close() def output_graph_matrics(index, tag, text): global label_list, color_list prediction = '' tokens = [] polyline = [] geom_index = text.find('Geom:') if geom_index >= 0: pred_label = '' label_index = text.find('Predicted: ') if label_index >= 0: pred = text[label_index + 11:geom_index] labels = pred.split(', ') if len(labels) > 0: prediction = labels[0] pred_label = labels[0] + '(0.3' polyline_index = text.find('Polyline:') if polyline_index > 0: pred = text[geom_index + 6:polyline_index - 2] polyline_text = text[polyline_index + 10:] polyline = eval(polyline_text) else: pred = text[geom_index + 6:] pred = pred.replace('[', '').replace(']', '') pred = pred.replace(')', '').replace("'", '') tokens = pred.split(',') if len(tokens) <= 1: tokens = pred.split(' ') if len(tokens) > 0: tokens.insert(0, pred_label) last = tokens[-1] if len(last) == 0: tokens.pop() else: return token_list = [token.split('(')[0] for token in tokens] token_list = [token.replace(' ', '') for token in token_list] ratios = [float(token.split('(')[1]) for token in tokens] results = {token_list[0]: ratios} labels = [label.replace(" ", "") for label in list(results.keys())] data = np.array(list(results.values())) data_cum = data.cumsum(axis=1) token_colors = [color_list[label_list.index(label)] for label in token_list] global plt_xsec, now_str, flag_all_xsections if flag_all_xsections == False: fig, ax = plt.subplots(figsize=(9.2, 5)) ax.invert_yaxis() ax.xaxis.set_visible(False) ax.set_xlim(0, np.sum(data, axis=1).max()) fig.set_size_inches(15, 0.5) for i, (colname, color) in enumerate(zip(token_list, token_colors)): widths = data[:, i] starts = data_cum[:, i] - widths if i > 0: starts += 0.02 rects = ax.barh(labels, widths, left=starts, height=0.5, label=colname, color=color) if i != 0: text_color = 'white' if np.max(color) < 0.4 else 'black' ax.bar_label(rects, label_type='center', color=text_color) ax.legend(ncols=len(token_list), bbox_to_anchor=(0, 1), loc='lower right', fontsize='small') tag = tag.replace(' ', '_') tag = tag.replace(':', '') if text.find('True') > 0: plt.savefig(f'./graph/box_list_{now_str}_{tag}_{index}_T.png') else: plt.savefig(f'./graph/box_list_{now_str}_{tag}_{index}_F.png') plt.close() else: if polyline[0] != polyline[-1]: polyline.append(polyline[0]) x, y = zip(*polyline) color = color_list[label_list.index(prediction)] plt_xsec.fill(x, y, color=color) centroid_x = sum(x) / len(x) centroid_y = sum(y) / len(y) area = 0.5 * abs(sum(x[i]*y[i+1] - x[i+1]*y[i] for i in range(len(polyline)-1))) if prediction.find('pave') < 0: plt_xsec.text(centroid_x, centroid_y, f'{prediction}={area:.2f}', horizontalalignment='center', verticalalignment='center', fontsize=5, color='black') return prediction, area, token_list output_stations = ['4+440.00000', '3+780.00000', '3+800.00000', '3+880.00000', '3+940.00000'] def output_logs(tag, equal='none'): global input_log_file, plt_xsec, now_str, prev_station, flag_all_xsection, output_stations text_list = [] logs = [] with open(input_log_file, 'r') as file: for index, label in enumerate(label_list): file.seek(0) for line in file: if flag_all_xsections == False and line.find(tag) < 0: continue tag_model = tag.split(' ')[0] if flag_all_xsections == True and line.find(tag_model) < 0: continue if flag_all_xsections == False and line.find('Label: ' + label) < 0: continue line = line.replace('\n', '') if equal == 'none': text_list.append(line) elif line.find(equal) > 0: text_list.append(line) if flag_all_xsections == False: break if flag_all_xsections: break if len(text_list) == 0: return logs def extract_station(text): sta_index = text.find('Station:') + 9 # Start of station value end_index = text.find(',', sta_index) return text[sta_index:end_index] if end_index != -1 else text[sta_index:] text_list = sorted(text_list, key=extract_station) station = '' for index, text in enumerate(text_list): sta_index = text.find('Station:') equal_index = text.find('Equal: ') equal_check = 'T' if text.find('True') > 0 else 'F' if sta_index > 0 and equal_index > 0: station = text[sta_index + 9:equal_index-2] print(station) try: if len(output_stations) and output_stations.index(station) < 0: continue except Exception as e: continue if prev_station != station: if len(prev_station) > 0: plt_xsec.savefig(f'./graph/polygon_{now_str}_{tag}_{prev_station}_{equal_check}.png', dpi=300) plt_xsec.close() plt_xsec.figure() plt_xsec.gca().set_xlim([-60, 60]) plt_xsec.gca().axis('equal') plt_xsec.gca().text(0, 0, f'{station}', fontsize=12, color='black') prev_station = station text = text.replace('\n', '') label, area, tokens = output_graph_matrics(index, tag, text) log = { 'index': index, 'station': station, 'label': label, 'area': area, 'tokens': tokens } logs.append(log) if index == len(text_list) - 1: plt_xsec.savefig(f'./graph/polygon_{now_str}_{tag}_{prev_station}_{equal_check}.png', dpi=300) plt_xsec.close() return logs def main(): draw_colorbox_list() summary_log_file = open('./graph/summary_log.csv', 'a') if summary_log_file is None: return summary_log_file.write(f'model, ground true, length, ground false, length\n') tags = ['MLP [128, 64, 32]', 'MLP [64, 128, 64]', 'MLP [64, 128, 64, 32]', 'LSTM [128]', 'LSTM [128, 64, 32]', 'LSTM [256, 128, 64]', 'transformer 32', 'transformer 64', 'transformer 128', 'BERT'] for tag in tags: print(tag) if len(output_stations) > 0: logs1 = output_logs(tag,) continue logs1 = output_logs(tag, 'Equal: True') logs2 = output_logs(tag, 'Equal: False') if len(logs1) == 0 or len(logs2) == 0: continue area1 = area2 = 0 area1 += sum([log['area'] for log in logs1]) area2 += sum([log['area'] for log in logs2]) log_record = f'{tag}, {area1}, {len(logs1)}, {area2}, {len(logs2)}' summary_log_file.write(f'{log_record}\n') if flag_all_xsections: break summary_log_file.close() if __name__ == '__main__': main()