ManishW's picture
Upload folder using huggingface_hub
022acf4
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<link rel="shortcut icon" href="../../img/favicon.ico" />
<title>train - NewsClassifier Docs</title>
<link rel="stylesheet" href="../../css/theme.css" />
<link rel="stylesheet" href="../../css/theme_extra.css" />
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.8.0/styles/github.min.css" />
<link href="../../assets/_mkdocstrings.css" rel="stylesheet" />
<script>
// Current page data
var mkdocs_page_name = "train";
var mkdocs_page_input_path = "newsclassifier\\train.md";
var mkdocs_page_url = null;
</script>
<!--[if lt IE 9]>
<script src="../../js/html5shiv.min.js"></script>
<![endif]-->
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.8.0/highlight.min.js"></script>
<script>hljs.highlightAll();</script>
</head>
<body class="wy-body-for-nav" role="document">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
<div class="wy-side-scroll">
<div class="wy-side-nav-search">
<a href="../.." class="icon icon-home"> NewsClassifier Docs
</a>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../..">Home</a>
</li>
</ul>
<p class="caption"><span class="caption-text">newsclassifier</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../config/">config</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="../data/">data</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="../models/">models</a>
</li>
<li class="toctree-l1 current"><a class="reference internal current" href="./">train</a>
<ul class="current">
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../tune/">tune</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="../inference/">inference</a>
</li>
<li class="toctree-l1"><a class="reference internal" href="../utils/">utils</a>
</li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" role="navigation" aria-label="Mobile navigation menu">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../..">NewsClassifier Docs</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content"><div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../.." class="icon icon-home" aria-label="Docs"></a></li>
<li class="breadcrumb-item">newsclassifier</li>
<li class="breadcrumb-item active">train</li>
<li class="wy-breadcrumbs-aside">
<a href="https://github.com/ManishW315/NewsClassifier/edit/master/docs/newsclassifier/train.md" class="icon icon-github"> Edit on GitHub</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div class="section" itemprop="articleBody">
<div class="doc doc-object doc-module">
<a id="newsclassifier.train"></a>
<div class="doc doc-contents first">
<div class="doc doc-children">
<div class="doc doc-object doc-function">
<h2 id="newsclassifier.train.eval_step" class="doc doc-heading">
<code class="highlight language-python"><span class="n">eval_step</span><span class="p">(</span><span class="n">val_loader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span></code>
</h2>
<div class="doc doc-contents ">
<p>Eval step.</p>
<details class="quote">
<summary> <code>newsclassifier\train.py</code></summary>
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">43</span>
<span class="normal">44</span>
<span class="normal">45</span>
<span class="normal">46</span>
<span class="normal">47</span>
<span class="normal">48</span>
<span class="normal">49</span>
<span class="normal">50</span>
<span class="normal">51</span>
<span class="normal">52</span>
<span class="normal">53</span>
<span class="normal">54</span>
<span class="normal">55</span>
<span class="normal">56</span>
<span class="normal">57</span>
<span class="normal">58</span>
<span class="normal">59</span>
<span class="normal">60</span>
<span class="normal">61</span>
<span class="normal">62</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">eval_step</span><span class="p">(</span><span class="n">val_loader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Eval step.&quot;&quot;&quot;</span>
<span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="n">loss</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="n">total_iterations</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">val_loader</span><span class="p">)</span>
<span class="n">desc</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;Validation - Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">y_trues</span><span class="p">,</span> <span class="n">y_preds</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">inference_mode</span><span class="p">():</span>
<span class="k">for</span> <span class="n">step</span><span class="p">,</span> <span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">val_loader</span><span class="p">),</span> <span class="n">total</span><span class="o">=</span><span class="n">total_iterations</span><span class="p">,</span> <span class="n">desc</span><span class="o">=</span><span class="n">desc</span><span class="p">):</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="n">collate</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">inputs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">inputs</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span>
<span class="n">targets</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">labels</span><span class="o">.</span><span class="n">long</span><span class="p">(),</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="c1"># one-hot (for loss_fn)</span>
<span class="n">J</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">loss</span> <span class="o">+=</span> <span class="p">(</span><span class="n">J</span> <span class="o">-</span> <span class="n">loss</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">y_trues</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">targets</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
<span class="n">y_preds</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
<span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">(</span><span class="n">y_trues</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">(</span><span class="n">y_preds</span><span class="p">)</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-function">
<h2 id="newsclassifier.train.train_step" class="doc doc-heading">
<code class="highlight language-python"><span class="n">train_step</span><span class="p">(</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span></code>
</h2>
<div class="doc doc-contents ">
<p>Train step.</p>
<details class="quote">
<summary> <code>newsclassifier\train.py</code></summary>
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">22</span>
<span class="normal">23</span>
<span class="normal">24</span>
<span class="normal">25</span>
<span class="normal">26</span>
<span class="normal">27</span>
<span class="normal">28</span>
<span class="normal">29</span>
<span class="normal">30</span>
<span class="normal">31</span>
<span class="normal">32</span>
<span class="normal">33</span>
<span class="normal">34</span>
<span class="normal">35</span>
<span class="normal">36</span>
<span class="normal">37</span>
<span class="normal">38</span>
<span class="normal">39</span>
<span class="normal">40</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Train step.&quot;&quot;&quot;</span>
<span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
<span class="n">loss</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="n">total_iterations</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="p">)</span>
<span class="n">desc</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;Training - Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">for</span> <span class="n">step</span><span class="p">,</span> <span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loader</span><span class="p">),</span> <span class="n">total</span><span class="o">=</span><span class="n">total_iterations</span><span class="p">,</span> <span class="n">desc</span><span class="o">=</span><span class="n">desc</span><span class="p">):</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="n">collate</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">inputs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">inputs</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="c1"># reset gradients</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># forward pass</span>
<span class="n">targets</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">labels</span><span class="o">.</span><span class="n">long</span><span class="p">(),</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="c1"># one-hot (for loss_fn)</span>
<span class="n">J</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="c1"># define loss</span>
<span class="n">J</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="c1"># backward pass</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="c1"># update weights</span>
<span class="n">loss</span> <span class="o">+=</span> <span class="p">(</span><span class="n">J</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">-</span> <span class="n">loss</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># cumulative loss</span>
<span class="k">return</span> <span class="n">loss</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
</div>
</div>
</div>
</div>
</div><footer>
<div class="rst-footer-buttons" role="navigation" aria-label="Footer Navigation">
<a href="../models/" class="btn btn-neutral float-left" title="models"><span class="icon icon-circle-arrow-left"></span> Previous</a>
<a href="../tune/" class="btn btn-neutral float-right" title="tune">Next <span class="icon icon-circle-arrow-right"></span></a>
</div>
<hr/>
<div role="contentinfo">
<!-- Copyright etc -->
</div>
Built with <a href="https://www.mkdocs.org/">MkDocs</a> using a <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" role="note" aria-label="Versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span>
<a href="https://github.com/ManishW315/NewsClassifier" class="fa fa-github" style="color: #fcfcfc"> GitHub</a>
</span>
<span><a href="../models/" style="color: #fcfcfc">&laquo; Previous</a></span>
<span><a href="../tune/" style="color: #fcfcfc">Next &raquo;</a></span>
</span>
</div>
<script src="../../js/jquery-3.6.0.min.js"></script>
<script>var base_url = "../..";</script>
<script src="../../js/theme_extra.js"></script>
<script src="../../js/theme.js"></script>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>