Spaces:
Sleeping
Sleeping
/*************************************************************************************************** | |
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
* SPDX-License-Identifier: BSD-3-Clause | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* | |
* 1. Redistributions of source code must retain the above copyright notice, this | |
* list of conditions and the following disclaimer. | |
* | |
* 2. Redistributions in binary form must reproduce the above copyright notice, | |
* this list of conditions and the following disclaimer in the documentation | |
* and/or other materials provided with the distribution. | |
* | |
* 3. Neither the name of the copyright holder nor the names of its | |
* contributors may be used to endorse or promote products derived from | |
* this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
**************************************************************************************************/ | |
/* \file | |
\brief | |
*/ | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
namespace cutlass { | |
namespace profiler { | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
template <typename T> | |
static T lexical_cast(std::string const &str) { | |
std::stringstream ss; | |
T value; | |
ss << str; | |
ss >> value; | |
return value; | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
std::ostream & KernelArgument::ValueIterator::print(std::ostream &out) const { | |
out << "[" << (void *)this << " " << argument->qualified_name() << "] "; | |
if (this->null_argument) { | |
out << "<null>"; | |
} | |
else { | |
out << "<not null>"; | |
} | |
return out; | |
} | |
KernelArgument::~KernelArgument() { | |
} | |
////////////////////////////////////////////////////////////////////////////////////////////////// | |
ScalarArgument::ScalarValue::ScalarValue( | |
std::string const &value_, | |
ScalarArgument const *argument_, | |
bool not_null_ | |
): | |
KernelArgument::Value(argument_, not_null_), | |
value(value_) { | |
} | |
std::ostream &ScalarArgument::ScalarValue::print(std::ostream &out) const { | |
out << argument->qualified_name() << ": "; | |
if (not_null) { | |
out << value; | |
} | |
else { | |
out << "<null>"; | |
} | |
return out; | |
} | |
ScalarArgument::ScalarValueIterator::ScalarValueIterator( | |
ScalarArgument const *argument_ | |
): | |
KernelArgument::ValueIterator(argument_) { | |
if (argument_) { | |
value_it = argument_->values.begin(); | |
} | |
} | |
void ScalarArgument::ScalarValueIterator::operator++() { | |
if (this->null_argument) { | |
this->null_argument = false; | |
} | |
else { | |
++value_it; | |
} | |
} | |
bool ScalarArgument::ScalarValueIterator::operator==(ValueIterator const &it) const { | |
if (it.type() != ArgumentTypeID::kScalar) { | |
throw std::runtime_error("Cannot compare ScalarValueIterator with iterator of different type"); | |
} | |
auto const & scalar_it = static_cast<ScalarValueIterator const &>(it); | |
return value_it == scalar_it.value_it; | |
} | |
/// Gets the value pointed to | |
std::unique_ptr<KernelArgument::Value> ScalarArgument::ScalarValueIterator::at() const { | |
if (this->null_argument) { | |
return std::unique_ptr<KernelArgument::Value>( | |
new ScalarArgument::ScalarValue( | |
std::string(), | |
static_cast<ScalarArgument const *>(argument), | |
false)); | |
} | |
else { | |
return std::unique_ptr<KernelArgument::Value>( | |
new ScalarArgument::ScalarValue( | |
*value_it, | |
static_cast<ScalarArgument const *>(argument))); | |
} | |
} | |
std::unique_ptr<KernelArgument::ValueIterator> ScalarArgument::begin() const { | |
return std::unique_ptr<KernelArgument::ValueIterator>(new ScalarValueIterator(this)); | |
} | |
std::unique_ptr<KernelArgument::ValueIterator> ScalarArgument::end() const { | |
ScalarValueIterator *it = new ScalarValueIterator(this); | |
it->value_it = this->values.end(); | |
it->null_argument = false; | |
return std::unique_ptr<ValueIterator>(it); | |
} | |
////////////////////////////////////////////////////////////////////////////////////////////////// | |
IntegerArgument::IntegerValue::IntegerValue( | |
int64_t value_, | |
IntegerArgument const *argument_, | |
bool not_null_ | |
): KernelArgument::Value(argument_, not_null_), value(value_) { | |
} | |
/// Pretty printer for debugging | |
std::ostream &IntegerArgument::IntegerValue::print(std::ostream &out) const { | |
out << argument->qualified_name() << ": "; | |
if (not_null) { | |
out << value; | |
} | |
else { | |
out << "<null>"; | |
} | |
return out; | |
} | |
IntegerArgument::IntegerValueIterator::IntegerValueIterator(IntegerArgument const *argument_): | |
KernelArgument::ValueIterator(argument_) { | |
if (argument_) { | |
range_it = argument_->ranges.begin(); | |
if (range_it != argument_->ranges.end()) { | |
value_it = range_it->begin(); | |
} | |
} | |
} | |
void IntegerArgument::IntegerValueIterator::operator++() { | |
if (this->null_argument) { | |
this->null_argument = false; | |
} | |
else { | |
++value_it; | |
if (value_it == range_it->end()) { | |
++range_it; | |
if (range_it != static_cast<IntegerArgument const *>(argument)->ranges.end()) { | |
value_it = range_it->begin(); | |
} | |
} | |
} | |
} | |
bool IntegerArgument::IntegerValueIterator::operator==(ValueIterator const &it) const { | |
if (it.type() != ArgumentTypeID::kInteger) { | |
throw std::runtime_error("Cannot compare IntegerValueIterator with iterator of different type"); | |
} | |
auto const & integer_iterator = static_cast<IntegerValueIterator const &>(it); | |
if (this->null_argument) { | |
return it.null_argument; | |
} | |
else { | |
if (range_it != integer_iterator.range_it) { | |
return false; | |
} | |
if (range_it == static_cast<IntegerArgument const *>(argument)->ranges.end() && | |
range_it == integer_iterator.range_it) { | |
return true; | |
} | |
return value_it == integer_iterator.value_it; | |
} | |
} | |
std::unique_ptr<KernelArgument::Value> IntegerArgument::IntegerValueIterator::at() const { | |
if (this->null_argument) { | |
return std::unique_ptr<KernelArgument::Value>( | |
new IntegerArgument::IntegerValue( | |
0, static_cast<IntegerArgument const *>(argument), false)); | |
} | |
else { | |
return std::unique_ptr<KernelArgument::Value>( | |
new IntegerArgument::IntegerValue( | |
*value_it, static_cast<IntegerArgument const *>(argument))); | |
} | |
} | |
std::unique_ptr<KernelArgument::ValueIterator> IntegerArgument::begin() const { | |
return std::unique_ptr<KernelArgument::ValueIterator>(new IntegerValueIterator(this)); | |
} | |
std::unique_ptr<KernelArgument::ValueIterator> IntegerArgument::end() const { | |
IntegerValueIterator *it = new IntegerValueIterator(this); | |
it->range_it = this->ranges.end(); | |
it->null_argument = false; | |
return std::unique_ptr<ValueIterator>(it); | |
} | |
////////////////////////////////////////////////////////////////////////////////////////////////// | |
TensorArgument::TensorValue::TensorValue( | |
TensorDescription const &desc_, | |
TensorArgument const *argument_, | |
bool not_null_ | |
): | |
KernelArgument::Value(argument_, not_null_), | |
desc(desc_) { | |
} | |
/// Pretty printer for debugging | |
std::ostream &TensorArgument::TensorValue::print(std::ostream &out) const { | |
out << argument->qualified_name() << ": " << to_string(desc.element) << ": " << to_string(desc.layout); | |
return out; | |
} | |
TensorArgument::TensorValueIterator::TensorValueIterator( | |
TensorArgument const *argument_ | |
): | |
KernelArgument::ValueIterator(argument_) { | |
if (argument_) { | |
value_it = argument_->values.begin(); | |
} | |
} | |
void TensorArgument::TensorValueIterator::operator++() { | |
if (this->null_argument) { | |
this->null_argument = false; | |
} | |
else { | |
++value_it; | |
} | |
} | |
bool TensorArgument::TensorValueIterator::operator==(ValueIterator const &it) const { | |
if (it.type() != ArgumentTypeID::kTensor) { | |
throw std::runtime_error("Cannot compare TensorValueIterator with iterator of different type"); | |
} | |
auto const & tensor_it = static_cast<TensorValueIterator const &>(it); | |
return value_it == tensor_it.value_it; | |
} | |
/// Gets the value pointed to | |
std::unique_ptr<KernelArgument::Value> TensorArgument::TensorValueIterator::at() const { | |
if (this->null_argument) { | |
return std::unique_ptr<KernelArgument::Value>( | |
new TensorArgument::TensorValue( | |
TensorDescription(), static_cast<TensorArgument const *>(argument), false)); | |
} | |
else { | |
return std::unique_ptr<KernelArgument::Value>( | |
new TensorArgument::TensorValue( | |
*value_it, static_cast<TensorArgument const *>(argument))); | |
} | |
} | |
std::unique_ptr<KernelArgument::ValueIterator> TensorArgument::begin() const { | |
return std::unique_ptr<KernelArgument::ValueIterator>(new TensorValueIterator(this)); | |
} | |
std::unique_ptr<KernelArgument::ValueIterator> TensorArgument::end() const { | |
TensorValueIterator *it = new TensorValueIterator(this); | |
it->value_it = this->values.end(); | |
it->null_argument = false; | |
return std::unique_ptr<ValueIterator>(it); | |
} | |
////////////////////////////////////////////////////////////////////////////////////////////////// | |
EnumeratedTypeArgument::EnumeratedTypeValue::EnumeratedTypeValue( | |
std::string const & element_, | |
EnumeratedTypeArgument const *argument_, | |
bool not_null_ | |
): | |
KernelArgument::Value(argument_, not_null_), | |
element(element_) { | |
} | |
/// Pretty printer for debugging | |
std::ostream &EnumeratedTypeArgument::EnumeratedTypeValue::print(std::ostream &out) const { | |
out << argument->qualified_name() << ": " << element; | |
return out; | |
} | |
EnumeratedTypeArgument::EnumeratedTypeValueIterator::EnumeratedTypeValueIterator( | |
EnumeratedTypeArgument const *argument_ | |
): | |
KernelArgument::ValueIterator(argument_) { | |
if (argument_) { | |
value_it = argument_->values.begin(); | |
} | |
} | |
void EnumeratedTypeArgument::EnumeratedTypeValueIterator::operator++() { | |
if (this->null_argument) { | |
this->null_argument = false; | |
} | |
else { | |
++value_it; | |
} | |
} | |
bool EnumeratedTypeArgument::EnumeratedTypeValueIterator::operator==(ValueIterator const &it) const { | |
if (it.type() != ArgumentTypeID::kEnumerated) { | |
throw std::runtime_error("Cannot compare EnumeratedTypeValueIterator with iterator of different type"); | |
} | |
auto const & enumerated_type_it = static_cast<EnumeratedTypeValueIterator const &>(it); | |
return value_it == enumerated_type_it.value_it; | |
} | |
/// Gets the value pointed to | |
std::unique_ptr<KernelArgument::Value> EnumeratedTypeArgument::EnumeratedTypeValueIterator::at() const { | |
if (this->null_argument) { | |
return std::unique_ptr<KernelArgument::Value>( | |
new EnumeratedTypeValue( | |
std::string(), static_cast<EnumeratedTypeArgument const *>(argument), false)); | |
} | |
else { | |
return std::unique_ptr<KernelArgument::Value>( | |
new EnumeratedTypeValue( | |
*value_it, static_cast<EnumeratedTypeArgument const *>(argument))); | |
} | |
} | |
std::unique_ptr<KernelArgument::ValueIterator> EnumeratedTypeArgument::begin() const { | |
return std::unique_ptr<KernelArgument::ValueIterator>(new EnumeratedTypeValueIterator(this)); | |
} | |
std::unique_ptr<KernelArgument::ValueIterator> EnumeratedTypeArgument::end() const { | |
EnumeratedTypeValueIterator *it = new EnumeratedTypeValueIterator(this); | |
it->value_it = this->values.end(); | |
it->null_argument = false; | |
return std::unique_ptr<ValueIterator>(it); | |
} | |
////////////////////////////////////////////////////////////////////////////////////////////////// | |
ProblemSpace::Iterator::Iterator() { | |
} | |
ProblemSpace::Iterator::Iterator(ProblemSpace const &problem_space) { | |
for (auto const & arg_ptr : problem_space.arguments) { | |
construct_(arg_ptr.get()); | |
} | |
} | |
ProblemSpace::Iterator::Iterator(Iterator && it) { | |
iterators = std::move(it.iterators); | |
} | |
/// Helper for recursively constructing iterators | |
void ProblemSpace::Iterator::construct_(KernelArgument const *argument) { | |
iterators.emplace_back(argument->begin()); | |
} | |
/// Given a set of ranges, iterate over the points within their Cartesian product. No big deal. | |
void ProblemSpace::Iterator::operator++() { | |
// Define a pair of iterator into the vector of iterators. | |
IteratorVector::iterator iterator_it = iterators.begin(); | |
IteratorVector::iterator next_iterator = iterator_it; | |
// Advance the first argument. | |
++(**iterator_it); | |
// Maintain a pair of iterators over consecutive arguments. | |
++next_iterator; | |
// Carry logic | |
while (next_iterator != iterators.end() && | |
**iterator_it == *((*iterator_it)->argument->end())) { // Did an iterator reach the end of its range? | |
(*iterator_it) = (*iterator_it)->argument->begin(); // Reset that iterator, | |
++(**next_iterator); // and increment the next argument's iterator. | |
iterator_it = next_iterator; // Advance to the next argument | |
++next_iterator; | |
} | |
} | |
/// Moves iterator to end | |
void ProblemSpace::Iterator::move_to_end() { | |
if (!iterators.empty()) { | |
std::unique_ptr<KernelArgument::ValueIterator> new_iter = iterators.back()->argument->end(); | |
std::swap(iterators.back(), new_iter); | |
} | |
} | |
ProblemSpace::Problem ProblemSpace::Iterator::at() const { | |
Problem problem; | |
for (std::unique_ptr<KernelArgument::ValueIterator> const & it : iterators) { | |
problem.emplace_back(it->at()); | |
} | |
return problem; | |
} | |
/// Equality operator | |
bool ProblemSpace::Iterator::operator==(Iterator const &it) const { | |
// This would be an opportunity for auto, but explicitly denoting references to | |
// owning smart pointers to dynamic polymorphic objects seems like a kindness to the reader. | |
IteratorVector::const_iterator first_it = iterators.begin(); | |
IteratorVector::const_iterator second_it = it.iterators.begin(); | |
int idx = 0; | |
for (; first_it != iterators.end(); ++first_it, ++second_it, ++idx) { | |
KernelArgument::ValueIterator const *my_it = first_it->get(); | |
KernelArgument::ValueIterator const *their_it = second_it->get(); | |
if (*my_it != *their_it) { | |
return false; | |
} | |
} | |
return true; | |
} | |
std::ostream &ProblemSpace::Iterator::print(std::ostream &out) const { | |
for (std::unique_ptr<KernelArgument::ValueIterator> const & iter_ptr : iterators) { | |
out << " [iter " << (iter_ptr->null_argument ? "null" : "<not null>") | |
<< ", type: " << to_string(iter_ptr->argument->description->type) << "]" << std::endl; | |
} | |
return out; | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
ProblemSpace::ProblemSpace(ArgumentDescriptionVector const &schema, CommandLine const &cmdline) { | |
// Clone the arguments | |
for (ArgumentDescription const & arg_desc : schema) { | |
clone_(arguments, &arg_desc); | |
} | |
// Parse values from the command line | |
for (auto & arg : arguments) { | |
parse_(arg.get(), cmdline); | |
} | |
} | |
/// Returns the index of an argument by name | |
size_t ProblemSpace::argument_index(char const *name) const { | |
return argument_index_map.at(name); | |
} | |
/// Helper for recursively cloning | |
void ProblemSpace::clone_( | |
KernelArgumentVector &kernel_args, | |
ArgumentDescription const *arg_desc) { | |
KernelArgument *kernel_arg = nullptr; | |
switch (arg_desc->type) { | |
case ArgumentTypeID::kScalar: | |
kernel_arg = new ScalarArgument(arg_desc); | |
break; | |
case ArgumentTypeID::kInteger: | |
kernel_arg = new IntegerArgument(arg_desc); | |
break; | |
case ArgumentTypeID::kTensor: | |
kernel_arg = new TensorArgument(arg_desc); | |
break; | |
case ArgumentTypeID::kStructure: | |
{ | |
throw std::runtime_error("ArgumentTypeID::kStructure not supported"); | |
} | |
break; | |
case ArgumentTypeID::kEnumerated: | |
kernel_arg = new EnumeratedTypeArgument(arg_desc); | |
break; | |
default: break; | |
} | |
if (kernel_arg) { | |
size_t idx = kernel_args.size(); | |
for (auto const &alias : arg_desc->aliases) { | |
argument_index_map.insert(std::make_pair(alias, idx)); | |
} | |
kernel_args.emplace_back(kernel_arg); | |
} | |
} | |
/// Parses a command line | |
void ProblemSpace::parse_(KernelArgument *arg, CommandLine const &cmdline) { | |
switch (arg->description->type) { | |
case ArgumentTypeID::kScalar: | |
{ | |
auto * scalar = static_cast<ScalarArgument *>(arg); | |
for (auto const &alias : arg->description->aliases) { | |
if (cmdline.check_cmd_line_flag(alias.c_str())) { | |
std::vector<std::vector<std::string>> tokens; | |
cmdline.get_cmd_line_argument_ranges(alias.c_str(), tokens); | |
for (auto const & vec : tokens) { | |
if (!vec.empty()) { | |
scalar->values.push_back(vec.front()); | |
} | |
} | |
break; | |
} | |
} | |
} | |
break; | |
case ArgumentTypeID::kInteger: | |
{ | |
auto *integer = static_cast<IntegerArgument *>(arg); | |
for (auto const &alias : arg->description->aliases) { | |
if (cmdline.check_cmd_line_flag(alias.c_str())) { | |
std::vector<std::vector<std::string> > tokens; | |
cmdline.get_cmd_line_argument_ranges(alias.c_str(), tokens); | |
for (auto &range_tokens : tokens) { | |
if (!range_tokens.empty()) { | |
Range range; | |
if (range_tokens.front() == "rand") { | |
range.mode = Range::Mode::kRandom; | |
} | |
else if (range_tokens.front() == "randlg2") { | |
range.mode = Range::Mode::kRandomLog2; | |
} | |
switch (range.mode) { | |
case Range::Mode::kSequence: | |
{ | |
range.first = lexical_cast<int64_t>(range_tokens.front()); | |
if (range_tokens.size() > 1) { | |
range.last = lexical_cast<int64_t>(range_tokens.at(1)); | |
} | |
else { | |
range.last = range.first; | |
} | |
if (range_tokens.size() > 2) { | |
range.increment = lexical_cast<int64_t>(range_tokens.at(2)); | |
} | |
else { | |
range.increment = 1; | |
} | |
} | |
break; | |
case Range::Mode::kRandom: // fall-through | |
case Range::Mode::kRandomLog2: | |
{ | |
if (range_tokens.size() < 4) { | |
throw std::runtime_error( | |
"Range of mode 'rand' must have four tokens showing " | |
"the minimum, maximum, and number of iterations. For example, " | |
"rand:16:128:1000"); | |
} | |
range.minimum = lexical_cast<int64_t>(range_tokens.at(1)); | |
range.maximum = lexical_cast<int64_t>(range_tokens.at(2)); | |
range.first = 1; | |
range.last = lexical_cast<int64_t>(range_tokens.at(3)); | |
range.increment = 1; | |
if (range_tokens.size() > 4) { | |
range.divisible = lexical_cast<int64_t>(range_tokens.at(4)); | |
} | |
} | |
break; | |
default: | |
throw std::runtime_error("Unsupported range mode."); | |
break; | |
} | |
integer->ranges.push_back(range); | |
} | |
} | |
break; | |
} | |
} | |
} | |
break; | |
case ArgumentTypeID::kTensor: | |
{ | |
auto *tensor = static_cast<TensorArgument *>(arg); | |
for (auto const &alias : arg->description->aliases) { | |
if (cmdline.check_cmd_line_flag(alias.c_str())) { | |
std::vector<std::vector<std::string>> tokens; | |
cmdline.get_cmd_line_argument_ranges(alias.c_str(), tokens); | |
for (auto const & tensor_tokens : tokens) { | |
if (!tensor_tokens.empty()) { | |
TensorArgument::TensorDescription tensor_desc; | |
tensor_desc.element = cutlass::library::from_string<library::NumericTypeID>(tensor_tokens.front()); | |
// Layout | |
if (tensor_tokens.size() > 1) { | |
tensor_desc.layout = cutlass::library::from_string<library::LayoutTypeID>(tensor_tokens.at(1)); | |
} | |
// Stride | |
for (size_t i = 2; i < tensor_tokens.size(); ++i) { | |
tensor_desc.stride.push_back(lexical_cast<int>(tensor_tokens.at(i))); | |
} | |
tensor->values.push_back(tensor_desc); | |
} | |
} | |
break; | |
} | |
} | |
} | |
break; | |
case ArgumentTypeID::kStructure: | |
{ | |
throw std::runtime_error("Structure arguments not supported"); | |
} | |
break; | |
case ArgumentTypeID::kEnumerated: | |
{ | |
auto *enumerated_type = static_cast<EnumeratedTypeArgument *>(arg); | |
for (auto const &alias : arg->description->aliases) { | |
if (cmdline.check_cmd_line_flag(alias.c_str())) { | |
std::vector<std::string> tokens; | |
cmdline.get_cmd_line_arguments(alias.c_str(), tokens); | |
for (auto const & token : tokens) { | |
enumerated_type->values.push_back(token); | |
} | |
break; | |
} | |
} | |
} | |
break; | |
default: | |
break; | |
} | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
ProblemSpace::Iterator ProblemSpace::begin() const { | |
return ProblemSpace::Iterator(*this); | |
} | |
ProblemSpace::Iterator ProblemSpace::end() const { | |
ProblemSpace::Iterator it(*this); | |
it.move_to_end(); | |
return it; | |
} | |
/// Gets all argument names as an ordered vector | |
std::vector<std::string> ProblemSpace::argument_names() const { | |
Problem problem = this->begin().at(); | |
std::vector<std::string> names; | |
names.reserve(problem.size()); | |
for (auto const & arg : problem) { | |
names.push_back(arg->argument->description->aliases.front()); | |
} | |
return names; | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_int(int64_t &int_value, KernelArgument::Value const *value_ptr) { | |
if (value_ptr->not_null) { | |
if (value_ptr->argument->description->type == ArgumentTypeID::kInteger) { | |
int_value = static_cast<IntegerArgument::IntegerValue const *>(value_ptr)->value; | |
} | |
else if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) { | |
std::stringstream ss; | |
ss << static_cast<ScalarArgument::ScalarValue const *>(value_ptr)->value; | |
ss >> int_value; | |
} | |
else { | |
throw std::runtime_error( | |
"arg_as_int64_t() - illegal cast. Problem space argument must be integer or scalar"); | |
} | |
return true; | |
} | |
return false; | |
} | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_int(int &int_value, KernelArgument::Value const *value_ptr) { | |
int64_t value64; | |
bool obtained = arg_as_int(value64, value_ptr); | |
if (obtained) { | |
int_value = int(value64); | |
return true; | |
} | |
return false; | |
} | |
/// Lexically casts an argument to an int | |
bool arg_as_int( | |
int &int_value, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
return arg_as_int(int_value, value_ptr); | |
} | |
/// Lexically casts an argument to an int64 | |
bool arg_as_int( | |
int64_t &int_value, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
return arg_as_int(int_value, value_ptr); | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_NumericTypeID( | |
library::NumericTypeID &numeric_type, | |
KernelArgument::Value const *value_ptr) { | |
if (value_ptr->not_null) { | |
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { | |
numeric_type = library::from_string<library::NumericTypeID>( | |
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element); | |
if (numeric_type == library::NumericTypeID::kInvalid) { | |
throw std::runtime_error( | |
"arg_as_NumericTypeID() - illegal cast."); | |
} | |
} | |
else { | |
throw std::runtime_error( | |
"arg_as_NumericTypeID() - illegal cast."); | |
} | |
return true; | |
} | |
return false; | |
} | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_NumericTypeID( | |
library::NumericTypeID &numeric_type, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
return arg_as_NumericTypeID(numeric_type, value_ptr); | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_RasterOrder( | |
library::RasterOrder &raster_order, | |
KernelArgument::Value const *value_ptr) { | |
if (value_ptr->not_null) { | |
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { | |
raster_order = library::from_string<library::RasterOrder>( | |
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element); | |
if (raster_order == library::RasterOrder::kInvalid) { | |
throw std::runtime_error( | |
"arg_as_RasterOrder() - illegal cast."); | |
} | |
} | |
else { | |
throw std::runtime_error( | |
"arg_as_RasterOrder() - illegal cast."); | |
} | |
return true; | |
} | |
return false; | |
} | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_RasterOrder( | |
library::RasterOrder &raster_order, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
return arg_as_RasterOrder(raster_order, value_ptr); | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_LayoutTypeID( | |
library::LayoutTypeID &layout_type, | |
KernelArgument::Value const *value_ptr) { | |
if (value_ptr->not_null) { | |
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { | |
layout_type = library::from_string<library::LayoutTypeID>( | |
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element); | |
if (layout_type == library::LayoutTypeID::kInvalid) { | |
throw std::runtime_error( | |
"arg_as_LayoutTypeID() - illegal cast."); | |
} | |
} | |
else { | |
throw std::runtime_error( | |
"arg_as_LayoutTypeID() - illegal cast."); | |
} | |
return true; | |
} | |
return false; | |
} | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_LayoutTypeID( | |
library::LayoutTypeID &layout_type, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
return arg_as_LayoutTypeID(layout_type, value_ptr); | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_OpcodeClassID( | |
library::OpcodeClassID &opcode_class, | |
KernelArgument::Value const *value_ptr) { | |
if (value_ptr->not_null) { | |
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { | |
opcode_class = library::from_string<library::OpcodeClassID>( | |
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element); | |
if (opcode_class == library::OpcodeClassID::kInvalid) { | |
throw std::runtime_error( | |
"arg_as_OpcodeClassID() - illegal cast."); | |
} | |
} | |
else { | |
throw std::runtime_error( | |
"arg_as_OpcodeClassID() - illegal cast."); | |
} | |
return true; | |
} | |
return false; | |
} | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_OpcodeClassID( | |
library::OpcodeClassID &opcode_class, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
return arg_as_OpcodeClassID(opcode_class, value_ptr); | |
} | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_SplitKModeID( | |
library::SplitKMode &split_k_mode, | |
KernelArgument::Value const *value_ptr) { | |
if (value_ptr->not_null) { | |
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { | |
split_k_mode = library::from_string<library::SplitKMode>( | |
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element); | |
if (split_k_mode == library::SplitKMode::kInvalid) { | |
throw std::runtime_error( | |
"arg_as_SplitKModeID() - illegal cast."); | |
} | |
} | |
else { | |
throw std::runtime_error( | |
"arg_as_SplitKModeID() - illegal cast."); | |
} | |
return true; | |
} | |
return false; | |
} | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_SplitKModeID( | |
library::SplitKMode &split_k_mode, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
return arg_as_SplitKModeID(split_k_mode, value_ptr); | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_ConvModeID( | |
library::ConvModeID &conv_mode, | |
KernelArgument::Value const *value_ptr) { | |
if (value_ptr->not_null) { | |
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { | |
conv_mode = library::from_string<library::ConvModeID>( | |
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element); | |
if (conv_mode == library::ConvModeID::kInvalid) { | |
throw std::runtime_error( | |
"arg_as_ConvModeID() - illegal cast."); | |
} | |
} | |
else { | |
throw std::runtime_error( | |
"arg_as_ConvModeID() - illegal cast."); | |
} | |
return true; | |
} | |
return false; | |
} | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_ConvModeID( | |
library::ConvModeID &conv_mode, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
return arg_as_ConvModeID(conv_mode, value_ptr); | |
} | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_ProviderID( | |
library::Provider &provider, | |
KernelArgument::Value const *value_ptr) { | |
if (value_ptr->not_null) { | |
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { | |
provider = library::from_string<library::Provider>( | |
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element); | |
if (provider == library::Provider::kInvalid) { | |
throw std::runtime_error( | |
"arg_as_ProviderID() - illegal cast."); | |
} | |
} | |
else { | |
throw std::runtime_error( | |
"arg_as_ProviderID() - illegal cast."); | |
} | |
return true; | |
} | |
return false; | |
} | |
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. | |
bool arg_as_ProviderID( | |
library::Provider &provider, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
return arg_as_ProviderID(provider, value_ptr); | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. | |
bool arg_as_scalar( | |
std::vector<uint8_t> &bytes, | |
library::NumericTypeID numeric_type, | |
KernelArgument::Value const *value_ptr) { | |
if (value_ptr->not_null) { | |
if (value_ptr->argument->description->type == ArgumentTypeID::kInteger) { | |
int64_t int_value = static_cast<IntegerArgument::IntegerValue const *>(value_ptr)->value; | |
// TODO - convert int64_t => destination type | |
} | |
else if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) { | |
std::string const &str_value = static_cast<ScalarArgument::ScalarValue const *>(value_ptr)->value; | |
return lexical_cast(bytes, numeric_type, str_value); | |
} | |
else { | |
throw std::runtime_error( | |
"arg_as_int() - illegal cast. Problem space argument must be integer or scalar"); | |
} | |
return true; | |
} | |
return false; | |
} | |
/// Lexically casts an argument to a given type and returns a byte array | |
bool arg_as_scalar( | |
std::vector<uint8_t> &bytes, | |
library::NumericTypeID numeric_type, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
return arg_as_scalar(bytes, numeric_type, value_ptr); | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
/// Returns true if a tensor description satisfies a `tensor` value | |
bool tensor_description_satisfies( | |
library::TensorDescription const &tensor_desc, | |
TensorArgument::TensorValue const *value_ptr) { | |
if (value_ptr->not_null) { | |
if (value_ptr->desc.element != library::NumericTypeID::kUnknown && | |
value_ptr->desc.element != tensor_desc.element) { | |
return false; | |
} | |
if (value_ptr->desc.layout != library::LayoutTypeID::kUnknown && | |
value_ptr->desc.layout != tensor_desc.layout) { | |
return false; | |
} | |
} | |
return true; | |
} | |
/// Returns true if a tensor description satisfies a `tensor` value | |
bool tensor_description_satisfies( | |
library::TensorDescription const &tensor_desc, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
if (value_ptr->argument->description->type == ArgumentTypeID::kTensor) { | |
return tensor_description_satisfies( | |
tensor_desc, | |
static_cast<TensorArgument::TensorValue const *>(value_ptr)); | |
} | |
else { | |
throw std::runtime_error("Kernel argument mismatch"); | |
} | |
return false; | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
/// Returns true if conv_kind satisfies the value | |
bool conv_kind_satisfies( | |
library::ConvKind const &conv_kind, | |
EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr) { | |
if (value_ptr->not_null) { | |
library::ConvKind conv_kind_cmd_line = | |
library::from_string<library::ConvKind>(value_ptr->element); | |
if (conv_kind_cmd_line != library::ConvKind::kUnknown && | |
conv_kind_cmd_line != conv_kind) { | |
return false; | |
} | |
} | |
return true; | |
} | |
/// Returns true if conv_kind satisfies the value | |
bool conv_kind_satisfies( | |
library::ConvKind const &conv_kind, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { | |
return conv_kind_satisfies( | |
conv_kind, | |
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)); | |
} | |
else { | |
throw std::runtime_error("Kernel argument mismatch"); | |
} | |
return false; | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
/// Returns true if a iterator algorithm satisfies the value | |
bool iterator_algorithm_satisfies( | |
library::IteratorAlgorithmID const &iterator_algorithm, | |
EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr) { | |
if (value_ptr->not_null) { | |
library::IteratorAlgorithmID iterator_algorithm_cmd_line = | |
library::from_string<library::IteratorAlgorithmID>(value_ptr->element); | |
if (iterator_algorithm_cmd_line != library::IteratorAlgorithmID::kNone && | |
iterator_algorithm_cmd_line != iterator_algorithm) { | |
return false; | |
} | |
} | |
return true; | |
} | |
/// Returns true if a iterator algorithm satisfies the value | |
bool iterator_algorithm_satisfies( | |
library::IteratorAlgorithmID const &iterator_algorithm, | |
char const *name, | |
ProblemSpace const &problem_space, | |
ProblemSpace::Problem const &problem) { | |
size_t idx = problem_space.argument_index(name); | |
KernelArgument::Value const *value_ptr = problem.at(idx).get(); | |
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { | |
return iterator_algorithm_satisfies( | |
iterator_algorithm, | |
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)); | |
} | |
else { | |
throw std::runtime_error("Kernel argument mismatch"); | |
} | |
return false; | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
} // namespace profiler | |
} // namespace cutlass | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |