ManishW commited on
Commit
022acf4
·
1 Parent(s): 80f5f82

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +166 -0
  3. .pre-commit-config.yaml +17 -0
  4. LICENSE +21 -0
  5. Makefile +15 -0
  6. NewsClassifier.egg-info/PKG-INFO +6 -0
  7. NewsClassifier.egg-info/SOURCES.txt +16 -0
  8. NewsClassifier.egg-info/dependency_links.txt +1 -0
  9. NewsClassifier.egg-info/requires.txt +34 -0
  10. NewsClassifier.egg-info/top_level.txt +1 -0
  11. README.md +4 -8
  12. app.py +50 -0
  13. artifacts/model.pt +3 -0
  14. dataset/preprocessed/test.csv +0 -0
  15. dataset/preprocessed/train.csv +0 -0
  16. dataset/raw/news_dataset.csv +3 -0
  17. docs/index.md +35 -0
  18. docs/newsclassifier/config.md +1 -0
  19. docs/newsclassifier/data.md +1 -0
  20. docs/newsclassifier/inference.md +1 -0
  21. docs/newsclassifier/models.md +1 -0
  22. docs/newsclassifier/train.md +1 -0
  23. docs/newsclassifier/tune.md +1 -0
  24. docs/newsclassifier/utils.md +1 -0
  25. logs/error.log +0 -0
  26. logs/info.log +186 -0
  27. mkdocs.yml +20 -0
  28. newsclassifier/__init__.py +0 -0
  29. newsclassifier/__pycache__/__init__.cpython-310.pyc +0 -0
  30. newsclassifier/__pycache__/config.cpython-310.pyc +0 -0
  31. newsclassifier/__pycache__/data.cpython-310.pyc +0 -0
  32. newsclassifier/__pycache__/models.cpython-310.pyc +0 -0
  33. newsclassifier/__pycache__/predict.cpython-310.pyc +0 -0
  34. newsclassifier/__pycache__/serve.cpython-310.pyc +0 -0
  35. newsclassifier/config/__init__.py +0 -0
  36. newsclassifier/config/__pycache__/__init__.cpython-310.pyc +0 -0
  37. newsclassifier/config/__pycache__/config.cpython-310.pyc +0 -0
  38. newsclassifier/config/config.py +265 -0
  39. newsclassifier/config/sweep_config.yaml +17 -0
  40. newsclassifier/data.py +197 -0
  41. newsclassifier/inference.py +54 -0
  42. newsclassifier/models.py +60 -0
  43. newsclassifier/predict.py +32 -0
  44. newsclassifier/train.py +151 -0
  45. newsclassifier/tune.py +85 -0
  46. newsclassifier/utils.py +20 -0
  47. notebooks/eda.ipynb +257 -0
  48. notebooks/newsclassifier-roberta-base-wandb-track-sweep.ipynb +1035 -0
  49. requirements.txt +34 -0
  50. setup.py +23 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ dataset/raw/news_dataset.csv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # make
163
+ Makefile
164
+
165
+ # artifacts
166
+ artifacts/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ # See https://pre-commit.com/hooks.html for more hooks
3
+ repos:
4
+ - repo: https://github.com/pre-commit/pre-commit-hooks
5
+ rev: v4.5.0
6
+ hooks:
7
+ - id: trailing-whitespace
8
+ exclude: "docs/index.md"
9
+ - id: check-yaml
10
+ - repo: local
11
+ hooks:
12
+ - id: style
13
+ name: Style
14
+ entry: make
15
+ args: ["style"]
16
+ language: system
17
+ pass_filenames: false
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Manish Wahale
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Makefile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ifeq ($(OS), Windows_NT)
2
+ # Styling
3
+ .PHONY: style
4
+ style:
5
+ black . --line-length 150
6
+ isort . -rc
7
+ flake8 . --exit-zero
8
+ else
9
+ # Styling
10
+ .PHONY: style
11
+ style:
12
+ python3 -m black . --line-length 150
13
+ python3 -m isort . -rc
14
+ python3 -m flake8 . --exit-zero
15
+ endif
NewsClassifier.egg-info/PKG-INFO ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: NewsClassifier
3
+ Version: 1.0
4
+ Author: ManishW
5
+ Author-email: [email protected]
6
+ License-File: LICENSE
NewsClassifier.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ NewsClassifier.egg-info/PKG-INFO
5
+ NewsClassifier.egg-info/SOURCES.txt
6
+ NewsClassifier.egg-info/dependency_links.txt
7
+ NewsClassifier.egg-info/requires.txt
8
+ NewsClassifier.egg-info/top_level.txt
9
+ newsclassifier/__init__.py
10
+ newsclassifier/data.py
11
+ newsclassifier/inference.py
12
+ newsclassifier/models.py
13
+ newsclassifier/train.py
14
+ newsclassifier/tune.py
15
+ newsclassifier/config/__init__.py
16
+ newsclassifier/config/config.py
NewsClassifier.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
NewsClassifier.egg-info/requires.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiosignal==1.3.1
2
+ attrs==23.1.0
3
+ certifi==2023.7.22
4
+ charset-normalizer==3.3.1
5
+ click==8.1.7
6
+ colorama==0.4.6
7
+ contourpy==1.1.1
8
+ cycler==0.12.1
9
+ filelock==3.12.4
10
+ fonttools==4.43.1
11
+ frozenlist==1.4.0
12
+ idna==3.4
13
+ jsonschema==4.19.1
14
+ jsonschema-specifications==2023.7.1
15
+ kiwisolver==1.4.5
16
+ matplotlib==3.8.0
17
+ msgpack==1.0.7
18
+ numpy==1.26.1
19
+ packaging==23.2
20
+ pandas==2.1.2
21
+ Pillow==10.1.0
22
+ protobuf==4.24.4
23
+ pyparsing==3.1.1
24
+ python-dateutil==2.8.2
25
+ pytz==2023.3.post1
26
+ PyYAML==6.0.1
27
+ ray==2.7.1
28
+ referencing==0.30.2
29
+ requests==2.31.0
30
+ rpds-py==0.10.6
31
+ seaborn==0.13.0
32
+ six==1.16.0
33
+ tzdata==2023.3
34
+ urllib3==2.0.7
NewsClassifier.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ newsclassifier
README.md CHANGED
@@ -1,12 +1,8 @@
1
  ---
2
- title: News Classifier
3
- emoji: 🏃
4
- colorFrom: indigo
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.0.2
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: News-Classifier
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.0.2
 
 
6
  ---
