File size: 3,342 Bytes
2d8da09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python3
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as an input XXXX.json files
(i.e., the output of nmt_transformer_infer.py --write_timing)
and creates plots XXX.PLOT_NAME.png at the same path.
"""
import json
import os
import sys

from matplotlib import pyplot as plt

# =============================================================================#
# Control Variables
# =============================================================================#

PLOTS_EXT = "pdf"
PLOT_TITLE = False
PLOT_XLABEL = True
PLOT_YLABEL = True
PLOT_LABEL_FONT_SIZE = 16
PLOT_GRID = True

# =============================================================================#
# Helper functions
# =============================================================================#


def plot_timing(lengths, timings, lengths_name, timings_name, fig=None):
    if fig is None:
        fig = plt.figure()

    plt.scatter(lengths, timings, label=timings_name)
    if PLOT_XLABEL:
        plt.xlabel(f"{lengths_name} [tokens]", fontsize=PLOT_LABEL_FONT_SIZE)
    if PLOT_YLABEL:
        plt.ylabel(f"{timings_name} [sec]", fontsize=PLOT_LABEL_FONT_SIZE)
    if PLOT_GRID:
        plt.grid(True)
    if PLOT_TITLE:
        plt.title(f"{timings_name} vs. {lengths_name}")

    plt.xticks(fontsize=PLOT_LABEL_FONT_SIZE)
    plt.yticks(fontsize=PLOT_LABEL_FONT_SIZE)
    plt.tight_layout()

    return fig


# =============================================================================#
# Main script
# =============================================================================#
if __name__ == "__main__":
    print("Usage: plot_detailed_timing.py <JSON FILE> <SJON FILE> ...")
    for timing_fn in sys.argv[1:]:
        # load data
        print(f"Parsing file = {timing_fn}")
        data = json.load(open(timing_fn))

        # plot data
        gifs_dict = {}
        gifs_dict["encoder-src_len"] = plot_timing(
            lengths=data["mean_src_length"],
            timings=data["encoder"],
            lengths_name="src length",
            timings_name="encoder",
        )
        gifs_dict["sampler-src_len"] = plot_timing(
            lengths=data["mean_src_length"],
            timings=data["sampler"],
            lengths_name="src length",
            timings_name="sampler",
        )
        gifs_dict["sampler-tgt_len"] = plot_timing(
            lengths=data["mean_tgt_length"],
            timings=data["sampler"],
            lengths_name="tgt length",
            timings_name="sampler",
        )

        # save data
        base_fn = os.path.splitext(timing_fn)[0]
        for name, fig in gifs_dict.items():
            plot_fn = f"{base_fn}.{name}.{PLOTS_EXT}"
            print(f"Saving pot = {plot_fn}")
            fig.savefig(plot_fn)