Spaces:
Sleeping
Sleeping
<html class="writer-html5" lang="en" > | |
<head> | |
<meta charset="utf-8" /> | |
<meta http-equiv="X-UA-Compatible" content="IE=edge" /> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> | |
<link rel="shortcut icon" href="../../img/favicon.ico" /> | |
<title>train - NewsClassifier Docs</title> | |
<link rel="stylesheet" href="../../css/theme.css" /> | |
<link rel="stylesheet" href="../../css/theme_extra.css" /> | |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.8.0/styles/github.min.css" /> | |
<link href="../../assets/_mkdocstrings.css" rel="stylesheet" /> | |
<script> | |
// Current page data | |
var mkdocs_page_name = "train"; | |
var mkdocs_page_input_path = "newsclassifier\\train.md"; | |
var mkdocs_page_url = null; | |
</script> | |
<!--[if lt IE 9]> | |
<script src="../../js/html5shiv.min.js"></script> | |
<![endif]--> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.8.0/highlight.min.js"></script> | |
<script>hljs.highlightAll();</script> | |
</head> | |
<body class="wy-body-for-nav" role="document"> | |
<div class="wy-grid-for-nav"> | |
<nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav"> | |
<div class="wy-side-scroll"> | |
<div class="wy-side-nav-search"> | |
<a href="../.." class="icon icon-home"> NewsClassifier Docs | |
</a> | |
</div> | |
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu"> | |
<ul> | |
<li class="toctree-l1"><a class="reference internal" href="../..">Home</a> | |
</li> | |
</ul> | |
<p class="caption"><span class="caption-text">newsclassifier</span></p> | |
<ul class="current"> | |
<li class="toctree-l1"><a class="reference internal" href="../config/">config</a> | |
</li> | |
<li class="toctree-l1"><a class="reference internal" href="../data/">data</a> | |
</li> | |
<li class="toctree-l1"><a class="reference internal" href="../models/">models</a> | |
</li> | |
<li class="toctree-l1 current"><a class="reference internal current" href="./">train</a> | |
<ul class="current"> | |
</ul> | |
</li> | |
<li class="toctree-l1"><a class="reference internal" href="../tune/">tune</a> | |
</li> | |
<li class="toctree-l1"><a class="reference internal" href="../inference/">inference</a> | |
</li> | |
<li class="toctree-l1"><a class="reference internal" href="../utils/">utils</a> | |
</li> | |
</ul> | |
</div> | |
</div> | |
</nav> | |
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"> | |
<nav class="wy-nav-top" role="navigation" aria-label="Mobile navigation menu"> | |
<i data-toggle="wy-nav-top" class="fa fa-bars"></i> | |
<a href="../..">NewsClassifier Docs</a> | |
</nav> | |
<div class="wy-nav-content"> | |
<div class="rst-content"><div role="navigation" aria-label="breadcrumbs navigation"> | |
<ul class="wy-breadcrumbs"> | |
<li><a href="../.." class="icon icon-home" aria-label="Docs"></a></li> | |
<li class="breadcrumb-item">newsclassifier</li> | |
<li class="breadcrumb-item active">train</li> | |
<li class="wy-breadcrumbs-aside"> | |
<a href="https://github.com/ManishW315/NewsClassifier/edit/master/docs/newsclassifier/train.md" class="icon icon-github"> Edit on GitHub</a> | |
</li> | |
</ul> | |
<hr/> | |
</div> | |
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article"> | |
<div class="section" itemprop="articleBody"> | |
<div class="doc doc-object doc-module"> | |
<a id="newsclassifier.train"></a> | |
<div class="doc doc-contents first"> | |
<div class="doc doc-children"> | |
<div class="doc doc-object doc-function"> | |
<h2 id="newsclassifier.train.eval_step" class="doc doc-heading"> | |
<code class="highlight language-python"><span class="n">eval_step</span><span class="p">(</span><span class="n">val_loader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span></code> | |
</h2> | |
<div class="doc doc-contents "> | |
<p>Eval step.</p> | |
<details class="quote"> | |
<summary> <code>newsclassifier\train.py</code></summary> | |
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">43</span> | |
<span class="normal">44</span> | |
<span class="normal">45</span> | |
<span class="normal">46</span> | |
<span class="normal">47</span> | |
<span class="normal">48</span> | |
<span class="normal">49</span> | |
<span class="normal">50</span> | |
<span class="normal">51</span> | |
<span class="normal">52</span> | |
<span class="normal">53</span> | |
<span class="normal">54</span> | |
<span class="normal">55</span> | |
<span class="normal">56</span> | |
<span class="normal">57</span> | |
<span class="normal">58</span> | |
<span class="normal">59</span> | |
<span class="normal">60</span> | |
<span class="normal">61</span> | |
<span class="normal">62</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">eval_step</span><span class="p">(</span><span class="n">val_loader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]:</span> | |
<span class="w"> </span><span class="sd">"""Eval step."""</span> | |
<span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> | |
<span class="n">loss</span> <span class="o">=</span> <span class="mf">0.0</span> | |
<span class="n">total_iterations</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">val_loader</span><span class="p">)</span> | |
<span class="n">desc</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"Validation - Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2">"</span> | |
<span class="n">y_trues</span><span class="p">,</span> <span class="n">y_preds</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span> | |
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">inference_mode</span><span class="p">():</span> | |
<span class="k">for</span> <span class="n">step</span><span class="p">,</span> <span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">val_loader</span><span class="p">),</span> <span class="n">total</span><span class="o">=</span><span class="n">total_iterations</span><span class="p">,</span> <span class="n">desc</span><span class="o">=</span><span class="n">desc</span><span class="p">):</span> | |
<span class="n">inputs</span> <span class="o">=</span> <span class="n">collate</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> | |
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">inputs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> | |
<span class="n">inputs</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> | |
<span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> | |
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> | |
<span class="n">targets</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">labels</span><span class="o">.</span><span class="n">long</span><span class="p">(),</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="c1"># one-hot (for loss_fn)</span> | |
<span class="n">J</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> | |
<span class="n">loss</span> <span class="o">+=</span> <span class="p">(</span><span class="n">J</span> <span class="o">-</span> <span class="n">loss</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> | |
<span class="n">y_trues</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">targets</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> | |
<span class="n">y_preds</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> | |
<span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">(</span><span class="n">y_trues</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">(</span><span class="n">y_preds</span><span class="p">)</span> | |
</code></pre></div></td></tr></table></div> | |
</details> | |
</div> | |
</div> | |
<div class="doc doc-object doc-function"> | |
<h2 id="newsclassifier.train.train_step" class="doc doc-heading"> | |
<code class="highlight language-python"><span class="n">train_step</span><span class="p">(</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span></code> | |
</h2> | |
<div class="doc doc-contents "> | |
<p>Train step.</p> | |
<details class="quote"> | |
<summary> <code>newsclassifier\train.py</code></summary> | |
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">22</span> | |
<span class="normal">23</span> | |
<span class="normal">24</span> | |
<span class="normal">25</span> | |
<span class="normal">26</span> | |
<span class="normal">27</span> | |
<span class="normal">28</span> | |
<span class="normal">29</span> | |
<span class="normal">30</span> | |
<span class="normal">31</span> | |
<span class="normal">32</span> | |
<span class="normal">33</span> | |
<span class="normal">34</span> | |
<span class="normal">35</span> | |
<span class="normal">36</span> | |
<span class="normal">37</span> | |
<span class="normal">38</span> | |
<span class="normal">39</span> | |
<span class="normal">40</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span> | |
<span class="w"> </span><span class="sd">"""Train step."""</span> | |
<span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span> | |
<span class="n">loss</span> <span class="o">=</span> <span class="mf">0.0</span> | |
<span class="n">total_iterations</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="p">)</span> | |
<span class="n">desc</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"Training - Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2">"</span> | |
<span class="k">for</span> <span class="n">step</span><span class="p">,</span> <span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loader</span><span class="p">),</span> <span class="n">total</span><span class="o">=</span><span class="n">total_iterations</span><span class="p">,</span> <span class="n">desc</span><span class="o">=</span><span class="n">desc</span><span class="p">):</span> | |
<span class="n">inputs</span> <span class="o">=</span> <span class="n">collate</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> | |
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">inputs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> | |
<span class="n">inputs</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> | |
<span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> | |
<span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="c1"># reset gradients</span> | |
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># forward pass</span> | |
<span class="n">targets</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">labels</span><span class="o">.</span><span class="n">long</span><span class="p">(),</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="c1"># one-hot (for loss_fn)</span> | |
<span class="n">J</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="c1"># define loss</span> | |
<span class="n">J</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="c1"># backward pass</span> | |
<span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="c1"># update weights</span> | |
<span class="n">loss</span> <span class="o">+=</span> <span class="p">(</span><span class="n">J</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">-</span> <span class="n">loss</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># cumulative loss</span> | |
<span class="k">return</span> <span class="n">loss</span> | |
</code></pre></div></td></tr></table></div> | |
</details> | |
</div> | |
</div> | |
</div> | |
</div> | |
</div> | |
</div> | |
</div><footer> | |
<div class="rst-footer-buttons" role="navigation" aria-label="Footer Navigation"> | |
<a href="../models/" class="btn btn-neutral float-left" title="models"><span class="icon icon-circle-arrow-left"></span> Previous</a> | |
<a href="../tune/" class="btn btn-neutral float-right" title="tune">Next <span class="icon icon-circle-arrow-right"></span></a> | |
</div> | |
<hr/> | |
<div role="contentinfo"> | |
<!-- Copyright etc --> | |
</div> | |
Built with <a href="https://www.mkdocs.org/">MkDocs</a> using a <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>. | |
</footer> | |
</div> | |
</div> | |
</section> | |
</div> | |
<div class="rst-versions" role="note" aria-label="Versions"> | |
<span class="rst-current-version" data-toggle="rst-current-version"> | |
<span> | |
<a href="https://github.com/ManishW315/NewsClassifier" class="fa fa-github" style="color: #fcfcfc"> GitHub</a> | |
</span> | |
<span><a href="../models/" style="color: #fcfcfc">« Previous</a></span> | |
<span><a href="../tune/" style="color: #fcfcfc">Next »</a></span> | |
</span> | |
</div> | |
<script src="../../js/jquery-3.6.0.min.js"></script> | |
<script>var base_url = "../..";</script> | |
<script src="../../js/theme_extra.js"></script> | |
<script src="../../js/theme.js"></script> | |
<script> | |
jQuery(function () { | |
SphinxRtdTheme.Navigation.enable(true); | |
}); | |
</script> | |
</body> | |
</html> | |