Spaces:
Sleeping
Sleeping
<html class="no-js" lang="en"> | |
<head><meta charset="utf-8"/> | |
<meta name="viewport" content="width=device-width,initial-scale=1"/> | |
<meta name="color-scheme" content="light dark"><meta name="generator" content="Docutils 0.19: https://docutils.sourceforge.io/" /> | |
<link rel="index" title="Index" href="../genindex.html" /><link rel="search" title="Search" href="../search.html" /><link rel="next" title="Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension" href="02_pytorch_extension_grouped_gemm.html" /><link rel="prev" title="Examples" href="../examples.html" /> | |
<link rel="canonical" href="docs/externals/01_epilogue.html" /> | |
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 --> | |
<title>Example of using elementwise activation functions in the CUTLASS Python interface - CUTLASS Python</title> | |
<link rel="stylesheet" type="text/css" href="../_static/pygments.css" /> | |
<link rel="stylesheet" type="text/css" href="../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" /> | |
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css" /> | |
<link rel="stylesheet" type="text/css" href="../_static/tabs.css" /> | |
<link rel="stylesheet" type="text/css" href="../_static/nbsphinx-code-cells.css" /> | |
<link rel="stylesheet" type="text/css" href="../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" /> | |
<style> | |
body { | |
--color-code-background: #eeffcc; | |
--color-code-foreground: black; | |
--color-brand-primary: #76B900; | |
--color-brand-content: #76B900; | |
} | |
@media not print { | |
body[data-theme="dark"] { | |
--color-code-background: #272822; | |
--color-code-foreground: #f8f8f2; | |
--color-brand-primary: #76B900; | |
--color-brand-content: #76B900; | |
} | |
@media (prefers-color-scheme: dark) { | |
body:not([data-theme="light"]) { | |
--color-code-background: #272822; | |
--color-code-foreground: #f8f8f2; | |
--color-brand-primary: #76B900; | |
--color-brand-content: #76B900; | |
} | |
} | |
} | |
</style></head> | |
<body> | |
<script> | |
document.body.dataset.theme = localStorage.getItem("theme") || "auto"; | |
</script> | |
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;"> | |
<symbol id="svg-toc" viewBox="0 0 24 24"> | |
<title>Contents</title> | |
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024"> | |
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/> | |
</svg> | |
</symbol> | |
<symbol id="svg-menu" viewBox="0 0 24 24"> | |
<title>Menu</title> | |
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" | |
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu"> | |
<line x1="3" y1="12" x2="21" y2="12"></line> | |
<line x1="3" y1="6" x2="21" y2="6"></line> | |
<line x1="3" y1="18" x2="21" y2="18"></line> | |
</svg> | |
</symbol> | |
<symbol id="svg-arrow-right" viewBox="0 0 24 24"> | |
<title>Expand</title> | |
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" | |
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right"> | |
<polyline points="9 18 15 12 9 6"></polyline> | |
</svg> | |
</symbol> | |
<symbol id="svg-sun" viewBox="0 0 24 24"> | |
<title>Light mode</title> | |
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" | |
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun"> | |
<circle cx="12" cy="12" r="5"></circle> | |
<line x1="12" y1="1" x2="12" y2="3"></line> | |
<line x1="12" y1="21" x2="12" y2="23"></line> | |
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line> | |
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line> | |
<line x1="1" y1="12" x2="3" y2="12"></line> | |
<line x1="21" y1="12" x2="23" y2="12"></line> | |
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line> | |
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line> | |
</svg> | |
</symbol> | |
<symbol id="svg-moon" viewBox="0 0 24 24"> | |
<title>Dark mode</title> | |
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" | |
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon"> | |
<path stroke="none" d="M0 0h24v24H0z" fill="none" /> | |
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" /> | |
</svg> | |
</symbol> | |
<symbol id="svg-sun-half" viewBox="0 0 24 24"> | |
<title>Auto light/dark mode</title> | |
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" | |
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow"> | |
<path stroke="none" d="M0 0h24v24H0z" fill="none"/> | |
<circle cx="12" cy="12" r="9" /> | |
<path d="M13 12h5" /> | |
<path d="M13 15h4" /> | |
<path d="M13 18h1" /> | |
<path d="M13 9h4" /> | |
<path d="M13 6h1" /> | |
</svg> | |
</symbol> | |
</svg> | |
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation"> | |
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc"> | |
<label class="overlay sidebar-overlay" for="__navigation"> | |
<div class="visually-hidden">Hide navigation sidebar</div> | |
</label> | |
<label class="overlay toc-overlay" for="__toc"> | |
<div class="visually-hidden">Hide table of contents sidebar</div> | |
</label> | |
<div class="page"> | |
<header class="mobile-header"> | |
<div class="header-left"> | |
<label class="nav-overlay-icon" for="__navigation"> | |
<div class="visually-hidden">Toggle site navigation sidebar</div> | |
<i class="icon"><svg><use href="#svg-menu"></use></svg></i> | |
</label> | |
</div> | |
<div class="header-center"> | |
<a href="../index.html"><div class="brand">CUTLASS Python</div></a> | |
</div> | |
<div class="header-right"> | |
<div class="theme-toggle-container theme-toggle-header"> | |
<button class="theme-toggle"> | |
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div> | |
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg> | |
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg> | |
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg> | |
</button> | |
</div> | |
<label class="toc-overlay-icon toc-header-icon" for="__toc"> | |
<div class="visually-hidden">Toggle table of contents sidebar</div> | |
<i class="icon"><svg><use href="#svg-toc"></use></svg></i> | |
</label> | |
</div> | |
</header> | |
<aside class="sidebar-drawer"> | |
<div class="sidebar-container"> | |
<div class="sidebar-sticky"><a class="sidebar-brand" href="../index.html"> | |
<div class="sidebar-logo-container"> | |
<img class="sidebar-logo only-light" src="../_static/cutlass-logo-small.png" alt="Light Logo"/> | |
<img class="sidebar-logo only-dark" src="../_static/cutlass-logo-small.png" alt="Dark Logo"/> | |
</div> | |
<span class="sidebar-brand-text">CUTLASS Python</span> | |
</a><form class="sidebar-search-container" method="get" action="../search.html" role="search"> | |
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search"> | |
<input type="hidden" name="check_keywords" value="yes"> | |
<input type="hidden" name="area" value="default"> | |
</form> | |
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree"> | |
<ul> | |
<li class="toctree-l1"><a class="reference internal" href="../index.html">Home</a></li> | |
</ul> | |
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p> | |
<ul> | |
<li class="toctree-l1"><a class="reference internal" href="../install.html">Installation</a></li> | |
<li class="toctree-l1"><a class="reference internal" href="00_basic_gemm.html">Getting Started</a></li> | |
<li class="toctree-l1"><a class="reference internal" href="../contribute.html">Contributing</a></li> | |
</ul> | |
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p> | |
<ul> | |
<li class="toctree-l1 has-children"><a class="reference internal" href="../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul> | |
<li class="toctree-l2 has-children"><a class="reference internal" href="../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul> | |
<li class="toctree-l3"><a class="reference internal" href="../cutlass.emit.html">Emitters</a></li> | |
<li class="toctree-l3"><a class="reference internal" href="../cutlass.op.html">Operations</a></li> | |
<li class="toctree-l3"><a class="reference internal" href="../cutlass.utils.html">Utilities</a></li> | |
</ul> | |
</li> | |
</ul> | |
</li> | |
</ul> | |
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p> | |
<ul class="current"> | |
<li class="toctree-l1 current has-children"><a class="reference internal" href="../examples.html">Examples</a><input checked="" class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul class="current"> | |
<li class="toctree-l2"><a class="reference internal" href="00_basic_gemm.html">Basic GEMM</a></li> | |
<li class="toctree-l2 current current-page"><a class="current reference internal" href="#">Epilogue</a></li> | |
<li class="toctree-l2"><a class="reference internal" href="02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li> | |
</ul> | |
</li> | |
</ul> | |
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p> | |
<ul> | |
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li> | |
</ul> | |
</div> | |
</div> | |
</div> | |
</div> | |
</aside> | |
<div class="main"> | |
<div class="content"> | |
<div class="article-container"> | |
<a href="#" class="back-to-top muted-link"> | |
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"> | |
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path> | |
</svg> | |
<span>Back to top</span> | |
</a> | |
<div class="content-icon-container"> | |
<div class="theme-toggle-container theme-toggle-content"> | |
<button class="theme-toggle"> | |
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div> | |
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg> | |
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg> | |
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg> | |
</button> | |
</div> | |
<label class="toc-overlay-icon toc-content-icon" for="__toc"> | |
<div class="visually-hidden">Toggle table of contents sidebar</div> | |
<i class="icon"><svg><use href="#svg-toc"></use></svg></i> | |
</label> | |
</div> | |
<article role="main"> | |
<section id="Example-of-using-elementwise-activation-functions-in-the-CUTLASS-Python-interface"> | |
<h1>Example of using elementwise activation functions in the CUTLASS Python interface<a class="headerlink" href="#Example-of-using-elementwise-activation-functions-in-the-CUTLASS-Python-interface" title="Permalink to this heading">#</a></h1> | |
<p>This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.</p> | |
<p><a class="reference external" href="https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a></p> | |
<p>We first import various packages needed for the example and construct the input and output tensors that will be used in our example.</p> | |
<div class="nbinput docutils container"> | |
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[1]: | |
</pre></div> | |
</div> | |
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> | |
<span class="kn">import</span> <span class="nn">cutlass</span> | |
<span class="c1"># This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to</span> | |
<span class="c1"># omit this information.</span> | |
<span class="n">print_module</span> <span class="o">=</span> <span class="kc">True</span> | |
<span class="n">m</span> <span class="o">=</span> <span class="mi">256</span> | |
<span class="n">n</span> <span class="o">=</span> <span class="n">m</span> | |
<span class="n">k</span> <span class="o">=</span> <span class="n">m</span> | |
<span class="n">type_A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span> | |
<span class="n">type_B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span> | |
<span class="n">type_C</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span> | |
<span class="n">type_D</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span> | |
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">1234</span><span class="p">)</span> | |
<span class="n">scope_min</span> <span class="o">=</span> <span class="o">-</span><span class="mi">4</span> | |
<span class="n">scope_max</span> <span class="o">=</span> <span class="mi">4</span> | |
<span class="n">tensor_A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="n">scope_min</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scope_max</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">k</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_A</span><span class="p">))</span> | |
<span class="n">tensor_B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="n">scope_min</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scope_max</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_B</span><span class="p">))</span> | |
<span class="n">tensor_C</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="n">scope_min</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scope_max</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_C</span><span class="p">))</span> | |
<span class="n">alpha</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">(</span><span class="mf">1.</span><span class="p">)</span> | |
<span class="n">beta</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">(</span><span class="mf">0.</span><span class="p">)</span> | |
<span class="n">tensor_D</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">tensor_C</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_D</span><span class="p">)</span> | |
</pre></div> | |
</div> | |
</div> | |
<div class="nboutput nblast docutils container"> | |
<div class="prompt empty docutils container"> | |
</div> | |
<div class="output_area stderr docutils container"> | |
<div class="highlight"><pre> | |
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html | |
from .autonotebook import tqdm as notebook_tqdm | |
</pre></div></div> | |
</div> | |
<section id="Run-a-GEMM-with-an-identity-activation-function"> | |
<h2>Run a GEMM with an identity activation function<a class="headerlink" href="#Run-a-GEMM-with-an-identity-activation-function" title="Permalink to this heading">#</a></h2> | |
<p>To begin, we simply run a default GEMM with an identity activation function. This performs the well-known operation <code class="docutils literal notranslate"><span class="pre">D</span> <span class="pre">=</span> <span class="pre">alpha</span> <span class="pre">*</span> <span class="pre">(A</span> <span class="pre">@</span> <span class="pre">B)</span> <span class="pre">+</span> <span class="pre">beta</span> <span class="pre">*</span> <span class="pre">C</span></code>. This is the default activation function used, and does not need to be specified.</p> | |
<div class="nbinput docutils container"> | |
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]: | |
</pre></div> | |
</div> | |
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">plan</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">Gemm</span><span class="p">(</span><span class="n">element</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">)</span> | |
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">tensor_B</span><span class="p">,</span> <span class="n">tensor_C</span><span class="p">,</span> <span class="n">tensor_D</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span> | |
</pre></div> | |
</div> | |
</div> | |
<div class="nboutput docutils container"> | |
<div class="prompt empty docutils container"> | |
</div> | |
<div class="output_area docutils container"> | |
<div class="highlight"><pre> | |
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8 | |
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base = | |
typename cutlass::gemm::kernel::DefaultGemmUniversal< | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, | |
cutlass::half_t, | |
cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm80, | |
cutlass::gemm::GemmShape<256, 128, 64>, | |
cutlass::gemm::GemmShape<64, 64, 64>, | |
cutlass::gemm::GemmShape<16, 8, 16>, | |
cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, | |
3, | |
cutlass::arch::OpMultiplyAdd | |
>::GemmKernel; | |
// Define named type | |
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : | |
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { }; | |
</pre></div></div> | |
</div> | |
<div class="nboutput nblast docutils container"> | |
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]: | |
</pre></div> | |
</div> | |
<div class="output_area docutils container"> | |
<div class="highlight"><pre> | |
<cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed907287c0> | |
</pre></div></div> | |
</div> | |
</section> | |
<section id="Run-a-GEMM-with-a-ReLU-element-wise-activation-function"> | |
<h2>Run a GEMM with a ReLU element-wise activation function<a class="headerlink" href="#Run-a-GEMM-with-a-ReLU-element-wise-activation-function" title="Permalink to this heading">#</a></h2> | |
<p>CUTLASS makes it easy to support other element-wise activation functions. This results in performing an element-wise after the generic linear combination performed in a GEMM. If we call such an activation function <code class="docutils literal notranslate"><span class="pre">act</span></code>, the resulting formulation is:</p> | |
<div class="highlight-none notranslate"><div class="highlight"><pre><span></span>D = alpha * (A @ B) + beta * C | |
D = act(D) | |
</pre></div> | |
</div> | |
<p>Here, we will add a ReLU activation function. Given an input <code class="docutils literal notranslate"><span class="pre">x</span></code>, ReLU returns <code class="docutils literal notranslate"><span class="pre">max(x,</span> <span class="pre">0)</span></code>.</p> | |
<p>This is easy to do in CUTLASS. One only needs to set the plan’s <code class="docutils literal notranslate"><span class="pre">activation</span></code> field.</p> | |
<div class="nbinput docutils container"> | |
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[3]: | |
</pre></div> | |
</div> | |
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">tensor_D_relu</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">tensor_C</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_D</span><span class="p">)</span> | |
<span class="n">plan</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">epilogue</span><span class="o">.</span><span class="n">relu</span> | |
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">tensor_B</span><span class="p">,</span> <span class="n">tensor_C</span><span class="p">,</span> <span class="n">tensor_D_relu</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span> | |
</pre></div> | |
</div> | |
</div> | |
<div class="nboutput docutils container"> | |
<div class="prompt empty docutils container"> | |
</div> | |
<div class="output_area docutils container"> | |
<div class="highlight"><pre> | |
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8 | |
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base = | |
typename cutlass::gemm::kernel::DefaultGemmUniversal< | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, | |
cutlass::half_t, | |
cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm80, | |
cutlass::gemm::GemmShape<256, 128, 64>, | |
cutlass::gemm::GemmShape<64, 64, 64>, | |
cutlass::gemm::GemmShape<16, 8, 16>, | |
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, | |
3, | |
cutlass::arch::OpMultiplyAdd | |
>::GemmKernel; | |
// Define named type | |
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : | |
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { }; | |
</pre></div></div> | |
</div> | |
<div class="nboutput nblast docutils container"> | |
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[3]: | |
</pre></div> | |
</div> | |
<div class="output_area docutils container"> | |
<div class="highlight"><pre> | |
<cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed906f2460> | |
</pre></div></div> | |
</div> | |
<p>We can now verify that the result of the GEMM that used a ReLU activation function:</p> | |
<div class="nbinput nblast docutils container"> | |
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[4]: | |
</pre></div> | |
</div> | |
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">relu_ref</span> <span class="o">=</span> <span class="p">(</span><span class="n">tensor_D</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_D</span><span class="p">)</span> <span class="o">*</span> <span class="n">tensor_D</span> | |
<span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_array_equal</span><span class="p">(</span><span class="n">relu_ref</span><span class="p">,</span> <span class="n">tensor_D_relu</span><span class="p">)</span> | |
</pre></div> | |
</div> | |
</div> | |
</section> | |
<section id="Other-element-wise-activation-functions"> | |
<h2>Other element-wise activation functions<a class="headerlink" href="#Other-element-wise-activation-functions" title="Permalink to this heading">#</a></h2> | |
<p>CUTLASS supports a variety of widely-used element-wise activation functions. We can obtain a list of these functions via the <code class="docutils literal notranslate"><span class="pre">get_activations()</span></code> method.</p> | |
<div class="nbinput docutils container"> | |
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[5]: | |
</pre></div> | |
</div> | |
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">activations</span> <span class="o">=</span> <span class="n">plan</span><span class="o">.</span><span class="n">activations</span><span class="p">()</span> | |
<span class="k">for</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">activations</span><span class="p">:</span> | |
<span class="nb">print</span><span class="p">(</span><span class="n">activation</span><span class="p">)</span> | |
</pre></div> | |
</div> | |
</div> | |
<div class="nboutput nblast docutils container"> | |
<div class="prompt empty docutils container"> | |
</div> | |
<div class="output_area docutils container"> | |
<div class="highlight"><pre> | |
<class 'cutlass.backend.epilogue.gelu'> | |
<class 'cutlass.backend.epilogue.hardswish'> | |
<class 'cutlass.backend.epilogue.identity'> | |
<class 'cutlass.backend.epilogue.leaky_relu'> | |
<class 'cutlass.backend.epilogue.relu'> | |
<class 'cutlass.backend.epilogue.sigmoid'> | |
<class 'cutlass.backend.epilogue.silu'> | |
<class 'cutlass.backend.epilogue.tanh'> | |
</pre></div></div> | |
</div> | |
<p>We can then run each of them:</p> | |
<div class="nbinput docutils container"> | |
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[6]: | |
</pre></div> | |
</div> | |
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="k">for</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">activations</span><span class="p">:</span> | |
<span class="nb">print</span><span class="p">(</span><span class="s1">'============================================================================================='</span><span class="p">)</span> | |
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Compiling and running activation </span><span class="si">{</span><span class="n">activation</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span> | |
<span class="nb">print</span><span class="p">(</span><span class="s1">'============================================================================================='</span><span class="p">)</span> | |
<span class="n">plan</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span> | |
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">tensor_B</span><span class="p">,</span> <span class="n">tensor_C</span><span class="p">,</span> <span class="n">tensor_D</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span> | |
</pre></div> | |
</div> | |
</div> | |
<div class="nboutput nblast docutils container"> | |
<div class="prompt empty docutils container"> | |
</div> | |
<div class="output_area docutils container"> | |
<div class="highlight"><pre> | |
============================================================================================= | |
Compiling and running activation <class 'cutlass.backend.epilogue.gelu'> | |
============================================================================================= | |
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8 | |
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base = | |
typename cutlass::gemm::kernel::DefaultGemmUniversal< | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, | |
cutlass::half_t, | |
cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm80, | |
cutlass::gemm::GemmShape<256, 128, 64>, | |
cutlass::gemm::GemmShape<64, 64, 64>, | |
cutlass::gemm::GemmShape<16, 8, 16>, | |
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, | |
3, | |
cutlass::arch::OpMultiplyAdd | |
>::GemmKernel; | |
// Define named type | |
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : | |
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { }; | |
============================================================================================= | |
Compiling and running activation <class 'cutlass.backend.epilogue.hardswish'> | |
============================================================================================= | |
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8 | |
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base = | |
typename cutlass::gemm::kernel::DefaultGemmUniversal< | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, | |
cutlass::half_t, | |
cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm80, | |
cutlass::gemm::GemmShape<256, 128, 64>, | |
cutlass::gemm::GemmShape<64, 64, 64>, | |
cutlass::gemm::GemmShape<16, 8, 16>, | |
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::HardSwish, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, | |
3, | |
cutlass::arch::OpMultiplyAdd | |
>::GemmKernel; | |
// Define named type | |
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : | |
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { }; | |
============================================================================================= | |
Compiling and running activation <class 'cutlass.backend.epilogue.identity'> | |
============================================================================================= | |
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8 | |
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base = | |
typename cutlass::gemm::kernel::DefaultGemmUniversal< | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, | |
cutlass::half_t, | |
cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm80, | |
cutlass::gemm::GemmShape<256, 128, 64>, | |
cutlass::gemm::GemmShape<64, 64, 64>, | |
cutlass::gemm::GemmShape<16, 8, 16>, | |
cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, | |
3, | |
cutlass::arch::OpMultiplyAdd | |
>::GemmKernel; | |
// Define named type | |
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : | |
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { }; | |
============================================================================================= | |
Compiling and running activation <class 'cutlass.backend.epilogue.leaky_relu'> | |
============================================================================================= | |
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8 | |
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base = | |
typename cutlass::gemm::kernel::DefaultGemmUniversal< | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, | |
cutlass::half_t, | |
cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm80, | |
cutlass::gemm::GemmShape<256, 128, 64>, | |
cutlass::gemm::GemmShape<64, 64, 64>, | |
cutlass::gemm::GemmShape<16, 8, 16>, | |
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::LeakyReLU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, | |
3, | |
cutlass::arch::OpMultiplyAdd | |
>::GemmKernel; | |
// Define named type | |
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : | |
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { }; | |
============================================================================================= | |
Compiling and running activation <class 'cutlass.backend.epilogue.relu'> | |
============================================================================================= | |
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8 | |
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base = | |
typename cutlass::gemm::kernel::DefaultGemmUniversal< | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, | |
cutlass::half_t, | |
cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm80, | |
cutlass::gemm::GemmShape<256, 128, 64>, | |
cutlass::gemm::GemmShape<64, 64, 64>, | |
cutlass::gemm::GemmShape<16, 8, 16>, | |
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, | |
3, | |
cutlass::arch::OpMultiplyAdd | |
>::GemmKernel; | |
// Define named type | |
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : | |
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { }; | |
============================================================================================= | |
Compiling and running activation <class 'cutlass.backend.epilogue.sigmoid'> | |
============================================================================================= | |
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8 | |
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base = | |
typename cutlass::gemm::kernel::DefaultGemmUniversal< | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, | |
cutlass::half_t, | |
cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm80, | |
cutlass::gemm::GemmShape<256, 128, 64>, | |
cutlass::gemm::GemmShape<64, 64, 64>, | |
cutlass::gemm::GemmShape<16, 8, 16>, | |
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Sigmoid, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, | |
3, | |
cutlass::arch::OpMultiplyAdd | |
>::GemmKernel; | |
// Define named type | |
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : | |
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { }; | |
============================================================================================= | |
Compiling and running activation <class 'cutlass.backend.epilogue.silu'> | |
============================================================================================= | |
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8 | |
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base = | |
typename cutlass::gemm::kernel::DefaultGemmUniversal< | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, | |
cutlass::half_t, | |
cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm80, | |
cutlass::gemm::GemmShape<256, 128, 64>, | |
cutlass::gemm::GemmShape<64, 64, 64>, | |
cutlass::gemm::GemmShape<16, 8, 16>, | |
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::SiLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, | |
3, | |
cutlass::arch::OpMultiplyAdd | |
>::GemmKernel; | |
// Define named type | |
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : | |
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { }; | |
============================================================================================= | |
Compiling and running activation <class 'cutlass.backend.epilogue.tanh'> | |
============================================================================================= | |
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8 | |
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base = | |
typename cutlass::gemm::kernel::DefaultGemmUniversal< | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, | |
cutlass::half_t, cutlass::layout::RowMajor, | |
cutlass::half_t, | |
cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm80, | |
cutlass::gemm::GemmShape<256, 128, 64>, | |
cutlass::gemm::GemmShape<64, 64, 64>, | |
cutlass::gemm::GemmShape<16, 8, 16>, | |
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Tanh, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, | |
3, | |
cutlass::arch::OpMultiplyAdd | |
>::GemmKernel; | |
// Define named type | |
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : | |
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { }; | |
</pre></div></div> | |
</div> | |
<div class="nbinput nblast docutils container"> | |
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[ ]: | |
</pre></div> | |
</div> | |
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span> | |
</pre></div> | |
</div> | |
</div> | |
</section> | |
</section> | |
</article> | |
</div> | |
<footer> | |
<div class="related-pages"> | |
<a class="next-page" href="02_pytorch_extension_grouped_gemm.html"> | |
<div class="page-info"> | |
<div class="context"> | |
<span>Next</span> | |
</div> | |
<div class="title">Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension</div> | |
</div> | |
<svg class="furo-related-icon"><use href="#svg-arrow-right"></use></svg> | |
</a> | |
<a class="prev-page" href="../examples.html"> | |
<svg class="furo-related-icon"><use href="#svg-arrow-right"></use></svg> | |
<div class="page-info"> | |
<div class="context"> | |
<span>Previous</span> | |
</div> | |
<div class="title">Examples</div> | |
</div> | |
</a> | |
</div> | |
<div class="bottom-of-page"> | |
<div class="left-details"> | |
<div class="copyright"> | |
Copyright © 2023, NVIDIA | |
</div> | |
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s | |
<a href="https://github.com/pradyunsg/furo">Furo</a> | |
</div> | |
<div class="right-details"> | |
<div class="icons"> | |
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub"> | |
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16"> | |
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path> | |
</svg> | |
</a> | |
</div> | |
</div> | |
</div> | |
</footer> | |
</div> | |
<aside class="toc-drawer"> | |
<div class="toc-sticky toc-scroll"> | |
<div class="toc-title-container"> | |
<span class="toc-title"> | |
On this page | |
</span> | |
</div> | |
<div class="toc-tree-container"> | |
<div class="toc-tree"> | |
<ul> | |
<li><a class="reference internal" href="#">Example of using elementwise activation functions in the CUTLASS Python interface</a><ul> | |
<li><a class="reference internal" href="#Run-a-GEMM-with-an-identity-activation-function">Run a GEMM with an identity activation function</a></li> | |
<li><a class="reference internal" href="#Run-a-GEMM-with-a-ReLU-element-wise-activation-function">Run a GEMM with a ReLU element-wise activation function</a></li> | |
<li><a class="reference internal" href="#Other-element-wise-activation-functions">Other element-wise activation functions</a></li> | |
</ul> | |
</li> | |
</ul> | |
</div> | |
</div> | |
</div> | |
</aside> | |
</div> | |
</div><script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script> | |
<script src="../_static/doctools.js"></script> | |
<script src="../_static/sphinx_highlight.js"></script> | |
<script src="../_static/scripts/furo.js"></script> | |
<script src="../_static/clipboard.min.js"></script> | |
<script src="../_static/copybutton.js"></script> | |
<script src="../_static/tabs.js"></script> | |
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script> | |
<script>window.MathJax = {"tex": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true}, "options": {"ignoreHtmlClass": "tex2jax_ignore|mathjax_ignore|document", "processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script> | |
<script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> | |
</body> | |
</html> |