/*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /* \file \brief */ #include #include #include #include "cutlass/library/util.h" #include "cutlass/profiler/problem_space.h" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace profiler { ///////////////////////////////////////////////////////////////////////////////////////////////// template static T lexical_cast(std::string const &str) { std::stringstream ss; T value; ss << str; ss >> value; return value; } ///////////////////////////////////////////////////////////////////////////////////////////////// std::ostream & KernelArgument::ValueIterator::print(std::ostream &out) const { out << "[" << (void *)this << " " << argument->qualified_name() << "] "; if (this->null_argument) { out << ""; } else { out << ""; } return out; } KernelArgument::~KernelArgument() { } ////////////////////////////////////////////////////////////////////////////////////////////////// ScalarArgument::ScalarValue::ScalarValue( std::string const &value_, ScalarArgument const *argument_, bool not_null_ ): KernelArgument::Value(argument_, not_null_), value(value_) { } std::ostream &ScalarArgument::ScalarValue::print(std::ostream &out) const { out << argument->qualified_name() << ": "; if (not_null) { out << value; } else { out << ""; } return out; } ScalarArgument::ScalarValueIterator::ScalarValueIterator( ScalarArgument const *argument_ ): KernelArgument::ValueIterator(argument_) { if (argument_) { value_it = argument_->values.begin(); } } void ScalarArgument::ScalarValueIterator::operator++() { if (this->null_argument) { this->null_argument = false; } else { ++value_it; } } bool ScalarArgument::ScalarValueIterator::operator==(ValueIterator const &it) const { if (it.type() != ArgumentTypeID::kScalar) { throw std::runtime_error("Cannot compare ScalarValueIterator with iterator of different type"); } auto const & scalar_it = static_cast(it); return value_it == scalar_it.value_it; } /// Gets the value pointed to std::unique_ptr ScalarArgument::ScalarValueIterator::at() const { if (this->null_argument) { return std::unique_ptr( new ScalarArgument::ScalarValue( std::string(), static_cast(argument), false)); } else { return std::unique_ptr( new ScalarArgument::ScalarValue( *value_it, static_cast(argument))); } } std::unique_ptr ScalarArgument::begin() const { return std::unique_ptr(new ScalarValueIterator(this)); } std::unique_ptr ScalarArgument::end() const { ScalarValueIterator *it = new ScalarValueIterator(this); it->value_it = this->values.end(); it->null_argument = false; return std::unique_ptr(it); } ////////////////////////////////////////////////////////////////////////////////////////////////// IntegerArgument::IntegerValue::IntegerValue( int64_t value_, IntegerArgument const *argument_, bool not_null_ ): KernelArgument::Value(argument_, not_null_), value(value_) { } /// Pretty printer for debugging std::ostream &IntegerArgument::IntegerValue::print(std::ostream &out) const { out << argument->qualified_name() << ": "; if (not_null) { out << value; } else { out << ""; } return out; } IntegerArgument::IntegerValueIterator::IntegerValueIterator(IntegerArgument const *argument_): KernelArgument::ValueIterator(argument_) { if (argument_) { range_it = argument_->ranges.begin(); if (range_it != argument_->ranges.end()) { value_it = range_it->begin(); } } } void IntegerArgument::IntegerValueIterator::operator++() { if (this->null_argument) { this->null_argument = false; } else { ++value_it; if (value_it == range_it->end()) { ++range_it; if (range_it != static_cast(argument)->ranges.end()) { value_it = range_it->begin(); } } } } bool IntegerArgument::IntegerValueIterator::operator==(ValueIterator const &it) const { if (it.type() != ArgumentTypeID::kInteger) { throw std::runtime_error("Cannot compare IntegerValueIterator with iterator of different type"); } auto const & integer_iterator = static_cast(it); if (this->null_argument) { return it.null_argument; } else { if (range_it != integer_iterator.range_it) { return false; } if (range_it == static_cast(argument)->ranges.end() && range_it == integer_iterator.range_it) { return true; } return value_it == integer_iterator.value_it; } } std::unique_ptr IntegerArgument::IntegerValueIterator::at() const { if (this->null_argument) { return std::unique_ptr( new IntegerArgument::IntegerValue( 0, static_cast(argument), false)); } else { return std::unique_ptr( new IntegerArgument::IntegerValue( *value_it, static_cast(argument))); } } std::unique_ptr IntegerArgument::begin() const { return std::unique_ptr(new IntegerValueIterator(this)); } std::unique_ptr IntegerArgument::end() const { IntegerValueIterator *it = new IntegerValueIterator(this); it->range_it = this->ranges.end(); it->null_argument = false; return std::unique_ptr(it); } ////////////////////////////////////////////////////////////////////////////////////////////////// TensorArgument::TensorValue::TensorValue( TensorDescription const &desc_, TensorArgument const *argument_, bool not_null_ ): KernelArgument::Value(argument_, not_null_), desc(desc_) { } /// Pretty printer for debugging std::ostream &TensorArgument::TensorValue::print(std::ostream &out) const { out << argument->qualified_name() << ": " << to_string(desc.element) << ": " << to_string(desc.layout); return out; } TensorArgument::TensorValueIterator::TensorValueIterator( TensorArgument const *argument_ ): KernelArgument::ValueIterator(argument_) { if (argument_) { value_it = argument_->values.begin(); } } void TensorArgument::TensorValueIterator::operator++() { if (this->null_argument) { this->null_argument = false; } else { ++value_it; } } bool TensorArgument::TensorValueIterator::operator==(ValueIterator const &it) const { if (it.type() != ArgumentTypeID::kTensor) { throw std::runtime_error("Cannot compare TensorValueIterator with iterator of different type"); } auto const & tensor_it = static_cast(it); return value_it == tensor_it.value_it; } /// Gets the value pointed to std::unique_ptr TensorArgument::TensorValueIterator::at() const { if (this->null_argument) { return std::unique_ptr( new TensorArgument::TensorValue( TensorDescription(), static_cast(argument), false)); } else { return std::unique_ptr( new TensorArgument::TensorValue( *value_it, static_cast(argument))); } } std::unique_ptr TensorArgument::begin() const { return std::unique_ptr(new TensorValueIterator(this)); } std::unique_ptr TensorArgument::end() const { TensorValueIterator *it = new TensorValueIterator(this); it->value_it = this->values.end(); it->null_argument = false; return std::unique_ptr(it); } ////////////////////////////////////////////////////////////////////////////////////////////////// EnumeratedTypeArgument::EnumeratedTypeValue::EnumeratedTypeValue( std::string const & element_, EnumeratedTypeArgument const *argument_, bool not_null_ ): KernelArgument::Value(argument_, not_null_), element(element_) { } /// Pretty printer for debugging std::ostream &EnumeratedTypeArgument::EnumeratedTypeValue::print(std::ostream &out) const { out << argument->qualified_name() << ": " << element; return out; } EnumeratedTypeArgument::EnumeratedTypeValueIterator::EnumeratedTypeValueIterator( EnumeratedTypeArgument const *argument_ ): KernelArgument::ValueIterator(argument_) { if (argument_) { value_it = argument_->values.begin(); } } void EnumeratedTypeArgument::EnumeratedTypeValueIterator::operator++() { if (this->null_argument) { this->null_argument = false; } else { ++value_it; } } bool EnumeratedTypeArgument::EnumeratedTypeValueIterator::operator==(ValueIterator const &it) const { if (it.type() != ArgumentTypeID::kEnumerated) { throw std::runtime_error("Cannot compare EnumeratedTypeValueIterator with iterator of different type"); } auto const & enumerated_type_it = static_cast(it); return value_it == enumerated_type_it.value_it; } /// Gets the value pointed to std::unique_ptr EnumeratedTypeArgument::EnumeratedTypeValueIterator::at() const { if (this->null_argument) { return std::unique_ptr( new EnumeratedTypeValue( std::string(), static_cast(argument), false)); } else { return std::unique_ptr( new EnumeratedTypeValue( *value_it, static_cast(argument))); } } std::unique_ptr EnumeratedTypeArgument::begin() const { return std::unique_ptr(new EnumeratedTypeValueIterator(this)); } std::unique_ptr EnumeratedTypeArgument::end() const { EnumeratedTypeValueIterator *it = new EnumeratedTypeValueIterator(this); it->value_it = this->values.end(); it->null_argument = false; return std::unique_ptr(it); } ////////////////////////////////////////////////////////////////////////////////////////////////// ProblemSpace::Iterator::Iterator() { } ProblemSpace::Iterator::Iterator(ProblemSpace const &problem_space) { for (auto const & arg_ptr : problem_space.arguments) { construct_(arg_ptr.get()); } } ProblemSpace::Iterator::Iterator(Iterator && it) { iterators = std::move(it.iterators); } /// Helper for recursively constructing iterators void ProblemSpace::Iterator::construct_(KernelArgument const *argument) { iterators.emplace_back(argument->begin()); } /// Given a set of ranges, iterate over the points within their Cartesian product. No big deal. void ProblemSpace::Iterator::operator++() { // Define a pair of iterator into the vector of iterators. IteratorVector::iterator iterator_it = iterators.begin(); IteratorVector::iterator next_iterator = iterator_it; // Advance the first argument. ++(**iterator_it); // Maintain a pair of iterators over consecutive arguments. ++next_iterator; // Carry logic while (next_iterator != iterators.end() && **iterator_it == *((*iterator_it)->argument->end())) { // Did an iterator reach the end of its range? (*iterator_it) = (*iterator_it)->argument->begin(); // Reset that iterator, ++(**next_iterator); // and increment the next argument's iterator. iterator_it = next_iterator; // Advance to the next argument ++next_iterator; } } /// Moves iterator to end void ProblemSpace::Iterator::move_to_end() { if (!iterators.empty()) { std::unique_ptr new_iter = iterators.back()->argument->end(); std::swap(iterators.back(), new_iter); } } ProblemSpace::Problem ProblemSpace::Iterator::at() const { Problem problem; for (std::unique_ptr const & it : iterators) { problem.emplace_back(it->at()); } return problem; } /// Equality operator bool ProblemSpace::Iterator::operator==(Iterator const &it) const { // This would be an opportunity for auto, but explicitly denoting references to // owning smart pointers to dynamic polymorphic objects seems like a kindness to the reader. IteratorVector::const_iterator first_it = iterators.begin(); IteratorVector::const_iterator second_it = it.iterators.begin(); int idx = 0; for (; first_it != iterators.end(); ++first_it, ++second_it, ++idx) { KernelArgument::ValueIterator const *my_it = first_it->get(); KernelArgument::ValueIterator const *their_it = second_it->get(); if (*my_it != *their_it) { return false; } } return true; } std::ostream &ProblemSpace::Iterator::print(std::ostream &out) const { for (std::unique_ptr const & iter_ptr : iterators) { out << " [iter " << (iter_ptr->null_argument ? "null" : "") << ", type: " << to_string(iter_ptr->argument->description->type) << "]" << std::endl; } return out; } ///////////////////////////////////////////////////////////////////////////////////////////////// ProblemSpace::ProblemSpace(ArgumentDescriptionVector const &schema, CommandLine const &cmdline) { // Clone the arguments for (ArgumentDescription const & arg_desc : schema) { clone_(arguments, &arg_desc); } // Parse values from the command line for (auto & arg : arguments) { parse_(arg.get(), cmdline); } } /// Returns the index of an argument by name size_t ProblemSpace::argument_index(char const *name) const { return argument_index_map.at(name); } /// Helper for recursively cloning void ProblemSpace::clone_( KernelArgumentVector &kernel_args, ArgumentDescription const *arg_desc) { KernelArgument *kernel_arg = nullptr; switch (arg_desc->type) { case ArgumentTypeID::kScalar: kernel_arg = new ScalarArgument(arg_desc); break; case ArgumentTypeID::kInteger: kernel_arg = new IntegerArgument(arg_desc); break; case ArgumentTypeID::kTensor: kernel_arg = new TensorArgument(arg_desc); break; case ArgumentTypeID::kStructure: { throw std::runtime_error("ArgumentTypeID::kStructure not supported"); } break; case ArgumentTypeID::kEnumerated: kernel_arg = new EnumeratedTypeArgument(arg_desc); break; default: break; } if (kernel_arg) { size_t idx = kernel_args.size(); for (auto const &alias : arg_desc->aliases) { argument_index_map.insert(std::make_pair(alias, idx)); } kernel_args.emplace_back(kernel_arg); } } /// Parses a command line void ProblemSpace::parse_(KernelArgument *arg, CommandLine const &cmdline) { switch (arg->description->type) { case ArgumentTypeID::kScalar: { auto * scalar = static_cast(arg); for (auto const &alias : arg->description->aliases) { if (cmdline.check_cmd_line_flag(alias.c_str())) { std::vector> tokens; cmdline.get_cmd_line_argument_ranges(alias.c_str(), tokens); for (auto const & vec : tokens) { if (!vec.empty()) { scalar->values.push_back(vec.front()); } } break; } } } break; case ArgumentTypeID::kInteger: { auto *integer = static_cast(arg); for (auto const &alias : arg->description->aliases) { if (cmdline.check_cmd_line_flag(alias.c_str())) { std::vector > tokens; cmdline.get_cmd_line_argument_ranges(alias.c_str(), tokens); for (auto &range_tokens : tokens) { if (!range_tokens.empty()) { Range range; if (range_tokens.front() == "rand") { range.mode = Range::Mode::kRandom; } else if (range_tokens.front() == "randlg2") { range.mode = Range::Mode::kRandomLog2; } switch (range.mode) { case Range::Mode::kSequence: { range.first = lexical_cast(range_tokens.front()); if (range_tokens.size() > 1) { range.last = lexical_cast(range_tokens.at(1)); } else { range.last = range.first; } if (range_tokens.size() > 2) { range.increment = lexical_cast(range_tokens.at(2)); } else { range.increment = 1; } } break; case Range::Mode::kRandom: // fall-through case Range::Mode::kRandomLog2: { if (range_tokens.size() < 4) { throw std::runtime_error( "Range of mode 'rand' must have four tokens showing " "the minimum, maximum, and number of iterations. For example, " "rand:16:128:1000"); } range.minimum = lexical_cast(range_tokens.at(1)); range.maximum = lexical_cast(range_tokens.at(2)); range.first = 1; range.last = lexical_cast(range_tokens.at(3)); range.increment = 1; if (range_tokens.size() > 4) { range.divisible = lexical_cast(range_tokens.at(4)); } } break; default: throw std::runtime_error("Unsupported range mode."); break; } integer->ranges.push_back(range); } } break; } } } break; case ArgumentTypeID::kTensor: { auto *tensor = static_cast(arg); for (auto const &alias : arg->description->aliases) { if (cmdline.check_cmd_line_flag(alias.c_str())) { std::vector> tokens; cmdline.get_cmd_line_argument_ranges(alias.c_str(), tokens); for (auto const & tensor_tokens : tokens) { if (!tensor_tokens.empty()) { TensorArgument::TensorDescription tensor_desc; tensor_desc.element = cutlass::library::from_string(tensor_tokens.front()); // Layout if (tensor_tokens.size() > 1) { tensor_desc.layout = cutlass::library::from_string(tensor_tokens.at(1)); } // Stride for (size_t i = 2; i < tensor_tokens.size(); ++i) { tensor_desc.stride.push_back(lexical_cast(tensor_tokens.at(i))); } tensor->values.push_back(tensor_desc); } } break; } } } break; case ArgumentTypeID::kStructure: { throw std::runtime_error("Structure arguments not supported"); } break; case ArgumentTypeID::kEnumerated: { auto *enumerated_type = static_cast(arg); for (auto const &alias : arg->description->aliases) { if (cmdline.check_cmd_line_flag(alias.c_str())) { std::vector tokens; cmdline.get_cmd_line_arguments(alias.c_str(), tokens); for (auto const & token : tokens) { enumerated_type->values.push_back(token); } break; } } } break; default: break; } } ///////////////////////////////////////////////////////////////////////////////////////////////// ProblemSpace::Iterator ProblemSpace::begin() const { return ProblemSpace::Iterator(*this); } ProblemSpace::Iterator ProblemSpace::end() const { ProblemSpace::Iterator it(*this); it.move_to_end(); return it; } /// Gets all argument names as an ordered vector std::vector ProblemSpace::argument_names() const { Problem problem = this->begin().at(); std::vector names; names.reserve(problem.size()); for (auto const & arg : problem) { names.push_back(arg->argument->description->aliases.front()); } return names; } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_int(int64_t &int_value, KernelArgument::Value const *value_ptr) { if (value_ptr->not_null) { if (value_ptr->argument->description->type == ArgumentTypeID::kInteger) { int_value = static_cast(value_ptr)->value; } else if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) { std::stringstream ss; ss << static_cast(value_ptr)->value; ss >> int_value; } else { throw std::runtime_error( "arg_as_int64_t() - illegal cast. Problem space argument must be integer or scalar"); } return true; } return false; } /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_int(int &int_value, KernelArgument::Value const *value_ptr) { int64_t value64; bool obtained = arg_as_int(value64, value_ptr); if (obtained) { int_value = int(value64); return true; } return false; } /// Lexically casts an argument to an int bool arg_as_int( int &int_value, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); return arg_as_int(int_value, value_ptr); } /// Lexically casts an argument to an int64 bool arg_as_int( int64_t &int_value, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); return arg_as_int(int_value, value_ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_NumericTypeID( library::NumericTypeID &numeric_type, KernelArgument::Value const *value_ptr) { if (value_ptr->not_null) { if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { numeric_type = library::from_string( static_cast(value_ptr)->element); if (numeric_type == library::NumericTypeID::kInvalid) { throw std::runtime_error( "arg_as_NumericTypeID() - illegal cast."); } } else { throw std::runtime_error( "arg_as_NumericTypeID() - illegal cast."); } return true; } return false; } /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_NumericTypeID( library::NumericTypeID &numeric_type, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); return arg_as_NumericTypeID(numeric_type, value_ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_RasterOrder( library::RasterOrder &raster_order, KernelArgument::Value const *value_ptr) { if (value_ptr->not_null) { if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { raster_order = library::from_string( static_cast(value_ptr)->element); if (raster_order == library::RasterOrder::kInvalid) { throw std::runtime_error( "arg_as_RasterOrder() - illegal cast."); } } else { throw std::runtime_error( "arg_as_RasterOrder() - illegal cast."); } return true; } return false; } /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_RasterOrder( library::RasterOrder &raster_order, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); return arg_as_RasterOrder(raster_order, value_ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_LayoutTypeID( library::LayoutTypeID &layout_type, KernelArgument::Value const *value_ptr) { if (value_ptr->not_null) { if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { layout_type = library::from_string( static_cast(value_ptr)->element); if (layout_type == library::LayoutTypeID::kInvalid) { throw std::runtime_error( "arg_as_LayoutTypeID() - illegal cast."); } } else { throw std::runtime_error( "arg_as_LayoutTypeID() - illegal cast."); } return true; } return false; } /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_LayoutTypeID( library::LayoutTypeID &layout_type, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); return arg_as_LayoutTypeID(layout_type, value_ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_OpcodeClassID( library::OpcodeClassID &opcode_class, KernelArgument::Value const *value_ptr) { if (value_ptr->not_null) { if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { opcode_class = library::from_string( static_cast(value_ptr)->element); if (opcode_class == library::OpcodeClassID::kInvalid) { throw std::runtime_error( "arg_as_OpcodeClassID() - illegal cast."); } } else { throw std::runtime_error( "arg_as_OpcodeClassID() - illegal cast."); } return true; } return false; } /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_OpcodeClassID( library::OpcodeClassID &opcode_class, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); return arg_as_OpcodeClassID(opcode_class, value_ptr); } /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_SplitKModeID( library::SplitKMode &split_k_mode, KernelArgument::Value const *value_ptr) { if (value_ptr->not_null) { if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { split_k_mode = library::from_string( static_cast(value_ptr)->element); if (split_k_mode == library::SplitKMode::kInvalid) { throw std::runtime_error( "arg_as_SplitKModeID() - illegal cast."); } } else { throw std::runtime_error( "arg_as_SplitKModeID() - illegal cast."); } return true; } return false; } /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_SplitKModeID( library::SplitKMode &split_k_mode, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); return arg_as_SplitKModeID(split_k_mode, value_ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_ConvModeID( library::ConvModeID &conv_mode, KernelArgument::Value const *value_ptr) { if (value_ptr->not_null) { if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { conv_mode = library::from_string( static_cast(value_ptr)->element); if (conv_mode == library::ConvModeID::kInvalid) { throw std::runtime_error( "arg_as_ConvModeID() - illegal cast."); } } else { throw std::runtime_error( "arg_as_ConvModeID() - illegal cast."); } return true; } return false; } /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_ConvModeID( library::ConvModeID &conv_mode, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); return arg_as_ConvModeID(conv_mode, value_ptr); } /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_ProviderID( library::Provider &provider, KernelArgument::Value const *value_ptr) { if (value_ptr->not_null) { if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { provider = library::from_string( static_cast(value_ptr)->element); if (provider == library::Provider::kInvalid) { throw std::runtime_error( "arg_as_ProviderID() - illegal cast."); } } else { throw std::runtime_error( "arg_as_ProviderID() - illegal cast."); } return true; } return false; } /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_ProviderID( library::Provider &provider, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); return arg_as_ProviderID(provider, value_ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. bool arg_as_scalar( std::vector &bytes, library::NumericTypeID numeric_type, KernelArgument::Value const *value_ptr) { if (value_ptr->not_null) { if (value_ptr->argument->description->type == ArgumentTypeID::kInteger) { int64_t int_value = static_cast(value_ptr)->value; // TODO - convert int64_t => destination type } else if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) { std::string const &str_value = static_cast(value_ptr)->value; return lexical_cast(bytes, numeric_type, str_value); } else { throw std::runtime_error( "arg_as_int() - illegal cast. Problem space argument must be integer or scalar"); } return true; } return false; } /// Lexically casts an argument to a given type and returns a byte array bool arg_as_scalar( std::vector &bytes, library::NumericTypeID numeric_type, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); return arg_as_scalar(bytes, numeric_type, value_ptr); } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Returns true if a tensor description satisfies a `tensor` value bool tensor_description_satisfies( library::TensorDescription const &tensor_desc, TensorArgument::TensorValue const *value_ptr) { if (value_ptr->not_null) { if (value_ptr->desc.element != library::NumericTypeID::kUnknown && value_ptr->desc.element != tensor_desc.element) { return false; } if (value_ptr->desc.layout != library::LayoutTypeID::kUnknown && value_ptr->desc.layout != tensor_desc.layout) { return false; } } return true; } /// Returns true if a tensor description satisfies a `tensor` value bool tensor_description_satisfies( library::TensorDescription const &tensor_desc, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); if (value_ptr->argument->description->type == ArgumentTypeID::kTensor) { return tensor_description_satisfies( tensor_desc, static_cast(value_ptr)); } else { throw std::runtime_error("Kernel argument mismatch"); } return false; } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Returns true if conv_kind satisfies the value bool conv_kind_satisfies( library::ConvKind const &conv_kind, EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr) { if (value_ptr->not_null) { library::ConvKind conv_kind_cmd_line = library::from_string(value_ptr->element); if (conv_kind_cmd_line != library::ConvKind::kUnknown && conv_kind_cmd_line != conv_kind) { return false; } } return true; } /// Returns true if conv_kind satisfies the value bool conv_kind_satisfies( library::ConvKind const &conv_kind, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { return conv_kind_satisfies( conv_kind, static_cast(value_ptr)); } else { throw std::runtime_error("Kernel argument mismatch"); } return false; } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Returns true if a iterator algorithm satisfies the value bool iterator_algorithm_satisfies( library::IteratorAlgorithmID const &iterator_algorithm, EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr) { if (value_ptr->not_null) { library::IteratorAlgorithmID iterator_algorithm_cmd_line = library::from_string(value_ptr->element); if (iterator_algorithm_cmd_line != library::IteratorAlgorithmID::kNone && iterator_algorithm_cmd_line != iterator_algorithm) { return false; } } return true; } /// Returns true if a iterator algorithm satisfies the value bool iterator_algorithm_satisfies( library::IteratorAlgorithmID const &iterator_algorithm, char const *name, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { size_t idx = problem_space.argument_index(name); KernelArgument::Value const *value_ptr = problem.at(idx).get(); if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { return iterator_algorithm_satisfies( iterator_algorithm, static_cast(value_ptr)); } else { throw std::runtime_error("Kernel argument mismatch"); } return false; } ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace profiler } // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////