File size: 4,599 Bytes
e45d058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
/***************************************************************************************************

 * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

 * SPDX-License-Identifier: BSD-3-Clause

 *

 * Redistribution and use in source and binary forms, with or without

 * modification, are permitted provided that the following conditions are met:

 *

 * 1. Redistributions of source code must retain the above copyright notice, this

 * list of conditions and the following disclaimer.

 *

 * 2. Redistributions in binary form must reproduce the above copyright notice,

 * this list of conditions and the following disclaimer in the documentation

 * and/or other materials provided with the distribution.

 *

 * 3. Neither the name of the copyright holder nor the names of its

 * contributors may be used to endorse or promote products derived from

 * this software without specific prior written permission.

 *

 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"

 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE

 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE

 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE

 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL

 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR

 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER

 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,

 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

 *

 **************************************************************************************************/

#pragma once

#include <vector>
#include <iostream>

// Cutlass command line parser
#include "cutlass/util/command_line.h"

class Options {
public:

  bool help;
  bool good;
  std::vector<int> extent;          ///< extent of tile to fill
  std::vector<int> stride;          ///< stride vector for layout function
  std::vector<int> output_shape;    ///< output shape
  int vectorize;                    ///< sequences of consecutive output elements are concatenated into a vector
                                    ///  if, and only if, they were consecutive in source memory

public:

  /// Options
  Options(): 
    help(false),
    good(true),
    extent({32, 8}),
    stride({32}),
    output_shape({16, 8}), 
    vectorize(1) { 

  }

  /// Constructs from command line parser
  Options(cutlass::CommandLine const & cmd_line): help(false), good(true) {

    if (cmd_line.check_cmd_line_flag("help") ||
        cmd_line.check_cmd_line_flag("h")) {

      help = true;
    }

    if (cmd_line.check_cmd_line_flag("extent")) {
      cmd_line.get_cmd_line_arguments("extent", extent);
    }
    else {
      extent = {32, 8};
    }

    if (cmd_line.check_cmd_line_flag("stride")) {
      cmd_line.get_cmd_line_arguments("stride", stride);
    }
    
    int default_output_shape[] = {16, 8}; 

    if (cmd_line.check_cmd_line_flag("output-shape")) {
      cmd_line.get_cmd_line_arguments("output-shape", output_shape);
    }

    for (int i = int(output_shape.size()); i < 2; ++i) {
      output_shape.push_back(default_output_shape[i]);
    }

    if (cmd_line.check_cmd_line_flag("vectorize")) {
      cmd_line.get_cmd_line_argument("vectorize", vectorize);
    }
    else {
      vectorize = 1;
    }

    if (output_shape.front() % vectorize) {

      std::cerr << "Error: --vectorize=" << vectorize 
        << " must divide contiguous elements in --output-shape="
        << output_shape.at(0) << "," << output_shape.at(1) << std::endl;

      good = false;
    }
  }

  /// Prints usage statement
  static void print_usage(std::ostream &out) {
    out
      << "  Options:\n"
      << "    --help                              Displays this help message.\n"
      << "    --extent=<extent>                   Specifies the layout-specific extent (as comma-delimited array).\n"
      << "    --stride=<stride>                   Specifies the layout-specific stride vector (comma-delimited array)\n"
      << "    --output-shape=<extent>             Specifies the dimensions of a row-major output matrix. \n"
      << "    --vectorize=<vector length>         If possible, vectorizes the output into vectors of consecutive elements\n";
  }
};