7
+ # NewsClassifier
8
+ See docs here: [NewsClassifier Docs](https://ManishW315.github.io/NewsClassifier/)
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from newsclassifier.config.config import Cfg, logger
6
+ from newsclassifier.data import prepare_input
7
+ from newsclassifier.models import CustomModel
8
+ from transformers import RobertaTokenizer
9
+
10
+ labels = list(Cfg.index_to_class.values())
11
+
12
+ # load and compile the model
13
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
14
+ model = CustomModel(num_classes=7)
15
+ model.load_state_dict(torch.load(os.path.join(Cfg.artifacts_path, "model.pt"), map_location=torch.device("cpu")))
16
+
17
+
18
+ def prediction(text):
19
+ sample_input = prepare_input(tokenizer, text)
20
+ input_ids = torch.unsqueeze(sample_input["input_ids"], 0).to("cpu")
21
+ attention_masks = torch.unsqueeze(sample_input["attention_mask"], 0).to("cpu")
22
+ test_sample = dict(input_ids=input_ids, attention_mask=attention_masks)
23
+
24
+ with torch.no_grad():
25
+ y_pred_test_sample = model.predict_proba(test_sample)
26
+ pred_probs = y_pred_test_sample[0]
27
+
28
+ return {labels[i]: float(pred_probs[i]) for i in range(len(labels))}
29
+
30
+
31
+ title = "NewsClassifier"
32
+ description = "Enter a news headline, and this app will classify it into one of the categories."
33
+ instructions = "Type or paste a news headline in the textbox and press Enter."
34
+
35
+ iface = gr.Interface(
36
+ fn=prediction,
37
+ inputs=gr.Textbox(),
38
+ outputs=gr.Label(num_top_classes=7),
39
+ title=title,
40
+ description=description,
41
+ examples=[
42
+ ["Global Smartphone Shipments Will Hit Lowest Point in a Decade, IDC Says"],
43
+ ["John Wick's First Spinoff is the Rare Prequel That Justifies Its Existence"],
44
+ ["Research provides a better understanding of how light stimulates the brain"],
45
+ ["Lionel Messi scores free kick golazo for Argentina in World Cup qualifiers"],
46
+ ],
47
+ article=instructions,
48
+ )
49
+
50
+ iface.launch(share=True)
artifacts/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ad2ee4ee7324989ef530eae760f3cb4a660aaca0bae36469c9ae6723130b83d
3
+ size 498672838
dataset/preprocessed/test.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset/preprocessed/train.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset/raw/news_dataset.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98c974915d3871f9fd92985fa2413afb995adb7545e6ee4a036240f3a20abd18
3
+ size 18273585
docs/index.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Welcome to NewsClassifier Docs
2
+
3
+ For source visit [ManishW315/NewsClassifier](https://github.com/ManishW315/NewsClassifier).
4
+
5
+ ## Project layout
6
+ <pre>
7
+ NewsClassifier
8
+
9
+ ├───dataset
10
+ │ ├───preprocessed
11
+ │ │ test.csv
12
+ │ │ train.csv
13
+ │ │
14
+ │ └───raw
15
+ │ news_dataset.csv
16
+
17
+ ├───newsclassifier
18
+ │ │ data.py
19
+ │ │ models.py
20
+ │ │ train.py
21
+ │ │ tune.py
22
+ │ │ inference.py
23
+ │ │ utils.py
24
+ │ │
25
+ │ │
26
+ │ └───config
27
+ │ config.py
28
+ │ sweep_config.yaml
29
+
30
+ ├───notebooks
31
+ │ eda.ipynb
32
+ │ newsclassifier-roberta-base-wandb-track-sweep.ipynb
33
+
34
+ └───test
35
+ </pre>
docs/newsclassifier/config.md ADDED
@@ -0,0 +1 @@
 
 
1
+ ::: newsclassifier.config.config
docs/newsclassifier/data.md ADDED
@@ -0,0 +1 @@
 
 
1
+ ::: newsclassifier.data
docs/newsclassifier/inference.md ADDED
@@ -0,0 +1 @@
 
 
1
+ ::: newsclassifier.inference
docs/newsclassifier/models.md ADDED
@@ -0,0 +1 @@
 
 
1
+ ::: newsclassifier.models
docs/newsclassifier/train.md ADDED
@@ -0,0 +1 @@
 
 
1
+ ::: newsclassifier.train
docs/newsclassifier/tune.md ADDED
@@ -0,0 +1 @@
 
 
1
+ ::: newsclassifier.tune
docs/newsclassifier/utils.md ADDED
@@ -0,0 +1 @@
 
 
1
+ ::: newsclassifier.utils
logs/error.log ADDED
File without changes
logs/info.log ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ INFO 2023-11-01 08:36:13,083 [root:data.py:load_dataset:24]
2
+ Loading Data.
3
+
4
+ INFO 2023-11-01 08:40:59,763 [root:data.py:load_dataset:24]
5
+ Loading Data.
6
+
7
+ INFO 2023-11-01 08:43:10,163 [root:data.py:load_dataset:24]
8
+ Loading Data.
9
+
10
+ INFO 2023-11-01 08:44:10,037 [root:data.py:load_dataset:24]
11
+ Loading Data.
12
+
13
+ INFO 2023-11-01 08:47:58,057 [root:data.py:load_dataset:27]
14
+ Loading Data.
15
+
16
+ INFO 2023-11-01 08:48:28,766 [root:data.py:load_dataset:27]
17
+ Loading Data.
18
+
19
+ INFO 2023-11-01 08:49:43,821 [root:data.py:load_dataset:27]
20
+ Loading Data.
21
+
22
+ INFO 2023-11-01 08:49:46,460 [root:data.py:data_split:105]
23
+ Splitting Data.
24
+
25
+ INFO 2023-11-01 08:49:46,564 [root:data.py:data_split:116]
26
+ Saving and storing data splits.
27
+
28
+ INFO 2023-11-02 00:09:13,890 [root:data.py:clean_text:58]
29
+ Cleaning input text.
30
+
31
+ INFO 2023-11-02 00:11:13,522 [root:data.py:clean_text:58]
32
+ Cleaning input text.
33
+
34
+ INFO 2023-11-02 00:23:17,886 [root:data.py:clean_text:58]
35
+ Cleaning input text.
36
+
37
+ INFO 2023-11-02 00:25:53,585 [root:data.py:clean_text:58]
38
+ Cleaning input text.
39
+
40
+ INFO 2023-11-02 00:25:53,642 [root:data.py:prepare_input:146]
41
+ Tokenizing input text.
42
+
43
+ INFO 2023-11-02 00:30:41,901 [root:data.py:clean_text:58]
44
+ Cleaning input text.
45
+
46
+ INFO 2023-11-02 00:30:41,919 [root:data.py:prepare_input:146]
47
+ Tokenizing input text.
48
+
49
+ INFO 2023-11-02 00:36:18,514 [root:data.py:clean_text:58]
50
+ Cleaning input text.
51
+
52
+ INFO 2023-11-02 00:36:18,538 [root:data.py:prepare_input:146]
53
+ Tokenizing input text.
54
+
55
+ INFO 2023-11-02 10:47:32,805 [root:data.py:prepare_input:146]
56
+ Tokenizing input text.
57
+
58
+ INFO 2023-11-02 10:48:36,522 [root:data.py:prepare_input:146]
59
+ Tokenizing input text.
60
+
61
+ INFO 2023-11-02 10:48:52,388 [root:data.py:prepare_input:146]
62
+ Tokenizing input text.
63
+
64
+ INFO 2023-11-02 10:49:14,171 [root:data.py:prepare_input:146]
65
+ Tokenizing input text.
66
+
67
+ INFO 2023-11-02 10:50:10,611 [root:data.py:prepare_input:146]
68
+ Tokenizing input text.
69
+
70
+ INFO 2023-11-02 10:50:27,112 [root:data.py:prepare_input:146]
71
+ Tokenizing input text.
72
+
73
+ INFO 2023-11-02 10:50:51,887 [root:data.py:prepare_input:146]
74
+ Tokenizing input text.
75
+
76
+ INFO 2023-11-02 10:51:44,829 [root:data.py:prepare_input:146]
77
+ Tokenizing input text.
78
+
79
+ INFO 2023-11-02 10:52:06,984 [root:data.py:prepare_input:146]
80
+ Tokenizing input text.
81
+
82
+ INFO 2023-11-02 10:52:20,660 [root:data.py:prepare_input:146]
83
+ Tokenizing input text.
84
+
85
+ INFO 2023-11-02 10:52:33,236 [root:data.py:prepare_input:146]
86
+ Tokenizing input text.
87
+
88
+ INFO 2023-11-02 10:53:05,679 [root:data.py:prepare_input:146]
89
+ Tokenizing input text.
90
+
91
+ INFO 2023-11-02 10:53:20,561 [root:data.py:prepare_input:146]
92
+ Tokenizing input text.
93
+
94
+ INFO 2023-11-02 10:53:29,476 [root:data.py:prepare_input:146]
95
+ Tokenizing input text.
96
+
97
+ INFO 2023-11-02 10:53:38,528 [root:data.py:prepare_input:146]
98
+ Tokenizing input text.
99
+
100
+ INFO 2023-11-02 11:01:28,685 [root:data.py:prepare_input:146]
101
+ Tokenizing input text.
102
+
103
+ INFO 2023-11-02 14:50:33,049 [root:data.py:prepare_input:146]
104
+ Tokenizing input text.
105
+
106
+ INFO 2023-11-02 14:52:09,259 [root:data.py:prepare_input:146]
107
+ Tokenizing input text.
108
+
109
+ INFO 2023-11-02 14:53:30,933 [root:data.py:prepare_input:146]
110
+ Tokenizing input text.
111
+
112
+ INFO 2023-11-02 21:22:31,654 [root:data.py:prepare_input:146]
113
+ Tokenizing input text.
114
+
115
+ INFO 2023-11-02 21:30:09,258 [root:data.py:clean_text:58]
116
+ Cleaning input text.
117
+
118
+ INFO 2023-11-02 21:30:46,696 [root:data.py:prepare_input:146]
119
+ Tokenizing input text.
120
+
121
+ INFO 2023-11-02 21:39:13,401 [root:data.py:prepare_input:146]
122
+ Tokenizing input text.
123
+
124
+ INFO 2023-11-02 21:40:13,665 [root:data.py:prepare_input:146]
125
+ Tokenizing input text.
126
+
127
+ INFO 2023-11-02 21:44:01,779 [root:data.py:prepare_input:146]
128
+ Tokenizing input text.
129
+
130
+ INFO 2023-11-02 21:44:20,110 [root:data.py:prepare_input:146]
131
+ Tokenizing input text.
132
+
133
+ INFO 2023-11-02 21:45:52,673 [root:data.py:prepare_input:146]
134
+ Tokenizing input text.
135
+
136
+ INFO 2023-11-02 21:48:31,415 [root:data.py:prepare_input:146]
137
+ Tokenizing input text.
138
+
139
+ INFO 2023-11-02 21:49:40,642 [root:data.py:prepare_input:146]
140
+ Tokenizing input text.
141
+
142
+ INFO 2023-11-02 21:50:42,110 [root:data.py:prepare_input:146]
143
+ Tokenizing input text.
144
+
145
+ INFO 2023-11-02 21:55:50,749 [root:data.py:prepare_input:146]
146
+ Tokenizing input text.
147
+
148
+ INFO 2023-11-02 21:56:30,951 [root:data.py:prepare_input:146]
149
+ Tokenizing input text.
150
+
151
+ INFO 2023-11-02 21:56:47,555 [root:data.py:prepare_input:146]
152
+ Tokenizing input text.
153
+
154
+ INFO 2023-11-02 21:56:53,879 [root:data.py:prepare_input:146]
155
+ Tokenizing input text.
156
+
157
+ INFO 2023-11-02 21:57:11,729 [root:data.py:prepare_input:146]
158
+ Tokenizing input text.
159
+
160
+ INFO 2023-11-02 21:57:14,827 [root:data.py:prepare_input:146]
161
+ Tokenizing input text.
162
+
163
+ INFO 2023-11-02 21:57:23,501 [root:data.py:prepare_input:146]
164
+ Tokenizing input text.
165
+
166
+ INFO 2023-11-02 22:20:57,360 [root:data.py:prepare_input:146]
167
+ Tokenizing input text.
168
+
169
+ INFO 2023-11-02 22:25:04,600 [root:data.py:prepare_input:146]
170
+ Tokenizing input text.
171
+
172
+ INFO 2023-11-02 22:25:15,152 [root:data.py:prepare_input:146]
173
+ Tokenizing input text.
174
+
175
+ INFO 2023-11-02 22:47:41,043 [root:data.py:prepare_input:146]
176
+ Tokenizing input text.
177
+
178
+ INFO 2023-11-02 22:47:47,106 [root:data.py:prepare_input:146]
179
+ Tokenizing input text.
180
+
181
+ INFO 2023-11-02 22:47:52,655 [root:data.py:prepare_input:146]
182
+ Tokenizing input text.
183
+
184
+ INFO 2023-11-02 22:47:56,948 [root:data.py:prepare_input:146]
185
+ Tokenizing input text.
186
+
mkdocs.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ site_name: NewsClassifier Docs
2
+ # site_url:
3
+ repo_url: https://github.com/ManishW315/NewsClassifier
4
+ nav:
5
+ - Home: index.md
6
+ - newsclassifier:
7
+ - config: newsclassifier\config.md
8
+ - data: newsclassifier\data.md
9
+ - models: newsclassifier\models.md
10
+ - train: newsclassifier\train.md
11
+ - tune: newsclassifier\tune.md
12
+ - inference: newsclassifier\inference.md
13
+ # - predict: newsclassifier/predict.md
14
+ # - serve: newsclassifier/serve.md
15
+ - utils: newsclassifier\utils.md
16
+ theme: readthedocs
17
+ plugins:
18
+ - mkdocstrings
19
+ watch:
20
+ - . # reload docs for any file changes
newsclassifier/__init__.py ADDED
File without changes
newsclassifier/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (182 Bytes). View file
 
newsclassifier/__pycache__/config.cpython-310.pyc ADDED
Binary file (2.88 kB). View file
 
newsclassifier/__pycache__/data.cpython-310.pyc ADDED
Binary file (6.76 kB). View file
 
newsclassifier/__pycache__/models.cpython-310.pyc ADDED
Binary file (2.45 kB). View file
 
newsclassifier/__pycache__/predict.cpython-310.pyc ADDED
Binary file (1.31 kB). View file
 
newsclassifier/__pycache__/serve.cpython-310.pyc ADDED
Binary file (1.25 kB). View file
 
newsclassifier/config/__init__.py ADDED
File without changes
newsclassifier/config/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (189 Bytes). View file
 
newsclassifier/config/__pycache__/config.cpython-310.pyc ADDED
Binary file (3.24 kB). View file
 
newsclassifier/config/config.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from dataclasses import dataclass
4
+ from logging.handlers import RotatingFileHandler
5
+ from pathlib import Path
6
+
7
+ import nltk
8
+
9
+ from rich.logging import RichHandler
10
+
11
+ # from nltk.corpus import stopwords
12
+ # nltk.download("stopwords")
13
+
14
+
15
+ @dataclass
16
+ class Cfg:
17
+ STOPWORDS = [
18
+ "i",
19
+ "me",
20
+ "my",
21
+ "myself",
22
+ "we",
23
+ "our",
24
+ "ours",
25
+ "ourselves",
26
+ "you",
27
+ "you're",
28
+ "you've",
29
+ "you'll",
30
+ "you'd",
31
+ "your",
32
+ "yours",
33
+ "yourself",
34
+ "yourselves",
35
+ "he",
36
+ "him",
37
+ "his",
38
+ "himself",
39
+ "she",
40
+ "she's",
41
+ "her",
42
+ "hers",
43
+ "herself",
44
+ "it",
45
+ "it's",
46
+ "its",
47
+ "itself",
48
+ "they",
49
+ "them",
50
+ "their",
51
+ "theirs",
52
+ "themselves",
53
+ "what",
54
+ "which",
55
+ "who",
56
+ "whom",
57
+ "this",
58
+ "that",
59
+ "that'll",
60
+ "these",
61
+ "those",
62
+ "am",
63
+ "is",
64
+ "are",
65
+ "was",
66
+ "were",
67
+ "be",
68
+ "been",
69
+ "being",
70
+ "have",
71
+ "has",
72
+ "had",
73
+ "having",
74
+ "do",
75
+ "does",
76
+ "did",
77
+ "doing",
78
+ "a",
79
+ "an",
80
+ "the",
81
+ "and",
82
+ "but",
83
+ "if",
84
+ "or",
85
+ "because",
86
+ "as",
87
+ "until",
88
+ "while",
89
+ "of",
90
+ "at",
91
+ "by",
92
+ "for",
93
+ "with",
94
+ "about",
95
+ "against",
96
+ "between",
97
+ "into",
98
+ "through",
99
+ "during",
100
+ "before",
101
+ "after",
102
+ "above",
103
+ "below",
104
+ "to",
105
+ "from",
106
+ "up",
107
+ "down",
108
+ "in",
109
+ "out",
110
+ "on",
111
+ "off",
112
+ "over",
113
+ "under",
114
+ "again",
115
+ "further",
116
+ "then",
117
+ "once",
118
+ "here",
119
+ "there",
120
+ "when",
121
+ "where",
122
+ "why",
123
+ "how",
124
+ "all",
125
+ "any",
126
+ "both",
127
+ "each",
128
+ "few",
129
+ "more",
130
+ "most",
131
+ "other",
132
+ "some",
133
+ "such",
134
+ "no",
135
+ "nor",
136
+ "not",
137
+ "only",
138
+ "own",
139
+ "same",
140
+ "so",
141
+ "than",
142
+ "too",
143
+ "very",
144
+ "s",
145
+ "t",
146
+ "can",
147
+ "will",
148
+ "just",
149
+ "don",
150
+ "don't",
151
+ "should",
152
+ "should've",
153
+ "now",
154
+ "d",
155
+ "ll",
156
+ "m",
157
+ "o",
158
+ "re",
159
+ "ve",
160
+ "y",
161
+ "ain",
162
+ "aren",
163
+ "aren't",
164
+ "couldn",
165
+ "couldn't",
166
+ "didn",
167
+ "didn't",
168
+ "doesn",
169
+ "doesn't",
170
+ "hadn",
171
+ "hadn't",
172
+ "hasn",
173
+ "hasn't",
174
+ "haven",
175
+ "haven't",
176
+ "isn",
177
+ "isn't",
178
+ "ma",
179
+ "mightn",
180
+ "mightn't",
181
+ "mustn",
182
+ "mustn't",
183
+ "needn",
184
+ "needn't",
185
+ "shan",
186
+ "shan't",
187
+ "shouldn",
188
+ "shouldn't",
189
+ "wasn",
190
+ "wasn't",
191
+ "weren",
192
+ "weren't",
193
+ "won",
194
+ "won't",
195
+ "wouldn",
196
+ "wouldn't",
197
+ ]
198
+
199
+ dataset_loc = os.path.join((Path(__file__).parent.parent.parent), "dataset", "raw", "news_dataset.csv")
200
+ preprocessed_data_path = os.path.join((Path(__file__).parent.parent.parent), "dataset", "preprocessed")
201
+ sweep_config_path = os.path.join((Path(__file__).parent), "sweep_config.yaml")
202
+
203
+ # Logs path
204
+ logs_path = os.path.join((Path(__file__).parent.parent.parent), "logs")
205
+ artifacts_path = os.path.join((Path(__file__).parent.parent.parent), "artifacts")
206
+ model_path = os.path.join((Path(__file__).parent.parent.parent), "artifacts", "model.pt")
207
+
208
+ test_size = 0.2
209
+
210
+ add_special_tokens = True
211
+ max_len = 50
212
+ pad_to_max_length = True
213
+ truncation = True
214
+
215
+ change_config = False
216
+
217
+ dropout_pb = 0.5
218
+ lr = 1e-4
219
+ lr_redfactor = 0.7
220
+ lr_redpatience = 4
221
+ epochs = 10
222
+ batch_size = 128
223
+ num_classes = 7
224
+
225
+ sweep_run = 10
226
+
227
+ index_to_class = {0: "Business", 1: "Entertainment", 2: "Health", 3: "Science", 4: "Sports", 5: "Technology", 6: "Worldwide"}
228
+
229
+
230
+ # Create logs folder
231
+ os.makedirs(Cfg.logs_path, exist_ok=True)
232
+
233
+ # Get root logger
234
+ logger = logging.getLogger()
235
+ logger.setLevel(logging.INFO)
236
+
237
+ # Create handlers
238
+ console_handler = RichHandler(markup=True)
239
+ console_handler.setLevel(logging.INFO)
240
+
241
+ info_handler = RotatingFileHandler(
242
+ filename=Path(Cfg.logs_path, "info.log"),
243
+ maxBytes=10485760, # 1 MB
244
+ backupCount=10,
245
+ )
246
+ info_handler.setLevel(logging.INFO)
247
+
248
+ error_handler = RotatingFileHandler(
249
+ filename=Path(Cfg.logs_path, "error.log"),
250
+ maxBytes=10485760, # 1 MB
251
+ backupCount=10,
252
+ )
253
+ error_handler.setLevel(logging.ERROR)
254
+
255
+ # Create formatters
256
+ minimal_formatter = logging.Formatter(fmt="%(message)s")
257
+ detailed_formatter = logging.Formatter(fmt="%(levelname)s %(asctime)s [%(name)s:%(filename)s:%(funcName)s:%(lineno)d]\n%(message)s\n")
258
+
259
+ # Hook it all up
260
+ console_handler.setFormatter(fmt=minimal_formatter)
261
+ info_handler.setFormatter(fmt=detailed_formatter)
262
+ error_handler.setFormatter(fmt=detailed_formatter)
263
+ logger.addHandler(hdlr=console_handler)
264
+ logger.addHandler(hdlr=info_handler)
265
+ logger.addHandler(hdlr=error_handler)
newsclassifier/config/sweep_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method: random
2
+ metric:
3
+ name: val_loss
4
+ goal: minimize
5
+ parameters:
6
+ dropout_pb:
7
+ values: [0.3, 0.4, 0.5]
8
+ learning_rate:
9
+ values: [0.0001, 0.001, 0.01]
10
+ batch_size:
11
+ values: [32, 64, 128]
12
+ lr_reduce_factor:
13
+ values: [0.5, 0.6, 0.7, 0.8]
14
+ lr_reduce_patience:
15
+ values: [2, 3, 4, 5]
16
+ epochs:
17
+ value: 1
newsclassifier/data.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import Dict, Tuple
4
+ from warnings import filterwarnings
5
+
6
+ import pandas as pd
7
+ from sklearn.model_selection import train_test_split
8
+
9
+ import torch
10
+ from newsclassifier.config.config import Cfg, logger
11
+ from torch.utils.data import Dataset
12
+ from transformers import RobertaTokenizer
13
+
14
+ filterwarnings("ignore")
15
+
16
+
17
+ def load_dataset(filepath: str, print_i: int = 0) -> pd.DataFrame:
18
+ """load data from source into a Pandas DataFrame.
19
+
20
+ Args:
21
+ filepath (str): file location.
22
+ print_i (int): Print number of instances.
23
+
24
+ Returns:
25
+ pd.DataFrame: Pandas DataFrame of the data.
26
+ """
27
+ logger.info("Loading Data.")
28
+ df = pd.read_csv(filepath)
29
+ if print_i:
30
+ print(df.head(print_i), "\n")
31
+ return df
32
+
33
+
34
+ def prepare_data(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
35
+ """Separate headlines instance and feature selection.
36
+
37
+ Args:
38
+ df: original dataframe.
39
+
40
+ Returns:
41
+ df: new dataframe with appropriate features.
42
+ headlines_df: dataframe cintaining "headlines" category instances.
43
+ """
44
+ logger.info("Preparing Data.")
45
+ try:
46
+ df = df[["Title", "Category"]]
47
+ df.rename(columns={"Title": "Text"}, inplace=True)
48
+ df, headlines_df = df[df["Category"] != "Headlines"].reset_index(drop=True), df[df["Category"] == "Headlines"].reset_index(drop=True)
49
+ except Exception as e:
50
+ logger.error(e)
51
+
52
+ return df, headlines_df
53
+
54
+
55
+ def clean_text(text: str) -> str:
56
+ """Clean text (lower, puntuations removal, blank space removal)."""
57
+ # lower case the text
58
+ logger.info("Cleaning input text.")
59
+ text = text.lower() # necessary to do before as stopwords are in lower case
60
+
61
+ # remove stopwords
62
+ stp_pattern = re.compile(r"\b(" + r"|".join(Cfg.STOPWORDS) + r")\b\s*")
63
+ text = stp_pattern.sub("", text)
64
+
65
+ # custom cleaning
66
+ text = text.strip() # remove space at start or end if any
67
+ text = re.sub(" +", " ", text) # remove extra spaces
68
+ text = re.sub("[^A-Za-z0-9]+", " ", text) # remove characters that are not alphanumeric
69
+
70
+ return text
71
+
72
+
73
+ def preprocess(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame, Dict, Dict]:
74
+ """Preprocess the data.
75
+
76
+ Args:
77
+ df: Dataframe on which the preprocessing steps need to be performed.
78
+
79
+ Returns:
80
+ df: Preprocessed Data.
81
+ class_to_index: class labels to indices mapping
82
+ class_to_index: indices to class labels mapping
83
+ """
84
+ df, headlines_df = prepare_data(df)
85
+
86
+ cats = df["Category"].unique().tolist()
87
+ class_to_index = {tag: i for i, tag in enumerate(cats)}
88
+ index_to_class = {v: k for k, v in class_to_index.items()}
89
+
90
+ df["Text"] = df["Text"].apply(clean_text) # clean text
91
+ df = df[["Text", "Category"]]
92
+ try:
93
+ df["Category"] = df["Category"].map(class_to_index) # label encoding
94
+ except Exception as e:
95
+ logger.error(e)
96
+ return df, headlines_df, class_to_index, index_to_class
97
+
98
+
99
+ def data_split(df: pd.DataFrame, split_size: float = 0.2, stratify_on_target: bool = True, save_dfs: bool = False):
100
+ """Split data into train and test sets.
101
+
102
+ Args:
103
+ df (pd.DataFrame): Data to be split.
104
+ split_size (float): train-test split ratio (test ratio).
105
+ stratify_on_target (bool): Whether to do stratify split on target.
106
+ target_sep (bool): Whether to do target setting for train and test sets.
107
+ save_dfs (bool): Whether to save dataset splits in artifacts.
108
+
109
+ Returns:
110
+ train-test splits (with/without target setting)
111
+ """
112
+ logger.info("Splitting Data.")
113
+ try:
114
+ if stratify_on_target:
115
+ stra = df["Category"]
116
+ else:
117
+ stra = None
118
+
119
+ train, test = train_test_split(df, test_size=split_size, random_state=42, stratify=stra)
120
+ train_ds = pd.DataFrame(train, columns=df.columns)
121
+ test_ds = pd.DataFrame(test, columns=df.columns)
122
+
123
+ if save_dfs:
124
+ logger.info("Saving and storing data splits.")
125
+
126
+ os.makedirs(Cfg.preprocessed_data_path, exist_ok=True)
127
+ train.to_csv(os.path.join(Cfg.preprocessed_data_path, "train.csv"))
128
+ test.to_csv(os.path.join(Cfg.preprocessed_data_path, "test.csv"))
129
+ except Exception as e:
130
+ logger.error(e)
131
+
132
+ return train_ds, test_ds
133
+
134
+
135
+ def prepare_input(tokenizer: RobertaTokenizer, text: str) -> Dict:
136
+ """Tokenize and prepare the input text using the provided tokenizer.
137
+
138
+ Args:
139
+ tokenizer (RobertaTokenizer): The Roberta tokenizer to encode the input.
140
+ text (str): The input text to be tokenized.
141
+
142
+ Returns:
143
+ inputs (dict): A dictionary containing the tokenized input with keys such as 'input_ids',
144
+ 'attention_mask', etc.
145
+ """
146
+ logger.info("Tokenizing input text.")
147
+ inputs = tokenizer.encode_plus(
148
+ text,
149
+ return_tensors=None,
150
+ add_special_tokens=Cfg.add_special_tokens,
151
+ max_length=Cfg.max_len,
152
+ pad_to_max_length=Cfg.pad_to_max_length,
153
+ truncation=Cfg.truncation,
154
+ )
155
+ for k, v in inputs.items():
156
+ inputs[k] = torch.tensor(v, dtype=torch.long)
157
+ return inputs
158
+
159
+
160
+ class NewsDataset(Dataset):
161
+ def __init__(self, ds):
162
+ self.texts = ds["Text"].values
163
+ self.labels = ds["Category"].values
164
+
165
+ def __len__(self):
166
+ return len(self.texts)
167
+
168
+ def __getitem__(self, item):
169
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
170
+ inputs = prepare_input(tokenizer, self.texts[item])
171
+ labels = torch.tensor(self.labels[item], dtype=torch.float)
172
+ return inputs, labels
173
+
174
+
175
+ def collate(inputs: Dict) -> Dict:
176
+ """Collate and modify the input dictionary to have the same sequence length for a particular input batch.
177
+
178
+ Args:
179
+ inputs (dict): A dictionary containing input tensors with varying sequence lengths.
180
+
181
+ Returns:
182
+ modified_inputs (dict): A modified dictionary with input tensors trimmed to have the same sequence length.
183
+ """
184
+ max_len = int(inputs["input_ids"].sum(axis=1).max())
185
+ for k, v in inputs.items():
186
+ inputs[k] = inputs[k][:, :max_len]
187
+ return inputs
188
+
189
+
190
+ if __name__ == "__main__":
191
+ df = load_dataset(Cfg.dataset_loc)
192
+ df, headlines_df, class_to_index, index_to_class = preprocess(df)
193
+ print(df)
194
+ print(class_to_index)
195
+ train_ds, val_ds = data_split(df, save_dfs=True)
196
+ dataset = NewsDataset(df)
197
+ print(dataset.__getitem__(0))
newsclassifier/inference.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+
4
+ import numpy as np
5
+ from sklearn.metrics import (accuracy_score, f1_score, precision_score,
6
+ recall_score)
7
+ from tqdm.auto import tqdm
8
+
9
+ import torch
10
+ from newsclassifier.config.config import Cfg, logger
11
+ from newsclassifier.data import NewsDataset, collate
12
+ from newsclassifier.models import CustomModel
13
+ from torch.utils.data import DataLoader
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+
18
+ def test_step(test_loader: DataLoader, model) -> Tuple[np.ndarray, np.ndarray]:
19
+ """Eval step."""
20
+ model.eval()
21
+ y_trues, y_preds = [], []
22
+ with torch.inference_mode():
23
+ for step, (inputs, labels) in tqdm(enumerate(test_loader)):
24
+ inputs = collate(inputs)
25
+ for k, v in inputs.items():
26
+ inputs[k] = v.to(device)
27
+ labels = labels.to(device)
28
+ y_pred = model(inputs)
29
+ y_trues.extend(labels.cpu().numpy())
30
+ y_preds.extend(torch.argmax(y_pred, dim=1).cpu().numpy())
31
+ return np.vstack(y_trues), np.vstack(y_preds)
32
+
33
+
34
+ def inference():
35
+ logger.info("Loading inference data.")
36
+ try:
37
+ test_dataset = NewsDataset(os.path.join(Cfg.preprocessed_data_path, "test.csv"))
38
+ test_loader = DataLoader(test_dataset, batch_size=Cfg.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
39
+ except Exception as e:
40
+ logger.error(e)
41
+
42
+ logger.info("loading model.")
43
+ try:
44
+ model = CustomModel(num_classes=Cfg.num_classes)
45
+ model.load_state_dict(torch.load(Cfg.model_path, map_location=torch.device("cpu")))
46
+ model.to(device)
47
+ except Exception as e:
48
+ logger.error(e)
49
+
50
+ y_true, y_pred = test_step(test_loader, model)
51
+
52
+ print(
53
+ f'Precision: {precision_score(y_true, y_pred, average="weighted")} \n Recall: {recall_score(y_true, y_pred, average="weighted")} \n F1: {f1_score(y_true, y_pred, average="weighted")} \n Accuracy: {accuracy_score(y_true, y_pred)}'
54
+ )
newsclassifier/models.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from transformers import RobertaModel
9
+
10
+
11
+ class CustomModel(nn.Module):
12
+ def __init__(self, num_classes, change_config=False, dropout_pb=0.0):
13
+ super(CustomModel, self).__init__()
14
+ if change_config:
15
+ pass
16
+ self.model = RobertaModel.from_pretrained("roberta-base")
17
+ self.hidden_size = self.model.config.hidden_size
18
+ self.num_classes = num_classes
19
+ self.dropout_pb = dropout_pb
20
+ self.dropout = torch.nn.Dropout(self.dropout_pb)
21
+ self.fc = nn.Linear(self.hidden_size, self.num_classes)
22
+
23
+ def forward(self, inputs):
24
+ output = self.model(**inputs)
25
+ z = self.dropout(output[1])
26
+ z = self.fc(z)
27
+ return z
28
+
29
+ @torch.inference_mode()
30
+ def predict(self, inputs):
31
+ self.eval()
32
+ z = self(inputs)
33
+ y_pred = torch.argmax(z, dim=1).cpu().numpy()
34
+ return y_pred
35
+
36
+ @torch.inference_mode()
37
+ def predict_proba(self, inputs):
38
+ self.eval()
39
+ z = self(inputs)
40
+ y_probs = F.softmax(z, dim=1).cpu().numpy()
41
+ return y_probs
42
+
43
+ def save(self, dp):
44
+ with open(Path(dp, "args.json"), "w") as fp:
45
+ contents = {
46
+ "dropout_pb": self.dropout_pb,
47
+ "hidden_size": self.hidden_size,
48
+ "num_classes": self.num_classes,
49
+ }
50
+ json.dump(contents, fp, indent=4, sort_keys=False)
51
+ torch.save(self.state_dict(), os.path.join(dp, "model.pt"))
52
+
53
+ @classmethod
54
+ def load(cls, args_fp, state_dict_fp):
55
+ with open(args_fp, "r") as fp:
56
+ kwargs = json.load(fp=fp)
57
+ llm = RobertaModel.from_pretrained("roberta-base")
58
+ model = cls(llm=llm, **kwargs)
59
+ model.load_state_dict(torch.load(state_dict_fp, map_location=torch.device("cpu")))
60
+ return model
newsclassifier/predict.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+
5
+ import torch
6
+ from newsclassifier.config.config import Cfg, logger
7
+ from newsclassifier.data import clean_text, prepare_input
8
+ from newsclassifier.models import CustomModel
9
+ from transformers import RobertaTokenizer
10
+
11
+
12
+ def predict(text: str):
13
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
14
+ model = CustomModel(num_classes=7)
15
+ model.load_state_dict(torch.load(os.path.join(Cfg.artifacts_path, "model.pt"), map_location=torch.device("cpu")))
16
+ index_to_class = Cfg.index_to_class
17
+ sample_input = prepare_input(tokenizer, text)
18
+ input_ids = torch.unsqueeze(sample_input["input_ids"], 0).to("cpu")
19
+ attention_masks = torch.unsqueeze(sample_input["attention_mask"], 0).to("cpu")
20
+ test_sample = dict(input_ids=input_ids, attention_mask=attention_masks)
21
+
22
+ with torch.no_grad():
23
+ y_pred_test_sample = model.predict_proba(test_sample)
24
+ prediction = y_pred_test_sample[0]
25
+
26
+ return prediction
27
+
28
+
29
+ if __name__ == "__main__":
30
+ txt = clean_text("Funds punished for owning too few Nvidia")
31
+ pred_prob = predict(txt)
32
+ print(pred_prob)
newsclassifier/train.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import time
4
+ from typing import Tuple
5
+
6
+ import numpy as np
7
+ from tqdm.auto import tqdm
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import wandb
13
+ from newsclassifier.config.config import Cfg, logger
14
+ from newsclassifier.data import (NewsDataset, collate, data_split,
15
+ load_dataset, preprocess)
16
+ from newsclassifier.models import CustomModel
17
+ from torch.utils.data import DataLoader
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+
22
+ def train_step(train_loader: DataLoader, model, num_classes: int, loss_fn, optimizer, epoch: int) -> float:
23
+ """Train step."""
24
+ model.train()
25
+ loss = 0.0
26
+ total_iterations = len(train_loader)
27
+ desc = f"Training - Epoch {epoch+1}"
28
+ for step, (inputs, labels) in tqdm(enumerate(train_loader), total=total_iterations, desc=desc):
29
+ inputs = collate(inputs)
30
+ for k, v in inputs.items():
31
+ inputs[k] = v.to(device)
32
+ labels = labels.to(device)
33
+ optimizer.zero_grad() # reset gradients
34
+ y_pred = model(inputs) # forward pass
35
+ targets = F.one_hot(labels.long(), num_classes=num_classes).float() # one-hot (for loss_fn)
36
+ J = loss_fn(y_pred, targets) # define loss
37
+ J.backward() # backward pass
38
+ optimizer.step() # update weights
39
+ loss += (J.detach().item() - loss) / (step + 1) # cumulative loss
40
+ return loss
41
+
42
+
43
+ def eval_step(val_loader: DataLoader, model, num_classes: int, loss_fn, epoch: int) -> Tuple[float, np.ndarray, np.ndarray]:
44
+ """Eval step."""
45
+ model.eval()
46
+ loss = 0.0
47
+ total_iterations = len(val_loader)
48
+ desc = f"Validation - Epoch {epoch+1}"
49
+ y_trues, y_preds = [], []
50
+ with torch.inference_mode():
51
+ for step, (inputs, labels) in tqdm(enumerate(val_loader), total=total_iterations, desc=desc):
52
+ inputs = collate(inputs)
53
+ for k, v in inputs.items():
54
+ inputs[k] = v.to(device)
55
+ labels = labels.to(device)
56
+ y_pred = model(inputs)
57
+ targets = F.one_hot(labels.long(), num_classes=num_classes).float() # one-hot (for loss_fn)
58
+ J = loss_fn(y_pred, targets).item()
59
+ loss += (J - loss) / (step + 1)
60
+ y_trues.extend(targets.cpu().numpy())
61
+ y_preds.extend(torch.argmax(y_pred, dim=1).cpu().numpy())
62
+ return loss, np.vstack(y_trues), np.vstack(y_preds)
63
+
64
+
65
+ def train_loop(config=None):
66
+ # ====================================================
67
+ # loader
68
+ # ====================================================
69
+
70
+ config = dict(
71
+ batch_size=Cfg.batch_size,
72
+ num_classes=Cfg.num_classes,
73
+ epochs=Cfg.epochs,
74
+ dropout_pb=Cfg.dropout_pb,
75
+ learning_rate=Cfg.lr,
76
+ lr_reduce_factor=Cfg.lr_redfactor,
77
+ lr_reduce_patience=Cfg.lr_redpatience,
78
+ )
79
+
80
+ with wandb.init(project="NewsClassifier", config=config):
81
+ config = wandb.config
82
+
83
+ df = load_dataset(Cfg.dataset_loc)
84
+ ds, headlines_df, class_to_index, index_to_class = preprocess(df)
85
+ train_ds, val_ds = data_split(ds, test_size=Cfg.test_size)
86
+
87
+ logger.info("Preparing Data.")
88
+
89
+ train_dataset = NewsDataset(train_ds)
90
+ valid_dataset = NewsDataset(val_ds)
91
+
92
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
93
+ valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
94
+
95
+ # ====================================================
96
+ # model
97
+ # ====================================================
98
+
99
+ logger.info("Creating Custom Model.")
100
+ num_classes = config.num_classes
101
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
+
103
+ model = CustomModel(num_classes=num_classes, dropout_pb=config.dropout_pb)
104
+ model.to(device)
105
+
106
+ # ====================================================
107
+ # Training components
108
+ # ====================================================
109
+ criterion = nn.BCEWithLogitsLoss()
110
+ optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
111
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
112
+ optimizer, mode="min", factor=config.lr_reduce_factor, patience=config.lr_reduce_patience
113
+ )
114
+
115
+ # ====================================================
116
+ # loop
117
+ # ====================================================
118
+ wandb.watch(model, criterion, log="all", log_freq=10)
119
+
120
+ min_loss = np.inf
121
+ logger.info("Staring Training Loop.")
122
+ for epoch in range(config.epochs):
123
+ try:
124
+ start_time = time.time()
125
+
126
+ # Step
127
+ train_loss = train_step(train_loader, model, num_classes, criterion, optimizer, epoch)
128
+ val_loss, _, _ = eval_step(valid_loader, model, num_classes, criterion, epoch)
129
+ scheduler.step(val_loss)
130
+
131
+ # scoring
132
+ elapsed = time.time() - start_time
133
+ wandb.log({"epoch": epoch + 1, "train_loss": train_loss, "val_loss": val_loss})
134
+ print(f"Epoch {epoch+1} - avg_train_loss: {train_loss:.4f} avg_val_loss: {val_loss:.4f} time: {elapsed:.0f}s")
135
+
136
+ if min_loss > val_loss:
137
+ min_loss = val_loss
138
+ print("Best Score : saving model.")
139
+ os.makedirs(Cfg.artifacts_path, exist_ok=True)
140
+ model.save(Cfg.artifacts_path)
141
+ print(f"\nSaved Best Model Score: {min_loss:.4f}\n\n")
142
+ except Exception as e:
143
+ logger.error(f"Epoch - {epoch+1}, {e}")
144
+
145
+ wandb.save(os.path.join(Cfg.artifacts_path, "model.pt"))
146
+ torch.cuda.empty_cache()
147
+ gc.collect()
148
+
149
+
150
+ if __name__ == "__main__":
151
+ train_loop()
newsclassifier/tune.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import time
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import wandb
10
+ from newsclassifier.config.config import Cfg, logger
11
+ from newsclassifier.data import (NewsDataset, data_split, load_dataset,
12
+ preprocess)
13
+ from newsclassifier.models import CustomModel
14
+ from newsclassifier.train import eval_step, train_step
15
+ from newsclassifier.utils import read_yaml
16
+ from torch.utils.data import DataLoader
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+
21
+ def tune_loop(config=None):
22
+ # ====================================================
23
+ # loader
24
+ # ====================================================
25
+ logger.info("Starting Tuning.")
26
+ with wandb.init(project="NewsClassifier", config=config):
27
+ config = wandb.config
28
+
29
+ df = load_dataset(Cfg.dataset_loc)
30
+ ds, headlines_df, class_to_index, index_to_class = preprocess(df)
31
+ train_ds, val_ds = data_split(ds, test_size=Cfg.test_size)
32
+
33
+ train_dataset = NewsDataset(train_ds)
34
+ valid_dataset = NewsDataset(val_ds)
35
+
36
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
37
+ valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
38
+
39
+ # ====================================================
40
+ # model
41
+ # ====================================================
42
+ num_classes = Cfg.num_classes
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+
45
+ model = CustomModel(num_classes=num_classes, dropout_pb=config.dropout_pb)
46
+ model.to(device)
47
+
48
+ # ====================================================
49
+ # Training components
50
+ # ====================================================
51
+ criterion = nn.BCEWithLogitsLoss()
52
+ optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
53
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
54
+ optimizer, mode="min", factor=config.lr_reduce_factor, patience=config.lr_reduce_patience
55
+ )
56
+
57
+ # ====================================================
58
+ # loop
59
+ # ====================================================
60
+ wandb.watch(model, criterion, log="all", log_freq=10)
61
+
62
+ for epoch in range(config.epochs):
63
+ try:
64
+ start_time = time.time()
65
+
66
+ # Step
67
+ train_loss = train_step(train_loader, model, num_classes, criterion, optimizer, epoch)
68
+ val_loss, _, _ = eval_step(valid_loader, model, num_classes, criterion, epoch)
69
+ scheduler.step(val_loss)
70
+
71
+ # scoring
72
+ elapsed = time.time() - start_time
73
+ wandb.log({"epoch": epoch + 1, "train_loss": train_loss, "val_loss": val_loss})
74
+ print(f"Epoch {epoch+1} - avg_train_loss: {train_loss:.4f} avg_val_loss: {val_loss:.4f} time: {elapsed:.0f}s")
75
+ except Exception as e:
76
+ logger.error(f"Epoch {epoch+1}, {e}")
77
+
78
+ torch.cuda.empty_cache()
79
+ gc.collect()
80
+
81
+
82
+ if __name__ == "__main__":
83
+ sweep_config = read_yaml(Cfg.sweep_config_path)
84
+ sweep_id = wandb.sweep(sweep_config, project="NewsClassifier")
85
+ wandb.agent(sweep_id, tune_loop, count=Cfg.sweep_runs)
newsclassifier/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pandas as pd
4
+ import yaml
5
+
6
+ from newsclassifier.config.config import Cfg, logger
7
+
8
+
9
+ def write_yaml(data: pd.DataFrame, filepath: str):
10
+ logger.info("Writing yaml file.")
11
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
12
+ with open(filepath, "w") as file:
13
+ yaml.dump(data, file, default_flow_style=False)
14
+
15
+
16
+ def read_yaml(file_path: str):
17
+ logger.info("Reading yamlfile")
18
+ with open(file_path, "r") as file:
19
+ params = yaml.load(file, Loader=yaml.FullLoader)
20
+ return params
notebooks/eda.ipynb ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Setup"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 19,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "# Imports\n",
17
+ "import pandas as pd\n",
18
+ "import matplotlib.pyplot as plt\n",
19
+ "import seaborn as sns\n",
20
+ "import ipywidgets as widgets\n",
21
+ "from wordcloud import WordCloud, STOPWORDS"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "markdown",
26
+ "metadata": {},
27
+ "source": [
28
+ "## Data"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 20,
34
+ "metadata": {},
35
+ "outputs": [
36
+ {
37
+ "data": {
38
+ "text/html": [
39
+ "<div>\n",
40
+ "<style scoped>\n",
41
+ " .dataframe tbody tr th:only-of-type {\n",
42
+ " vertical-align: middle;\n",
43
+ " }\n",
44
+ "\n",
45
+ " .dataframe tbody tr th {\n",
46
+ " vertical-align: top;\n",
47
+ " }\n",
48
+ "\n",
49
+ " .dataframe thead th {\n",
50
+ " text-align: right;\n",
51
+ " }\n",
52
+ "</style>\n",
53
+ "<table border=\"1\" class=\"dataframe\">\n",
54
+ " <thead>\n",
55
+ " <tr style=\"text-align: right;\">\n",
56
+ " <th></th>\n",
57
+ " <th>Title</th>\n",
58
+ " <th>Publisher</th>\n",
59
+ " <th>DateTime</th>\n",
60
+ " <th>Link</th>\n",
61
+ " <th>Category</th>\n",
62
+ " </tr>\n",
63
+ " </thead>\n",
64
+ " <tbody>\n",
65
+ " <tr>\n",
66
+ " <th>0</th>\n",
67
+ " <td>Chainlink (LINK) Falters, Hedera (HBAR) Wobble...</td>\n",
68
+ " <td>Analytics Insight</td>\n",
69
+ " <td>2023-08-30T06:54:49Z</td>\n",
70
+ " <td>https://news.google.com/articles/CBMibGh0dHBzO...</td>\n",
71
+ " <td>Business</td>\n",
72
+ " </tr>\n",
73
+ " <tr>\n",
74
+ " <th>1</th>\n",
75
+ " <td>Funds punished for owning too few Nvidia share...</td>\n",
76
+ " <td>ZAWYA</td>\n",
77
+ " <td>2023-08-30T07:15:59Z</td>\n",
78
+ " <td>https://news.google.com/articles/CBMigwFodHRwc...</td>\n",
79
+ " <td>Business</td>\n",
80
+ " </tr>\n",
81
+ " <tr>\n",
82
+ " <th>2</th>\n",
83
+ " <td>Crude oil prices stalled as hedge funds sold: ...</td>\n",
84
+ " <td>ZAWYA</td>\n",
85
+ " <td>2023-08-30T07:31:31Z</td>\n",
86
+ " <td>https://news.google.com/articles/CBMibGh0dHBzO...</td>\n",
87
+ " <td>Business</td>\n",
88
+ " </tr>\n",
89
+ " <tr>\n",
90
+ " <th>3</th>\n",
91
+ " <td>Grayscale's Bitcoin Win Is Still Only Half the...</td>\n",
92
+ " <td>Bloomberg</td>\n",
93
+ " <td>2023-08-30T10:38:40Z</td>\n",
94
+ " <td>https://news.google.com/articles/CBMib2h0dHBzO...</td>\n",
95
+ " <td>Business</td>\n",
96
+ " </tr>\n",
97
+ " <tr>\n",
98
+ " <th>4</th>\n",
99
+ " <td>I'm a Home Shopping Editor, and These Are the ...</td>\n",
100
+ " <td>Better Homes &amp; Gardens</td>\n",
101
+ " <td>2023-08-30T11:00:00Z</td>\n",
102
+ " <td>https://news.google.com/articles/CBMiPWh0dHBzO...</td>\n",
103
+ " <td>Business</td>\n",
104
+ " </tr>\n",
105
+ " </tbody>\n",
106
+ "</table>\n",
107
+ "</div>"
108
+ ],
109
+ "text/plain": [
110
+ " Title Publisher \\\n",
111
+ "0 Chainlink (LINK) Falters, Hedera (HBAR) Wobble... Analytics Insight \n",
112
+ "1 Funds punished for owning too few Nvidia share... ZAWYA \n",
113
+ "2 Crude oil prices stalled as hedge funds sold: ... ZAWYA \n",
114
+ "3 Grayscale's Bitcoin Win Is Still Only Half the... Bloomberg \n",
115
+ "4 I'm a Home Shopping Editor, and These Are the ... Better Homes & Gardens \n",
116
+ "\n",
117
+ " DateTime Link \\\n",
118
+ "0 2023-08-30T06:54:49Z https://news.google.com/articles/CBMibGh0dHBzO... \n",
119
+ "1 2023-08-30T07:15:59Z https://news.google.com/articles/CBMigwFodHRwc... \n",
120
+ "2 2023-08-30T07:31:31Z https://news.google.com/articles/CBMibGh0dHBzO... \n",
121
+ "3 2023-08-30T10:38:40Z https://news.google.com/articles/CBMib2h0dHBzO... \n",
122
+ "4 2023-08-30T11:00:00Z https://news.google.com/articles/CBMiPWh0dHBzO... \n",
123
+ "\n",
124
+ " Category \n",
125
+ "0 Business \n",
126
+ "1 Business \n",
127
+ "2 Business \n",
128
+ "3 Business \n",
129
+ "4 Business "
130
+ ]
131
+ },
132
+ "execution_count": 20,
133
+ "metadata": {},
134
+ "output_type": "execute_result"
135
+ }
136
+ ],
137
+ "source": [
138
+ "# Data Ingestion\n",
139
+ "df = pd.read_csv(\"../dataset/news_dataset.csv\")\n",
140
+ "df.head()"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 21,
146
+ "metadata": {},
147
+ "outputs": [
148
+ {
149
+ "data": {
150
+ "text/plain": [
151
+ "Text(0.5, 1.0, 'Category Distribution')"
152
+ ]
153
+ },
154
+ "execution_count": 21,
155
+ "metadata": {},
156
+ "output_type": "execute_result"
157
+ },
158
+ {
159
+ "data": {
160
+ "image/png": "",
161
+ "text/plain": [
162
+ "<Figure size 1000x500 with 1 Axes>"
163
+ ]
164
+ },
165
+ "metadata": {},
166
+ "output_type": "display_data"
167
+ }
168
+ ],
169
+ "source": [
170
+ "# Distribution Bar plot (Count plot)\n",
171
+ "plt.figure(figsize=(10, 5))\n",
172
+ "sns.barplot(x=df[\"Category\"].value_counts().index, y=df[\"Category\"].value_counts())\n",
173
+ "plt.ylabel(\"Number of News\")\n",
174
+ "plt.title(\"Category Distribution\")"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "markdown",
179
+ "metadata": {},
180
+ "source": [
181
+ "**There's no extreme data imbalance except \"Health\" and \"Science\" news are almost half the \"Sports\" (majority) news.**"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": 22,
187
+ "metadata": {},
188
+ "outputs": [
189
+ {
190
+ "data": {
191
+ "application/vnd.jupyter.widget-view+json": {
192
+ "model_id": "8368a3df9eea413b99d2d0c5876fbcf6",
193
+ "version_major": 2,
194
+ "version_minor": 0
195
+ },
196
+ "text/plain": [
197
+ "interactive(children=(Dropdown(description='category', options=('Business', 'Entertainment', 'Headlines', 'Hea…"
198
+ ]
199
+ },
200
+ "metadata": {},
201
+ "output_type": "display_data"
202
+ }
203
+ ],
204
+ "source": [
205
+ "# Word cloud\n",
206
+ "categories = df[\"Category\"].unique().tolist()\n",
207
+ "\n",
208
+ "\n",
209
+ "@widgets.interact(category=categories)\n",
210
+ "def display_categotical_plots(category=categories[0]):\n",
211
+ " subset = df[df[\"Category\"] == category].sample(n=100, random_state=42)\n",
212
+ " text = subset[\"Title\"].values\n",
213
+ " cloud = WordCloud(stopwords=STOPWORDS, background_color=\"black\", collocations=False, width=600, height=400).generate(\" \".join(text))\n",
214
+ " plt.axis(\"off\")\n",
215
+ " plt.imshow(cloud)"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "metadata": {},
221
+ "source": [
222
+ "**From the word cloud we can immediately draw one insight about the redundant key words like \"New\" which is coming a lot in different categories.**</br>\n",
223
+ "We can also see some action verbs, adjectives, adverbs which need to be removed to some extent before training the model.**</br>\n",
224
+ "Other than that the word cloud seems very intuitive to what the respective categorical tag/name is.</br></br>\n",
225
+ "We can also see the \"Headlines\" category contains mixed words (will be mixed as it can be a ground breaking news of any category), so we'll hold out those data instances as a test set without targets just to analyze the number of headlines with different categories."
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "metadata": {},
232
+ "outputs": [],
233
+ "source": []
234
+ }
235
+ ],
236
+ "metadata": {
237
+ "kernelspec": {
238
+ "display_name": "news_venv",
239
+ "language": "python",
240
+ "name": "python3"
241
+ },
242
+ "language_info": {
243
+ "codemirror_mode": {
244
+ "name": "ipython",
245
+ "version": 3
246
+ },
247
+ "file_extension": ".py",
248
+ "mimetype": "text/x-python",
249
+ "name": "python",
250
+ "nbconvert_exporter": "python",
251
+ "pygments_lexer": "ipython3",
252
+ "version": "3.10.13"
253
+ }
254
+ },
255
+ "nbformat": 4,
256
+ "nbformat_minor": 2
257
+ }
notebooks/newsclassifier-roberta-base-wandb-track-sweep.ipynb ADDED
@@ -0,0 +1,1035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# NewsClassifier"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 1,
13
+ "metadata": {
14
+ "id": "mtVYEQSYsswc",
15
+ "outputId": "6f16c0c1-ef25-406c-dd14-edd1a72dc760",
16
+ "trusted": true
17
+ },
18
+ "outputs": [
19
+ {
20
+ "name": "stderr",
21
+ "output_type": "stream",
22
+ "text": [
23
+ "[nltk_data] Downloading package stopwords to\n",
24
+ "[nltk_data] C:\\Users\\manis\\AppData\\Roaming\\nltk_data...\n",
25
+ "[nltk_data] Package stopwords is already up-to-date!\n"
26
+ ]
27
+ },
28
+ {
29
+ "data": {
30
+ "text/plain": [
31
+ "True"
32
+ ]
33
+ },
34
+ "execution_count": 1,
35
+ "metadata": {},
36
+ "output_type": "execute_result"
37
+ }
38
+ ],
39
+ "source": [
40
+ "# Imports\n",
41
+ "import os\n",
42
+ "import gc\n",
43
+ "import time\n",
44
+ "from pathlib import Path\n",
45
+ "import json\n",
46
+ "from typing import Tuple, Dict\n",
47
+ "from warnings import filterwarnings\n",
48
+ "\n",
49
+ "filterwarnings(\"ignore\")\n",
50
+ "\n",
51
+ "import pandas as pd\n",
52
+ "import numpy as np\n",
53
+ "from sklearn.model_selection import train_test_split\n",
54
+ "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
55
+ "\n",
56
+ "import matplotlib.pyplot as plt\n",
57
+ "import seaborn as sns\n",
58
+ "import ipywidgets as widgets\n",
59
+ "from wordcloud import WordCloud, STOPWORDS\n",
60
+ "\n",
61
+ "from tqdm.auto import tqdm\n",
62
+ "from dataclasses import dataclass\n",
63
+ "\n",
64
+ "import re\n",
65
+ "import nltk\n",
66
+ "from nltk.corpus import stopwords\n",
67
+ "\n",
68
+ "import torch\n",
69
+ "import torch.nn as nn\n",
70
+ "import torch.nn.functional as F\n",
71
+ "from torch.utils.data import DataLoader, Dataset\n",
72
+ "\n",
73
+ "from transformers import RobertaTokenizer, RobertaModel\n",
74
+ "\n",
75
+ "import wandb\n",
76
+ "\n",
77
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
78
+ "\n",
79
+ "nltk.download(\"stopwords\")"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": 2,
85
+ "metadata": {
86
+ "trusted": true
87
+ },
88
+ "outputs": [
89
+ {
90
+ "name": "stderr",
91
+ "output_type": "stream",
92
+ "text": [
93
+ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n"
94
+ ]
95
+ },
96
+ {
97
+ "name": "stderr",
98
+ "output_type": "stream",
99
+ "text": [
100
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmanishdrw1\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
101
+ ]
102
+ },
103
+ {
104
+ "data": {
105
+ "text/plain": [
106
+ "True"
107
+ ]
108
+ },
109
+ "execution_count": 2,
110
+ "metadata": {},
111
+ "output_type": "execute_result"
112
+ }
113
+ ],
114
+ "source": [
115
+ "wandb.login()"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": 3,
121
+ "metadata": {
122
+ "id": "fGW_WYn31JHT",
123
+ "trusted": true
124
+ },
125
+ "outputs": [],
126
+ "source": [
127
+ "@dataclass\n",
128
+ "class Cfg:\n",
129
+ " STOPWORDS = stopwords.words(\"english\")\n",
130
+ " dataset_loc = \"../dataset/raw/news_dataset.csv\"\n",
131
+ " test_size = 0.2\n",
132
+ "\n",
133
+ " add_special_tokens = True\n",
134
+ " max_len = 50\n",
135
+ " pad_to_max_length = True\n",
136
+ " truncation = True\n",
137
+ "\n",
138
+ " change_config = False\n",
139
+ "\n",
140
+ " dropout_pb = 0.5\n",
141
+ " lr = 1e-4\n",
142
+ " lr_redfactor = 0.7\n",
143
+ " lr_redpatience = 4\n",
144
+ " epochs = 10\n",
145
+ " batch_size = 128\n",
146
+ "\n",
147
+ " wandb_sweep = False"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 13,
153
+ "metadata": {
154
+ "id": "7V5OJWw4sswg",
155
+ "outputId": "8eb13263-d31a-4d49-f1f6-3c2dc0595c78",
156
+ "trusted": true
157
+ },
158
+ "outputs": [
159
+ {
160
+ "name": "stdout",
161
+ "output_type": "stream",
162
+ "text": [
163
+ "Matthew McConaughey Gives Joy Behar A Foot Massage On ‘The View’\n",
164
+ "Entertainment\n"
165
+ ]
166
+ }
167
+ ],
168
+ "source": [
169
+ "df = pd.read_csv(Cfg.dataset_loc)\n",
170
+ "print(df[\"Title\"][10040])\n",
171
+ "print(df[\"Category\"][10040])"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "metadata": {
177
+ "id": "w05pkO5RN1H2"
178
+ },
179
+ "source": [
180
+ "## Prepare Data"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 14,
186
+ "metadata": {
187
+ "id": "l8Z3Hhk3sswg",
188
+ "trusted": true
189
+ },
190
+ "outputs": [],
191
+ "source": [
192
+ "def prepare_data(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:\n",
193
+ " \"\"\"Separate headlines instance and feature selection.\n",
194
+ "\n",
195
+ " Args:\n",
196
+ " df: original dataframe.\n",
197
+ "\n",
198
+ " Returns:\n",
199
+ " df: new dataframe with appropriate features.\n",
200
+ " headlines_df: dataframe cintaining \"headlines\" category instances.\n",
201
+ " \"\"\"\n",
202
+ " df = df[[\"Title\", \"Category\"]]\n",
203
+ " df.rename(columns={\"Title\": \"Text\"}, inplace=True)\n",
204
+ " df, headlines_df = df[df[\"Category\"] != \"Headlines\"].reset_index(drop=True), df[df[\"Category\"] == \"Headlines\"].reset_index(drop=True)\n",
205
+ "\n",
206
+ " return df, headlines_df"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": 15,
212
+ "metadata": {
213
+ "id": "d4t7JjIEsswg",
214
+ "trusted": true
215
+ },
216
+ "outputs": [],
217
+ "source": [
218
+ "def clean_text(text: str) -> str:\n",
219
+ " \"\"\"Clean text (lower, puntuations removal, blank space removal).\"\"\"\n",
220
+ " # lower case the text\n",
221
+ " text = text.lower() # necessary to do before as stopwords are in lower case\n",
222
+ "\n",
223
+ " # remove stopwords\n",
224
+ " stp_pattern = re.compile(r\"\\b(\" + r\"|\".join(Cfg.STOPWORDS) + r\")\\b\\s*\")\n",
225
+ " text = stp_pattern.sub(\"\", text)\n",
226
+ "\n",
227
+ " # custom cleaning\n",
228
+ " text = text.strip() # remove space at start or end if any\n",
229
+ " text = re.sub(\" +\", \" \", text) # remove extra spaces\n",
230
+ " text = re.sub(\"[^A-Za-z0-9]+\", \" \", text) # remove characters that are not alphanumeric\n",
231
+ "\n",
232
+ " return text"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": 16,
238
+ "metadata": {
239
+ "id": "NokmvVFusswh",
240
+ "trusted": true
241
+ },
242
+ "outputs": [],
243
+ "source": [
244
+ "def preprocess(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict, Dict]:\n",
245
+ " \"\"\"Preprocess the data.\n",
246
+ "\n",
247
+ " Args:\n",
248
+ " df: Dataframe on which the preprocessing steps need to be performed.\n",
249
+ "\n",
250
+ " Returns:\n",
251
+ " df: Preprocessed Data.\n",
252
+ " class_to_index: class labels to indices mapping\n",
253
+ " class_to_index: indices to class labels mapping\n",
254
+ " \"\"\"\n",
255
+ " df, headlines_df = prepare_data(df)\n",
256
+ "\n",
257
+ " cats = df[\"Category\"].unique().tolist()\n",
258
+ " num_classes = len(cats)\n",
259
+ " class_to_index = {tag: i for i, tag in enumerate(cats)}\n",
260
+ " index_to_class = {v: k for k, v in class_to_index.items()}\n",
261
+ "\n",
262
+ " df[\"Text\"] = df[\"Text\"].apply(clean_text) # clean text\n",
263
+ " df = df[[\"Text\", \"Category\"]]\n",
264
+ " df[\"Category\"] = df[\"Category\"].map(class_to_index) # label encoding\n",
265
+ " return df, class_to_index, index_to_class"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": 17,
271
+ "metadata": {
272
+ "id": "f45cNikCsswh",
273
+ "outputId": "880e338e-11a3-4048-ccf7-d30bf13e996b",
274
+ "trusted": true
275
+ },
276
+ "outputs": [
277
+ {
278
+ "data": {
279
+ "text/html": [
280
+ "<div>\n",
281
+ "<style scoped>\n",
282
+ " .dataframe tbody tr th:only-of-type {\n",
283
+ " vertical-align: middle;\n",
284
+ " }\n",
285
+ "\n",
286
+ " .dataframe tbody tr th {\n",
287
+ " vertical-align: top;\n",
288
+ " }\n",
289
+ "\n",
290
+ " .dataframe thead th {\n",
291
+ " text-align: right;\n",
292
+ " }\n",
293
+ "</style>\n",
294
+ "<table border=\"1\" class=\"dataframe\">\n",
295
+ " <thead>\n",
296
+ " <tr style=\"text-align: right;\">\n",
297
+ " <th></th>\n",
298
+ " <th>Text</th>\n",
299
+ " <th>Category</th>\n",
300
+ " </tr>\n",
301
+ " </thead>\n",
302
+ " <tbody>\n",
303
+ " <tr>\n",
304
+ " <th>0</th>\n",
305
+ " <td>chainlink link falters hedera hbar wobbles yet...</td>\n",
306
+ " <td>0</td>\n",
307
+ " </tr>\n",
308
+ " <tr>\n",
309
+ " <th>1</th>\n",
310
+ " <td>funds punished owning nvidia shares stunning 2...</td>\n",
311
+ " <td>0</td>\n",
312
+ " </tr>\n",
313
+ " <tr>\n",
314
+ " <th>2</th>\n",
315
+ " <td>crude oil prices stalled hedge funds sold kemp</td>\n",
316
+ " <td>0</td>\n",
317
+ " </tr>\n",
318
+ " <tr>\n",
319
+ " <th>3</th>\n",
320
+ " <td>grayscale bitcoin win still half battle</td>\n",
321
+ " <td>0</td>\n",
322
+ " </tr>\n",
323
+ " <tr>\n",
324
+ " <th>4</th>\n",
325
+ " <td>home shopping editor miss labor day deals eyeing</td>\n",
326
+ " <td>0</td>\n",
327
+ " </tr>\n",
328
+ " <tr>\n",
329
+ " <th>...</th>\n",
330
+ " <td>...</td>\n",
331
+ " <td>...</td>\n",
332
+ " </tr>\n",
333
+ " <tr>\n",
334
+ " <th>44142</th>\n",
335
+ " <td>slovakia election could echo ukraine expect</td>\n",
336
+ " <td>6</td>\n",
337
+ " </tr>\n",
338
+ " <tr>\n",
339
+ " <th>44143</th>\n",
340
+ " <td>things know nobel prizes washington post</td>\n",
341
+ " <td>6</td>\n",
342
+ " </tr>\n",
343
+ " <tr>\n",
344
+ " <th>44144</th>\n",
345
+ " <td>brief calm protests killing 2 students rock im...</td>\n",
346
+ " <td>6</td>\n",
347
+ " </tr>\n",
348
+ " <tr>\n",
349
+ " <th>44145</th>\n",
350
+ " <td>one safe france vows action bedbugs sweep paris</td>\n",
351
+ " <td>6</td>\n",
352
+ " </tr>\n",
353
+ " <tr>\n",
354
+ " <th>44146</th>\n",
355
+ " <td>slovakia election polls open knife edge vote u...</td>\n",
356
+ " <td>6</td>\n",
357
+ " </tr>\n",
358
+ " </tbody>\n",
359
+ "</table>\n",
360
+ "<p>44147 rows × 2 columns</p>\n",
361
+ "</div>"
362
+ ],
363
+ "text/plain": [
364
+ " Text Category\n",
365
+ "0 chainlink link falters hedera hbar wobbles yet... 0\n",
366
+ "1 funds punished owning nvidia shares stunning 2... 0\n",
367
+ "2 crude oil prices stalled hedge funds sold kemp 0\n",
368
+ "3 grayscale bitcoin win still half battle 0\n",
369
+ "4 home shopping editor miss labor day deals eyeing 0\n",
370
+ "... ... ...\n",
371
+ "44142 slovakia election could echo ukraine expect 6\n",
372
+ "44143 things know nobel prizes washington post 6\n",
373
+ "44144 brief calm protests killing 2 students rock im... 6\n",
374
+ "44145 one safe france vows action bedbugs sweep paris 6\n",
375
+ "44146 slovakia election polls open knife edge vote u... 6\n",
376
+ "\n",
377
+ "[44147 rows x 2 columns]"
378
+ ]
379
+ },
380
+ "execution_count": 17,
381
+ "metadata": {},
382
+ "output_type": "execute_result"
383
+ }
384
+ ],
385
+ "source": [
386
+ "ds, class_to_index, index_to_class = preprocess(df)\n",
387
+ "ds"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": null,
393
+ "metadata": {},
394
+ "outputs": [],
395
+ "source": [
396
+ "index_to_class"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": 20,
402
+ "metadata": {
403
+ "id": "zGlMz2UJsswi",
404
+ "trusted": true
405
+ },
406
+ "outputs": [],
407
+ "source": [
408
+ "# Data splits\n",
409
+ "train_ds, val_ds = train_test_split(ds, test_size=Cfg.test_size, stratify=ds[\"Category\"])"
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "code",
414
+ "execution_count": 21,
415
+ "metadata": {
416
+ "id": "zTeAsruMsswi",
417
+ "outputId": "bffed91d-04c6-490e-d682-03537d3182dd",
418
+ "trusted": true
419
+ },
420
+ "outputs": [
421
+ {
422
+ "data": {
423
+ "text/plain": [
424
+ "{'input_ids': tensor([ 0, 462, 25744, 7188, 155, 23, 462, 11485, 112, 2,\n",
425
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
426
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
427
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
428
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
429
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
430
+ " 0, 0])}"
431
+ ]
432
+ },
433
+ "execution_count": 21,
434
+ "metadata": {},
435
+ "output_type": "execute_result"
436
+ }
437
+ ],
438
+ "source": [
439
+ "def prepare_input(tokenizer: RobertaTokenizer, text: str) -> Dict:\n",
440
+ " \"\"\"Tokenize and prepare the input text using the provided tokenizer.\n",
441
+ "\n",
442
+ " Args:\n",
443
+ " tokenizer (RobertaTokenizer): The Roberta tokenizer to encode the input.\n",
444
+ " text (str): The input text to be tokenized.\n",
445
+ "\n",
446
+ " Returns:\n",
447
+ " inputs (dict): A dictionary containing the tokenized input with keys such as 'input_ids',\n",
448
+ " 'attention_mask', etc.\n",
449
+ " \"\"\"\n",
450
+ " inputs = tokenizer.encode_plus(\n",
451
+ " text,\n",
452
+ " return_tensors=None,\n",
453
+ " add_special_tokens=Cfg.add_special_tokens,\n",
454
+ " max_length=Cfg.max_len,\n",
455
+ " pad_to_max_length=Cfg.pad_to_max_length,\n",
456
+ " truncation=Cfg.truncation,\n",
457
+ " )\n",
458
+ " for k, v in inputs.items():\n",
459
+ " inputs[k] = torch.tensor(v, dtype=torch.long)\n",
460
+ " return inputs\n",
461
+ "\n",
462
+ "\n",
463
+ "class NewsDataset(Dataset):\n",
464
+ " def __init__(self, ds):\n",
465
+ " self.texts = ds[\"Text\"].values\n",
466
+ " self.labels = ds[\"Category\"].values\n",
467
+ "\n",
468
+ " def __len__(self):\n",
469
+ " return len(self.texts)\n",
470
+ "\n",
471
+ " def __getitem__(self, item):\n",
472
+ " inputs = prepare_input(tokenizer, self.texts[item])\n",
473
+ " labels = torch.tensor(self.labels[item], dtype=torch.float)\n",
474
+ " return inputs, labels\n",
475
+ "\n",
476
+ "\n",
477
+ "def collate(inputs: Dict) -> Dict:\n",
478
+ " \"\"\"Collate and modify the input dictionary to have the same sequence length for a particular input batch.\n",
479
+ "\n",
480
+ " Args:\n",
481
+ " inputs (dict): A dictionary containing input tensors with varying sequence lengths.\n",
482
+ "\n",
483
+ " Returns:\n",
484
+ " modified_inputs (dict): A modified dictionary with input tensors trimmed to have the same sequence length.\n",
485
+ " \"\"\"\n",
486
+ " max_len = int(inputs[\"input_ids\"].sum(axis=1).max())\n",
487
+ " for k, v in inputs.items():\n",
488
+ " inputs[k] = inputs[k][:, :max_len]\n",
489
+ " return inputs\n",
490
+ "\n",
491
+ "\n",
492
+ "tokenizer = RobertaTokenizer.from_pretrained(\"roberta-base\")\n",
493
+ "\n",
494
+ "sample_input = prepare_input(tokenizer, train_ds[\"Text\"].values[10])\n",
495
+ "sample_input"
496
+ ]
497
+ },
498
+ {
499
+ "cell_type": "markdown",
500
+ "metadata": {
501
+ "id": "-qp-4d-aN503"
502
+ },
503
+ "source": [
504
+ "## Model"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": 22,
510
+ "metadata": {
511
+ "id": "XIJ6ARJfsswj",
512
+ "trusted": true
513
+ },
514
+ "outputs": [],
515
+ "source": [
516
+ "class CustomModel(nn.Module):\n",
517
+ " def __init__(self, num_classes, change_config=False, dropout_pb=0.0):\n",
518
+ " super(CustomModel, self).__init__()\n",
519
+ " if change_config:\n",
520
+ " pass\n",
521
+ " self.model = RobertaModel.from_pretrained(\"roberta-base\")\n",
522
+ " self.hidden_size = self.model.config.hidden_size\n",
523
+ " self.num_classes = num_classes\n",
524
+ " self.dropout_pb = dropout_pb\n",
525
+ " self.dropout = torch.nn.Dropout(self.dropout_pb)\n",
526
+ " self.fc = nn.Linear(self.hidden_size, self.num_classes)\n",
527
+ "\n",
528
+ " def forward(self, inputs):\n",
529
+ " output = self.model(**inputs)\n",
530
+ " z = self.dropout(output[1])\n",
531
+ " z = self.fc(z)\n",
532
+ " return z\n",
533
+ "\n",
534
+ " @torch.inference_mode()\n",
535
+ " def predict(self, inputs):\n",
536
+ " self.eval()\n",
537
+ " z = self(inputs)\n",
538
+ " y_pred = torch.argmax(z, dim=1).cpu().numpy()\n",
539
+ " return y_pred\n",
540
+ "\n",
541
+ " @torch.inference_mode()\n",
542
+ " def predict_proba(self, inputs):\n",
543
+ " self.eval()\n",
544
+ " z = self(inputs)\n",
545
+ " y_probs = F.softmax(z, dim=1).cpu().numpy()\n",
546
+ " return y_probs\n",
547
+ "\n",
548
+ " def save(self, dp):\n",
549
+ " with open(Path(dp, \"args.json\"), \"w\") as fp:\n",
550
+ " contents = {\n",
551
+ " \"dropout_pb\": self.dropout_pb,\n",
552
+ " \"hidden_size\": self.hidden_size,\n",
553
+ " \"num_classes\": self.num_classes,\n",
554
+ " }\n",
555
+ " json.dump(contents, fp, indent=4, sort_keys=False)\n",
556
+ " torch.save(self.state_dict(), os.path.join(dp, \"model.pt\"))\n",
557
+ "\n",
558
+ " @classmethod\n",
559
+ " def load(cls, args_fp, state_dict_fp):\n",
560
+ " with open(args_fp, \"r\") as fp:\n",
561
+ " kwargs = json.load(fp=fp)\n",
562
+ " llm = RobertaModel.from_pretrained(\"roberta-base\")\n",
563
+ " model = cls(llm=llm, **kwargs)\n",
564
+ " model.load_state_dict(torch.load(state_dict_fp, map_location=torch.device(\"cpu\")))\n",
565
+ " return model"
566
+ ]
567
+ },
568
+ {
569
+ "cell_type": "code",
570
+ "execution_count": null,
571
+ "metadata": {
572
+ "id": "YZEM0lIlsswj",
573
+ "outputId": "c05d70cf-e75d-4514-b730-3070484ceee3",
574
+ "trusted": true
575
+ },
576
+ "outputs": [],
577
+ "source": [
578
+ "# Initialize model check\n",
579
+ "num_classes = len(ds[\"Category\"].unique())\n",
580
+ "model = CustomModel(num_classes=num_classes, dropout_pb=Cfg.dropout_pb)\n",
581
+ "print(model.named_parameters)"
582
+ ]
583
+ },
584
+ {
585
+ "cell_type": "markdown",
586
+ "metadata": {
587
+ "id": "ztUd4m9CN8qM"
588
+ },
589
+ "source": [
590
+ "## Training"
591
+ ]
592
+ },
593
+ {
594
+ "cell_type": "code",
595
+ "execution_count": null,
596
+ "metadata": {
597
+ "id": "a3VPiwjqsswk",
598
+ "trusted": true
599
+ },
600
+ "outputs": [],
601
+ "source": [
602
+ "def train_step(train_loader: DataLoader, model, num_classes: int, loss_fn, optimizer, epoch: int) -> float:\n",
603
+ " \"\"\"Train step.\"\"\"\n",
604
+ " model.train()\n",
605
+ " loss = 0.0\n",
606
+ " total_iterations = len(train_loader)\n",
607
+ " desc = f\"Training - Epoch {epoch+1}\"\n",
608
+ " for step, (inputs, labels) in tqdm(enumerate(train_loader), total=total_iterations, desc=desc):\n",
609
+ " inputs = collate(inputs)\n",
610
+ " for k, v in inputs.items():\n",
611
+ " inputs[k] = v.to(device)\n",
612
+ " labels = labels.to(device)\n",
613
+ " optimizer.zero_grad() # reset gradients\n",
614
+ " y_pred = model(inputs) # forward pass\n",
615
+ " targets = F.one_hot(labels.long(), num_classes=num_classes).float() # one-hot (for loss_fn)\n",
616
+ " J = loss_fn(y_pred, targets) # define loss\n",
617
+ " J.backward() # backward pass\n",
618
+ " optimizer.step() # update weights\n",
619
+ " loss += (J.detach().item() - loss) / (step + 1) # cumulative loss\n",
620
+ " return loss\n",
621
+ "\n",
622
+ "\n",
623
+ "def eval_step(val_loader: DataLoader, model, num_classes: int, loss_fn, epoch: int) -> Tuple[float, np.ndarray, np.ndarray]:\n",
624
+ " \"\"\"Eval step.\"\"\"\n",
625
+ " model.eval()\n",
626
+ " loss = 0.0\n",
627
+ " total_iterations = len(val_loader)\n",
628
+ " desc = f\"Validation - Epoch {epoch+1}\"\n",
629
+ " y_trues, y_preds = [], []\n",
630
+ " with torch.inference_mode():\n",
631
+ " for step, (inputs, labels) in tqdm(enumerate(val_loader), total=total_iterations, desc=desc):\n",
632
+ " inputs = collate(inputs)\n",
633
+ " for k, v in inputs.items():\n",
634
+ " inputs[k] = v.to(device)\n",
635
+ " labels = labels.to(device)\n",
636
+ " y_pred = model(inputs)\n",
637
+ " targets = F.one_hot(labels.long(), num_classes=num_classes).float() # one-hot (for loss_fn)\n",
638
+ " J = loss_fn(y_pred, targets).item()\n",
639
+ " loss += (J - loss) / (step + 1)\n",
640
+ " y_trues.extend(targets.cpu().numpy())\n",
641
+ " y_preds.extend(torch.argmax(y_pred, dim=1).cpu().numpy())\n",
642
+ " return loss, np.vstack(y_trues), np.vstack(y_preds)"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "markdown",
647
+ "metadata": {},
648
+ "source": [
649
+ "### Sweep config"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "code",
654
+ "execution_count": null,
655
+ "metadata": {
656
+ "trusted": true
657
+ },
658
+ "outputs": [],
659
+ "source": [
660
+ "sweep_config = {\"method\": \"random\"}\n",
661
+ "\n",
662
+ "metric = {\"name\": \"val_loss\", \"goal\": \"minimize\"}\n",
663
+ "\n",
664
+ "sweep_config[\"metric\"] = metric\n",
665
+ "\n",
666
+ "parameters_dict = {\n",
667
+ " \"dropout_pb\": {\n",
668
+ " \"values\": [0.3, 0.4, 0.5],\n",
669
+ " },\n",
670
+ " \"learning_rate\": {\n",
671
+ " \"values\": [0.0001, 0.001, 0.01],\n",
672
+ " },\n",
673
+ " \"batch_size\": {\n",
674
+ " \"values\": [32, 64, 128],\n",
675
+ " },\n",
676
+ " \"lr_reduce_factor\": {\n",
677
+ " \"values\": [0.5, 0.6, 0.7, 0.8],\n",
678
+ " },\n",
679
+ " \"lr_reduce_patience\": {\n",
680
+ " \"values\": [2, 3, 4, 5],\n",
681
+ " },\n",
682
+ " \"epochs\": {\"value\": 1},\n",
683
+ "}\n",
684
+ "\n",
685
+ "sweep_config[\"parameters\"] = parameters_dict"
686
+ ]
687
+ },
688
+ {
689
+ "cell_type": "code",
690
+ "execution_count": null,
691
+ "metadata": {
692
+ "trusted": true
693
+ },
694
+ "outputs": [],
695
+ "source": [
696
+ "# create sweep\n",
697
+ "if Cfg.wandb_sweep:\n",
698
+ " sweep_id = wandb.sweep(sweep_config, project=\"NewsClassifier\")"
699
+ ]
700
+ },
701
+ {
702
+ "cell_type": "code",
703
+ "execution_count": null,
704
+ "metadata": {
705
+ "id": "oG-4tz-Lsswk",
706
+ "trusted": true
707
+ },
708
+ "outputs": [],
709
+ "source": [
710
+ "def train_loop(config=None):\n",
711
+ " # ====================================================\n",
712
+ " # loader\n",
713
+ " # ====================================================\n",
714
+ "\n",
715
+ " if not Cfg.wandb_sweep:\n",
716
+ " config = dict(\n",
717
+ " batch_size=Cfg.batch_size,\n",
718
+ " num_classes=7,\n",
719
+ " epochs=Cfg.epochs,\n",
720
+ " dropout_pb=Cfg.dropout_pb,\n",
721
+ " learning_rate=Cfg.lr,\n",
722
+ " lr_reduce_factor=Cfg.lr_redfactor,\n",
723
+ " lr_reduce_patience=Cfg.lr_redpatience,\n",
724
+ " )\n",
725
+ "\n",
726
+ " with wandb.init(project=\"NewsClassifier\", config=config):\n",
727
+ " config = wandb.config\n",
728
+ "\n",
729
+ " train_ds, val_ds = train_test_split(ds, test_size=Cfg.test_size, stratify=ds[\"Category\"])\n",
730
+ "\n",
731
+ " train_dataset = NewsDataset(train_ds)\n",
732
+ " valid_dataset = NewsDataset(val_ds)\n",
733
+ "\n",
734
+ " train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)\n",
735
+ " valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)\n",
736
+ "\n",
737
+ " # ====================================================\n",
738
+ " # model\n",
739
+ " # ====================================================\n",
740
+ " num_classes = 7\n",
741
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
742
+ "\n",
743
+ " model = CustomModel(num_classes=num_classes, dropout_pb=config.dropout_pb)\n",
744
+ " model.to(device)\n",
745
+ "\n",
746
+ " # ====================================================\n",
747
+ " # Training components\n",
748
+ " # ====================================================\n",
749
+ " criterion = nn.BCEWithLogitsLoss()\n",
750
+ " optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)\n",
751
+ " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
752
+ " optimizer, mode=\"min\", factor=config.lr_reduce_factor, patience=config.lr_reduce_patience\n",
753
+ " )\n",
754
+ "\n",
755
+ " # ====================================================\n",
756
+ " # loop\n",
757
+ " # ====================================================\n",
758
+ " wandb.watch(model, criterion, log=\"all\", log_freq=10)\n",
759
+ "\n",
760
+ " min_loss = np.inf\n",
761
+ "\n",
762
+ " for epoch in range(config.epochs):\n",
763
+ " start_time = time.time()\n",
764
+ "\n",
765
+ " # Step\n",
766
+ " train_loss = train_step(train_loader, model, num_classes, criterion, optimizer, epoch)\n",
767
+ " val_loss, _, _ = eval_step(valid_loader, model, num_classes, criterion, epoch)\n",
768
+ " scheduler.step(val_loss)\n",
769
+ "\n",
770
+ " # scoring\n",
771
+ " elapsed = time.time() - start_time\n",
772
+ " wandb.log({\"epoch\": epoch + 1, \"train_loss\": train_loss, \"val_loss\": val_loss})\n",
773
+ " print(f\"Epoch {epoch+1} - avg_train_loss: {train_loss:.4f} avg_val_loss: {val_loss:.4f} time: {elapsed:.0f}s\")\n",
774
+ "\n",
775
+ " if min_loss > val_loss:\n",
776
+ " min_loss = val_loss\n",
777
+ " print(\"Best Score : saving model.\")\n",
778
+ " os.makedirs(\"../artifacts\", exist_ok=True)\n",
779
+ " model.save(\"../artifacts\")\n",
780
+ " print(f\"\\nSaved Best Model Score: {min_loss:.4f}\\n\\n\")\n",
781
+ "\n",
782
+ " wandb.save(\"../artifacts/model.pt\")\n",
783
+ " torch.cuda.empty_cache()\n",
784
+ " gc.collect()"
785
+ ]
786
+ },
787
+ {
788
+ "cell_type": "code",
789
+ "execution_count": null,
790
+ "metadata": {
791
+ "id": "tIBl_kvssswk",
792
+ "outputId": "4bff057f-a3a7-45ca-f3c2-5b5fbd15bab5",
793
+ "trusted": true
794
+ },
795
+ "outputs": [],
796
+ "source": [
797
+ "# Train/Tune\n",
798
+ "if not Cfg.wandb_sweep:\n",
799
+ " train_loop()\n",
800
+ "else:\n",
801
+ " wandb.agent(sweep_id, train_loop, count=10)"
802
+ ]
803
+ },
804
+ {
805
+ "cell_type": "markdown",
806
+ "metadata": {
807
+ "id": "qxXv-FaNNtKJ"
808
+ },
809
+ "source": [
810
+ "## Inference"
811
+ ]
812
+ },
813
+ {
814
+ "cell_type": "code",
815
+ "execution_count": 34,
816
+ "metadata": {
817
+ "id": "SHCGJBhABesw",
818
+ "outputId": "a62f9ff6-d47d-46d0-f971-cfeb76adc6d5",
819
+ "trusted": true
820
+ },
821
+ "outputs": [
822
+ {
823
+ "name": "stderr",
824
+ "output_type": "stream",
825
+ "text": [
826
+ "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']\n",
827
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
828
+ ]
829
+ },
830
+ {
831
+ "data": {
832
+ "text/plain": [
833
+ "CustomModel(\n",
834
+ " (model): RobertaModel(\n",
835
+ " (embeddings): RobertaEmbeddings(\n",
836
+ " (word_embeddings): Embedding(50265, 768, padding_idx=1)\n",
837
+ " (position_embeddings): Embedding(514, 768, padding_idx=1)\n",
838
+ " (token_type_embeddings): Embedding(1, 768)\n",
839
+ " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
840
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
841
+ " )\n",
842
+ " (encoder): RobertaEncoder(\n",
843
+ " (layer): ModuleList(\n",
844
+ " (0-11): 12 x RobertaLayer(\n",
845
+ " (attention): RobertaAttention(\n",
846
+ " (self): RobertaSelfAttention(\n",
847
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
848
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
849
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
850
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
851
+ " )\n",
852
+ " (output): RobertaSelfOutput(\n",
853
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
854
+ " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
855
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
856
+ " )\n",
857
+ " )\n",
858
+ " (intermediate): RobertaIntermediate(\n",
859
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
860
+ " (intermediate_act_fn): GELUActivation()\n",
861
+ " )\n",
862
+ " (output): RobertaOutput(\n",
863
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
864
+ " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
865
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
866
+ " )\n",
867
+ " )\n",
868
+ " )\n",
869
+ " )\n",
870
+ " (pooler): RobertaPooler(\n",
871
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
872
+ " (activation): Tanh()\n",
873
+ " )\n",
874
+ " )\n",
875
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
876
+ " (fc): Linear(in_features=768, out_features=7, bias=True)\n",
877
+ ")"
878
+ ]
879
+ },
880
+ "execution_count": 34,
881
+ "metadata": {},
882
+ "output_type": "execute_result"
883
+ }
884
+ ],
885
+ "source": [
886
+ "model = CustomModel(num_classes=7)\n",
887
+ "model.load_state_dict(torch.load(\"../artifacts/model.pt\", map_location=torch.device(\"cpu\")))\n",
888
+ "model.to(device)"
889
+ ]
890
+ },
891
+ {
892
+ "cell_type": "code",
893
+ "execution_count": null,
894
+ "metadata": {
895
+ "id": "BjupBkbOCI22",
896
+ "trusted": true
897
+ },
898
+ "outputs": [],
899
+ "source": [
900
+ "def test_step(test_loader: DataLoader, model, num_classes: int) -> Tuple[np.ndarray, np.ndarray]:\n",
901
+ " \"\"\"Eval step.\"\"\"\n",
902
+ " model.eval()\n",
903
+ " y_trues, y_preds = [], []\n",
904
+ " with torch.inference_mode():\n",
905
+ " for step, (inputs, labels) in tqdm(enumerate(test_loader)):\n",
906
+ " inputs = collate(inputs)\n",
907
+ " for k, v in inputs.items():\n",
908
+ " inputs[k] = v.to(device)\n",
909
+ " labels = labels.to(device)\n",
910
+ " y_pred = model(inputs)\n",
911
+ " y_trues.extend(labels.cpu().numpy())\n",
912
+ " y_preds.extend(torch.argmax(y_pred, dim=1).cpu().numpy())\n",
913
+ " return np.vstack(y_trues), np.vstack(y_preds)"
914
+ ]
915
+ },
916
+ {
917
+ "cell_type": "code",
918
+ "execution_count": null,
919
+ "metadata": {
920
+ "id": "QimlSstFDsbJ",
921
+ "outputId": "8c903f7f-eddd-417c-c85e-4d57a4206501",
922
+ "trusted": true
923
+ },
924
+ "outputs": [],
925
+ "source": [
926
+ "test_dataset = NewsDataset(val_ds)\n",
927
+ "test_loader = DataLoader(test_dataset, batch_size=Cfg.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)\n",
928
+ "\n",
929
+ "y_true, y_pred = test_step(test_loader, model, 7)"
930
+ ]
931
+ },
932
+ {
933
+ "cell_type": "code",
934
+ "execution_count": null,
935
+ "metadata": {
936
+ "id": "CLz_GuoeEEgz",
937
+ "outputId": "8870b27c-46a6-4695-e526-e5c1e778f96a",
938
+ "trusted": true
939
+ },
940
+ "outputs": [],
941
+ "source": [
942
+ "print(\n",
943
+ " f'Precision: {precision_score(y_true, y_pred, average=\"weighted\")} \\n Recall: {recall_score(y_true, y_pred, average=\"weighted\")} \\n F1: {f1_score(y_true, y_pred, average=\"weighted\")} \\n Accuracy: {accuracy_score(y_true, y_pred)}'\n",
944
+ ")"
945
+ ]
946
+ },
947
+ {
948
+ "cell_type": "markdown",
949
+ "metadata": {
950
+ "id": "j_D8B0aNOBiI"
951
+ },
952
+ "source": [
953
+ "## Prediction on single sample"
954
+ ]
955
+ },
956
+ {
957
+ "cell_type": "code",
958
+ "execution_count": null,
959
+ "metadata": {},
960
+ "outputs": [],
961
+ "source": [
962
+ "val_ds"
963
+ ]
964
+ },
965
+ {
966
+ "cell_type": "code",
967
+ "execution_count": 36,
968
+ "metadata": {
969
+ "id": "-wU3xnKkH0Tt",
970
+ "outputId": "171245e5-4844-4e71-82b7-a0f3e97879e7",
971
+ "trusted": true
972
+ },
973
+ "outputs": [
974
+ {
975
+ "name": "stdout",
976
+ "output_type": "stream",
977
+ "text": [
978
+ "Ground Truth: 5, Sports\n",
979
+ "Predicted: 5, Sports\n",
980
+ "Predicted Probabilities: [9.8119999e-05 1.0613000e-04 7.7200002e-06 3.2520002e-05 8.3100003e-06\n",
981
+ " 9.9973667e-01 1.0560000e-05]\n"
982
+ ]
983
+ }
984
+ ],
985
+ "source": [
986
+ "sample = 2\n",
987
+ "sample_input = prepare_input(tokenizer, val_ds[\"Text\"].values[sample])\n",
988
+ "\n",
989
+ "cats = df[\"Category\"].unique().tolist()\n",
990
+ "num_classes = len(cats)\n",
991
+ "class_to_index = {tag: i for i, tag in enumerate(cats)}\n",
992
+ "index_to_class = {v: k for k, v in class_to_index.items()}\n",
993
+ "\n",
994
+ "label = val_ds[\"Category\"].values[sample]\n",
995
+ "input_ids = torch.unsqueeze(sample_input[\"input_ids\"], 0).to(device)\n",
996
+ "attention_masks = torch.unsqueeze(sample_input[\"attention_mask\"], 0).to(device)\n",
997
+ "test_sample = dict(input_ids=input_ids, attention_mask=attention_masks)\n",
998
+ "\n",
999
+ "with torch.no_grad():\n",
1000
+ " y_pred_test_sample = model.predict_proba(test_sample)\n",
1001
+ " print(f\"Ground Truth: {label}, {index_to_class[int(label)]}\")\n",
1002
+ " print(f\"Predicted: {np.argmax(y_pred_test_sample)}, {index_to_class[int(np.argmax(y_pred_test_sample))]}\")\n",
1003
+ " print(f\"Predicted Probabilities: {np.round(y_pred_test_sample, 8)[0]}\")"
1004
+ ]
1005
+ },
1006
+ {
1007
+ "cell_type": "code",
1008
+ "execution_count": null,
1009
+ "metadata": {},
1010
+ "outputs": [],
1011
+ "source": []
1012
+ }
1013
+ ],
1014
+ "metadata": {
1015
+ "kernelspec": {
1016
+ "display_name": "Python 3",
1017
+ "language": "python",
1018
+ "name": "python3"
1019
+ },
1020
+ "language_info": {
1021
+ "codemirror_mode": {
1022
+ "name": "ipython",
1023
+ "version": 3
1024
+ },
1025
+ "file_extension": ".py",
1026
+ "mimetype": "text/x-python",
1027
+ "name": "python",
1028
+ "nbconvert_exporter": "python",
1029
+ "pygments_lexer": "ipython3",
1030
+ "version": "3.10.13"
1031
+ }
1032
+ },
1033
+ "nbformat": 4,
1034
+ "nbformat_minor": 4
1035
+ }
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiosignal==1.3.1
2
+ attrs==23.1.0
3
+ certifi==2023.7.22
4
+ charset-normalizer==3.3.1
5
+ click==8.1.7
6
+ colorama==0.4.6
7
+ contourpy==1.1.1
8
+ cycler==0.12.1
9
+ filelock==3.12.4
10
+ fonttools==4.43.1
11
+ frozenlist==1.4.0
12
+ idna==3.4
13
+ jsonschema==4.19.1
14
+ jsonschema-specifications==2023.7.1
15
+ kiwisolver==1.4.5
16
+ matplotlib==3.8.0
17
+ msgpack==1.0.7
18
+ numpy==1.26.1
19
+ packaging==23.2
20
+ pandas==2.1.2
21
+ Pillow==10.1.0
22
+ protobuf==4.24.4
23
+ pyparsing==3.1.1
24
+ python-dateutil==2.8.2
25
+ pytz==2023.3.post1
26
+ PyYAML==6.0.1
27
+ ray==2.7.1
28
+ referencing==0.30.2
29
+ requests==2.31.0
30
+ rpds-py==0.10.6
31
+ seaborn==0.13.0
32
+ six==1.16.0
33
+ tzdata==2023.3
34
+ urllib3==2.0.7
setup.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from setuptools import find_packages, setup
4
+
5
+
6
+ def get_requirements(file_path: str) -> List[str]:
7
+ """Get the requirements/dependencies (packages) in a list."""
8
+ with open(file_path) as f:
9
+ lines = f.readlines()
10
+ requirements = [line.rstrip("\n") for line in lines]
11
+
12
+ return requirements
13
+
14
+
15
+ setup(
16
+ name="NewsClassifier",
17
+ version="1.0",
18
+ author="ManishW",
19
+ author_email="[email protected]",
20
+ description="",
21
+ packages=find_packages(),
22
+ install_requires=get_requirements("requirements.txt"),
23
+ )