|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script takes as an input XXXX.json files |
|
(i.e., the output of nmt_transformer_infer.py --write_timing) |
|
and creates plots XXX.PLOT_NAME.png at the same path. |
|
""" |
|
import json |
|
import os |
|
import sys |
|
|
|
from matplotlib import pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
PLOTS_EXT = "pdf" |
|
PLOT_TITLE = False |
|
PLOT_XLABEL = True |
|
PLOT_YLABEL = True |
|
PLOT_LABEL_FONT_SIZE = 16 |
|
PLOT_GRID = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_timing(lengths, timings, lengths_name, timings_name, fig=None): |
|
if fig is None: |
|
fig = plt.figure() |
|
|
|
plt.scatter(lengths, timings, label=timings_name) |
|
if PLOT_XLABEL: |
|
plt.xlabel(f"{lengths_name} [tokens]", fontsize=PLOT_LABEL_FONT_SIZE) |
|
if PLOT_YLABEL: |
|
plt.ylabel(f"{timings_name} [sec]", fontsize=PLOT_LABEL_FONT_SIZE) |
|
if PLOT_GRID: |
|
plt.grid(True) |
|
if PLOT_TITLE: |
|
plt.title(f"{timings_name} vs. {lengths_name}") |
|
|
|
plt.xticks(fontsize=PLOT_LABEL_FONT_SIZE) |
|
plt.yticks(fontsize=PLOT_LABEL_FONT_SIZE) |
|
plt.tight_layout() |
|
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
print("Usage: plot_detailed_timing.py <JSON FILE> <SJON FILE> ...") |
|
for timing_fn in sys.argv[1:]: |
|
|
|
print(f"Parsing file = {timing_fn}") |
|
data = json.load(open(timing_fn)) |
|
|
|
|
|
gifs_dict = {} |
|
gifs_dict["encoder-src_len"] = plot_timing( |
|
lengths=data["mean_src_length"], |
|
timings=data["encoder"], |
|
lengths_name="src length", |
|
timings_name="encoder", |
|
) |
|
gifs_dict["sampler-src_len"] = plot_timing( |
|
lengths=data["mean_src_length"], |
|
timings=data["sampler"], |
|
lengths_name="src length", |
|
timings_name="sampler", |
|
) |
|
gifs_dict["sampler-tgt_len"] = plot_timing( |
|
lengths=data["mean_tgt_length"], |
|
timings=data["sampler"], |
|
lengths_name="tgt length", |
|
timings_name="sampler", |
|
) |
|
|
|
|
|
base_fn = os.path.splitext(timing_fn)[0] |
|
for name, fig in gifs_dict.items(): |
|
plot_fn = f"{base_fn}.{name}.{PLOTS_EXT}" |
|
print(f"Saving pot = {plot_fn}") |
|
fig.savefig(plot_fn) |
|
|