Hiroaki OGASAWARA commited on
Commit
06a4fa8
·
1 Parent(s): 0aa8103

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ data/
2
+ figures/
3
+ models/*
4
+
5
+ __pycache__
README.md CHANGED
@@ -1,12 +1,29 @@
1
  ---
2
- title: Chiikawa Yonezu
3
- emoji: 🏢
4
- colorFrom: indigo
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.13.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
  ---
2
+ title: chiikawa-yonezu
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.13.0
 
 
6
  ---
7
+ # ちいかわか米津玄師か分類タスク
8
+
9
+ ```powershell
10
+ conda create -f environment.yml
11
+ conda activate chiikawa-yonezu
12
+ pip install fugashi ipadic
13
+ ```
14
+
15
+ ## Run gradio
16
+
17
+ ```powershell
18
+ conda activate chiikawa-yonezu
19
+ python app.py
20
+ # or
21
+ conda run -n chiikawa-yonezu python app.py # not recommended because standard output is not displayed
22
+ ```
23
+
24
+ ## Deploy to gradio
25
 
26
+ ```powershell
27
+ conda activate chiikawa-yonezu
28
+ gradio deploy
29
+ ```
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pprint import pprint
2
+
3
+ import gradio as gr
4
+ import torch
5
+
6
+ from safetensors import safe_open
7
+ from transformers import BertTokenizer
8
+
9
+ from utils.ClassifierModel import ClassifierModel
10
+
11
+
12
+ def _classify_text(text, model, device, tokenizer, max_length=20):
13
+ """
14
+ テキストが、'ちいかわ' と '米津玄師' のどちらに該当するかの確率を出力する。
15
+ """
16
+
17
+ # テキストをトークナイズし、PyTorchのテンソルに変換
18
+ inputs = tokenizer.encode_plus(
19
+ text,
20
+ add_special_tokens=True,
21
+ max_length=max_length,
22
+ padding="max_length",
23
+ truncation=True,
24
+ return_attention_mask=True,
25
+ return_tensors="pt",
26
+ )
27
+ pprint(f"inputs: {inputs}")
28
+
29
+ # モデルの推論
30
+ model.eval()
31
+ with torch.no_grad():
32
+ outputs = model(
33
+ inputs["input_ids"].to(device), inputs["attention_mask"].to(device)
34
+ )
35
+ pprint(f"outputs: {outputs}")
36
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
37
+
38
+ # 確率の取得
39
+ chiikawa_prob = probabilities[0][0].item()
40
+ yonezu_prob = probabilities[0][1].item()
41
+
42
+ return chiikawa_prob, yonezu_prob
43
+
44
+
45
+ is_cuda = torch.cuda.is_available()
46
+ device = torch.device("cuda" if is_cuda else "cpu")
47
+ pprint(f"device: {device}")
48
+
49
+ model_save_path = "models/model.safetensors"
50
+ tensors = {}
51
+ with safe_open(model_save_path, framework="pt", device="cpu") as f:
52
+ for key in f.keys():
53
+ tensors[key] = f.get_tensor(key)
54
+
55
+ inference_model: torch.nn.Module = ClassifierModel().to(device)
56
+ inference_model.load_state_dict(tensors)
57
+
58
+ tokenizer = BertTokenizer.from_pretrained(
59
+ "cl-tohoku/bert-base-japanese-whole-word-masking"
60
+ )
61
+
62
+
63
+ def classify_text(text):
64
+ chii_prob, yone_prob = _classify_text(text, inference_model, device, tokenizer)
65
+ return {"ちいかわ": chii_prob, "米津玄師": yone_prob}
66
+
67
+
68
+ demo = gr.Interface(
69
+ fn=classify_text,
70
+ inputs="textbox",
71
+ outputs="label",
72
+ examples=[
73
+ "炊き立て・・・・ってコト!?",
74
+ "晴れた空に種を蒔こう",
75
+ ],
76
+ )
77
+
78
+ demo.launch(share=True) # Share your demo with just 1 extra parameter 🚀
environment.yml ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: chiikawa-yonezu
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - aiohttp=3.9.1=py311ha68e1ae_0
9
+ - aiosignal=1.3.1=pyhd8ed1ab_0
10
+ - asttokens=2.4.1=pyhd8ed1ab_0
11
+ - attrs=23.2.0=pyh71513ae_0
12
+ - aws-c-auth=0.7.11=h9ca94be_0
13
+ - aws-c-cal=0.6.9=h7f0e5be_2
14
+ - aws-c-common=0.9.10=hcfcfb64_0
15
+ - aws-c-compression=0.2.17=h7f0e5be_7
16
+ - aws-c-event-stream=0.4.0=h51e6447_0
17
+ - aws-c-http=0.8.0=h80119a0_0
18
+ - aws-c-io=0.13.36=ha737126_3
19
+ - aws-c-mqtt=0.10.0=h2889a98_2
20
+ - aws-c-s3=0.4.7=h876bada_3
21
+ - aws-c-sdkutils=0.1.13=h7f0e5be_0
22
+ - aws-checksums=0.1.17=h7f0e5be_6
23
+ - aws-crt-cpp=0.26.0=h3a8c176_4
24
+ - aws-sdk-cpp=1.11.210=h79fa1a6_8
25
+ - blas=2.120=mkl
26
+ - blas-devel=3.9.0=20_win64_mkl
27
+ - brotli=1.1.0=hcfcfb64_1
28
+ - brotli-bin=1.1.0=hcfcfb64_1
29
+ - brotli-python=1.1.0=py311h12c1d0e_1
30
+ - bzip2=1.0.8=hcfcfb64_5
31
+ - c-ares=1.24.0=hcfcfb64_0
32
+ - ca-certificates=2023.11.17=h56e8100_0
33
+ - certifi=2023.11.17=pyhd8ed1ab_0
34
+ - charset-normalizer=3.3.2=pyhd8ed1ab_0
35
+ - colorama=0.4.6=pyhd8ed1ab_0
36
+ - comm=0.2.1=pyhd8ed1ab_0
37
+ - contourpy=1.2.0=py311h005e61a_0
38
+ - cuda-cccl=12.3.101=0
39
+ - cuda-cudart=12.1.105=0
40
+ - cuda-cudart-dev=12.1.105=0
41
+ - cuda-cupti=12.1.105=0
42
+ - cuda-libraries=12.1.0=0
43
+ - cuda-libraries-dev=12.1.0=0
44
+ - cuda-nvrtc=12.1.105=0
45
+ - cuda-nvrtc-dev=12.1.105=0
46
+ - cuda-nvtx=12.1.105=0
47
+ - cuda-opencl=12.3.101=0
48
+ - cuda-opencl-dev=12.3.101=0
49
+ - cuda-profiler-api=12.3.101=0
50
+ - cuda-runtime=12.1.0=0
51
+ - cycler=0.12.1=pyhd8ed1ab_0
52
+ - datasets=2.16.1=pyhd8ed1ab_0
53
+ - debugpy=1.8.0=py311h12c1d0e_1
54
+ - decorator=5.1.1=pyhd8ed1ab_0
55
+ - dill=0.3.7=pyhd8ed1ab_0
56
+ - exceptiongroup=1.2.0=pyhd8ed1ab_0
57
+ - executing=2.0.1=pyhd8ed1ab_0
58
+ - filelock=3.13.1=pyhd8ed1ab_0
59
+ - fonttools=4.47.0=py311ha68e1ae_0
60
+ - freetype=2.12.1=hdaf720e_2
61
+ - frozenlist=1.4.1=py311ha68e1ae_0
62
+ - fsspec=2023.10.0=pyhca7485f_0
63
+ - gettext=0.21.1=h5728263_0
64
+ - glib=2.78.3=h12be248_0
65
+ - glib-tools=2.78.3=h12be248_0
66
+ - gst-plugins-base=1.22.8=h001b923_0
67
+ - gstreamer=1.22.8=hb4038d2_0
68
+ - huggingface_hub=0.20.2=pyhd8ed1ab_0
69
+ - icu=73.2=h63175ca_0
70
+ - idna=3.6=pyhd8ed1ab_0
71
+ - importlib-metadata=7.0.1=pyha770c72_0
72
+ - importlib_metadata=7.0.1=hd8ed1ab_0
73
+ - intel-openmp=2023.2.0=h57928b3_50497
74
+ - ipykernel=6.28.0=pyha63f2e9_0
75
+ - ipython=8.19.0=pyh7428d3b_0
76
+ - jaconv=0.3.4=pyhd8ed1ab_0
77
+ - jedi=0.19.1=pyhd8ed1ab_0
78
+ - jinja2=3.1.2=pyhd8ed1ab_1
79
+ - joblib=1.3.2=pyhd8ed1ab_0
80
+ - jupyter_client=8.6.0=pyhd8ed1ab_0
81
+ - jupyter_core=5.7.0=py311h1ea47a8_0
82
+ - kiwisolver=1.4.5=py311h005e61a_1
83
+ - krb5=1.21.2=heb0366b_0
84
+ - lcms2=2.16=h67d730c_0
85
+ - lerc=4.0.0=h63175ca_0
86
+ - libabseil=20230802.1=cxx17_h63175ca_0
87
+ - libarrow=14.0.2=he5f67d5_2_cpu
88
+ - libarrow-acero=14.0.2=h63175ca_2_cpu
89
+ - libarrow-dataset=14.0.2=h63175ca_2_cpu
90
+ - libarrow-flight=14.0.2=h53b1db0_2_cpu
91
+ - libarrow-flight-sql=14.0.2=h78eab7c_2_cpu
92
+ - libarrow-gandiva=14.0.2=hb2eaab1_2_cpu
93
+ - libarrow-substrait=14.0.2=hd4c9904_2_cpu
94
+ - libblas=3.9.0=20_win64_mkl
95
+ - libbrotlicommon=1.1.0=hcfcfb64_1
96
+ - libbrotlidec=1.1.0=hcfcfb64_1
97
+ - libbrotlienc=1.1.0=hcfcfb64_1
98
+ - libcblas=3.9.0=20_win64_mkl
99
+ - libclang=15.0.7=default_h77d9078_3
100
+ - libclang13=15.0.7=default_h77d9078_3
101
+ - libcrc32c=1.1.2=h0e60522_0
102
+ - libcublas=12.1.0.26=0
103
+ - libcublas-dev=12.1.0.26=0
104
+ - libcufft=11.0.2.4=0
105
+ - libcufft-dev=11.0.2.4=0
106
+ - libcurand=10.3.4.107=0
107
+ - libcurand-dev=10.3.4.107=0
108
+ - libcurl=8.5.0=hd5e4a3a_0
109
+ - libcusolver=11.4.4.55=0
110
+ - libcusolver-dev=11.4.4.55=0
111
+ - libcusparse=12.0.2.55=0
112
+ - libcusparse-dev=12.0.2.55=0
113
+ - libdeflate=1.19=hcfcfb64_0
114
+ - libevent=2.1.12=h3671451_1
115
+ - libexpat=2.5.0=h63175ca_1
116
+ - libffi=3.4.2=h8ffe710_5
117
+ - libglib=2.78.3=h16e383f_0
118
+ - libgoogle-cloud=2.12.0=h39f2fc6_4
119
+ - libgrpc=1.59.3=h5bbd4a7_0
120
+ - libhwloc=2.9.3=default_haede6df_1009
121
+ - libiconv=1.17=hcfcfb64_2
122
+ - libjpeg-turbo=3.0.0=hcfcfb64_1
123
+ - liblapack=3.9.0=20_win64_mkl
124
+ - liblapacke=3.9.0=20_win64_mkl
125
+ - libnpp=12.0.2.50=0
126
+ - libnpp-dev=12.0.2.50=0
127
+ - libnvjitlink=12.1.105=0
128
+ - libnvjitlink-dev=12.1.105=0
129
+ - libnvjpeg=12.1.1.14=0
130
+ - libnvjpeg-dev=12.1.1.14=0
131
+ - libogg=1.3.4=h8ffe710_1
132
+ - libparquet=14.0.2=h7ec3a38_2_cpu
133
+ - libpng=1.6.39=h19919ed_0
134
+ - libprotobuf=4.24.4=hb8276f3_0
135
+ - libre2-11=2023.06.02=h8c5ae5e_0
136
+ - libsodium=1.0.18=h8d14728_1
137
+ - libsqlite=3.44.2=hcfcfb64_0
138
+ - libssh2=1.11.0=h7dfc565_0
139
+ - libthrift=0.19.0=ha2b3283_1
140
+ - libtiff=4.6.0=h6e2ebb7_2
141
+ - libutf8proc=2.8.0=h82a8f57_0
142
+ - libuv=1.44.2=hcfcfb64_1
143
+ - libvorbis=1.3.7=h0e60522_0
144
+ - libwebp-base=1.3.2=hcfcfb64_0
145
+ - libxcb=1.15=hcd874cb_0
146
+ - libxml2=2.11.6=hc3477c8_0
147
+ - libzlib=1.2.13=hcfcfb64_5
148
+ - lz4-c=1.9.4=hcfcfb64_0
149
+ - m2w64-gcc-libgfortran=5.3.0=6
150
+ - m2w64-gcc-libs=5.3.0=7
151
+ - m2w64-gcc-libs-core=5.3.0=7
152
+ - m2w64-gmp=6.1.0=2
153
+ - m2w64-libwinpthread-git=5.0.0.4634.697f757=2
154
+ - markupsafe=2.1.3=py311ha68e1ae_1
155
+ - matplotlib=3.8.2=py311h1ea47a8_0
156
+ - matplotlib-base=3.8.2=py311h6e989c2_0
157
+ - matplotlib-inline=0.1.6=pyhd8ed1ab_0
158
+ - mkl=2023.2.0=h6a75c08_50497
159
+ - mkl-devel=2023.2.0=h57928b3_50497
160
+ - mkl-include=2023.2.0=h6a75c08_50497
161
+ - mpmath=1.3.0=pyhd8ed1ab_0
162
+ - msys2-conda-epoch=20160418=1
163
+ - multidict=6.0.4=py311ha68e1ae_1
164
+ - multiprocess=0.70.15=py311ha68e1ae_1
165
+ - munkres=1.1.4=pyh9f0ad1d_0
166
+ - nest-asyncio=1.5.8=pyhd8ed1ab_0
167
+ - networkx=3.2.1=pyhd8ed1ab_0
168
+ - numpy=1.26.3=py311h0b4df5a_0
169
+ - openjpeg=2.5.0=h3d672ee_3
170
+ - openssl=3.2.0=hcfcfb64_1
171
+ - orc=1.9.2=hf0b6bd4_0
172
+ - packaging=23.2=pyhd8ed1ab_0
173
+ - pandas=2.1.4=py311hf63dbb6_0
174
+ - parso=0.8.3=pyhd8ed1ab_0
175
+ - pcre2=10.42=h17e33f8_0
176
+ - pickleshare=0.7.5=py_1003
177
+ - pillow=10.2.0=py311h4dd8a23_0
178
+ - pip=23.3.2=pyhd8ed1ab_0
179
+ - platformdirs=4.1.0=pyhd8ed1ab_0
180
+ - ply=3.11=py_1
181
+ - prompt-toolkit=3.0.42=pyha770c72_0
182
+ - psutil=5.9.7=py311ha68e1ae_0
183
+ - pthread-stubs=0.4=hcd874cb_1001
184
+ - pthreads-win32=2.9.1=hfa6e2cd_3
185
+ - pure_eval=0.2.2=pyhd8ed1ab_0
186
+ - pyarrow=14.0.2=py311h6a6099b_2_cpu
187
+ - pyarrow-hotfix=0.6=pyhd8ed1ab_0
188
+ - pygments=2.17.2=pyhd8ed1ab_0
189
+ - pyparsing=3.1.1=pyhd8ed1ab_0
190
+ - pyqt=5.15.9=py311h125bc19_5
191
+ - pyqt5-sip=12.12.2=py311h12c1d0e_5
192
+ - pysocks=1.7.1=pyh0701188_6
193
+ - python=3.11.7=h2628c8c_1_cpython
194
+ - python-dateutil=2.8.2=pyhd8ed1ab_0
195
+ - python-tzdata=2023.4=pyhd8ed1ab_0
196
+ - python-xxhash=3.4.1=py311ha68e1ae_0
197
+ - python_abi=3.11=4_cp311
198
+ - pytorch=2.1.2=py3.11_cuda12.1_cudnn8_0
199
+ - pytorch-cuda=12.1=hde6ce7c_5
200
+ - pytorch-mutex=1.0=cuda
201
+ - pytz=2023.3.post1=pyhd8ed1ab_0
202
+ - pywin32=306=py311h12c1d0e_2
203
+ - pyyaml=6.0.1=py311ha68e1ae_1
204
+ - pyzmq=25.1.2=py311h9250fbb_0
205
+ - qt-main=5.15.8=h9e85ed6_18
206
+ - re2=2023.06.02=hcbb65ff_0
207
+ - regex=2023.12.25=py311ha68e1ae_0
208
+ - requests=2.31.0=pyhd8ed1ab_0
209
+ - safetensors=0.3.3=py311hc37eb10_1
210
+ - scikit-learn=1.3.2=py311h142b183_2
211
+ - scipy=1.11.4=py311h0b4df5a_0
212
+ - setuptools=69.0.3=pyhd8ed1ab_0
213
+ - sip=6.7.12=py311h12c1d0e_0
214
+ - six=1.16.0=pyh6c4a22f_0
215
+ - snappy=1.1.10=hfb803bf_0
216
+ - stack_data=0.6.2=pyhd8ed1ab_0
217
+ - sympy=1.12=pyh04b8f61_3
218
+ - tbb=2021.11.0=h91493d7_0
219
+ - threadpoolctl=3.2.0=pyha21a80b_0
220
+ - tk=8.6.13=h5226925_1
221
+ - tokenizers=0.15.0=py311h91c4a10_1
222
+ - toml=0.10.2=pyhd8ed1ab_0
223
+ - tomli=2.0.1=pyhd8ed1ab_0
224
+ - tornado=6.3.3=py311ha68e1ae_1
225
+ - tqdm=4.66.1=pyhd8ed1ab_0
226
+ - traitlets=5.14.1=pyhd8ed1ab_0
227
+ - transformers=4.36.2=pyhd8ed1ab_0
228
+ - typing-extensions=4.9.0=hd8ed1ab_0
229
+ - typing_extensions=4.9.0=pyha770c72_0
230
+ - tzdata=2023d=h0c530f3_0
231
+ - ucrt=10.0.22621.0=h57928b3_0
232
+ - urllib3=2.1.0=pyhd8ed1ab_0
233
+ - vc=14.3=hcf57466_18
234
+ - vc14_runtime=14.38.33130=h82b7239_18
235
+ - vs2015_runtime=14.38.33130=hcb4865c_18
236
+ - wcwidth=0.2.12=pyhd8ed1ab_0
237
+ - wheel=0.42.0=pyhd8ed1ab_0
238
+ - win_inet_pton=1.1.0=pyhd8ed1ab_6
239
+ - xorg-libxau=1.0.11=hcd874cb_0
240
+ - xorg-libxdmcp=1.1.3=hcd874cb_0
241
+ - xxhash=0.8.2=hcfcfb64_0
242
+ - xz=5.2.6=h8d14728_0
243
+ - yaml=0.2.5=h8ffe710_2
244
+ - yarl=1.9.3=py311ha68e1ae_0
245
+ - zeromq=4.3.5=h63175ca_0
246
+ - zipp=3.17.0=pyhd8ed1ab_0
247
+ - zstd=1.5.5=h12be248_0
248
+ - pip:
249
+ - aiofiles==23.2.1
250
+ - altair==5.2.0
251
+ - annotated-types==0.6.0
252
+ - anyio==4.2.0
253
+ - click==8.1.7
254
+ - fastapi==0.108.0
255
+ - ffmpy==0.3.1
256
+ - gradio==4.13.0
257
+ - gradio-client==0.8.0
258
+ - h11==0.14.0
259
+ - httpcore==1.0.2
260
+ - httpx==0.26.0
261
+ - importlib-resources==6.1.1
262
+ - jsonschema==4.20.0
263
+ - jsonschema-specifications==2023.12.1
264
+ - markdown-it-py==3.0.0
265
+ - mdurl==0.1.2
266
+ - orjson==3.9.10
267
+ - pydantic==2.5.3
268
+ - pydantic-core==2.14.6
269
+ - pydub==0.25.1
270
+ - python-multipart==0.0.6
271
+ - referencing==0.32.1
272
+ - rich==13.7.0
273
+ - rpds-py==0.16.2
274
+ - semantic-version==2.10.0
275
+ - shellingham==1.5.4
276
+ - sniffio==1.3.0
277
+ - starlette==0.32.0.post1
278
+ - tomlkit==0.12.0
279
+ - toolz==0.12.0
280
+ - torchaudio==2.1.2
281
+ - torchvision==0.16.2
282
+ - typer==0.9.0
283
+ - uvicorn==0.25.0
284
+ - websockets==11.0.3
285
+
notebooks/embeddings.ipynb ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from transformers import BertTokenizer, BertModel\n",
10
+ "import torch\n",
11
+ "from pprint import pprint"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 5,
17
+ "metadata": {},
18
+ "outputs": [
19
+ {
20
+ "name": "stderr",
21
+ "output_type": "stream",
22
+ "text": [
23
+ "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n",
24
+ "The tokenizer class you load from this checkpoint is 'BertJapaneseTokenizer'. \n",
25
+ "The class this function is called from is 'BertTokenizer'.\n"
26
+ ]
27
+ },
28
+ {
29
+ "name": "stdout",
30
+ "output_type": "stream",
31
+ "text": [
32
+ "{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]]),\n",
33
+ " 'input_ids': tensor([[ 2, 73, 371, 37, 1541, 546, 3]]),\n",
34
+ " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]])}\n",
35
+ "torch.Size([1, 7, 768])\n"
36
+ ]
37
+ }
38
+ ],
39
+ "source": [
40
+ "# 日本語の事前学習済みモデルとトークナイザーの読み込み\n",
41
+ "tokenizer = BertTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')\n",
42
+ "model = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')\n",
43
+ "\n",
44
+ "# テキストをトークン化し、PyTorchテンソルに変換\n",
45
+ "text = \"お正月休み\"\n",
46
+ "encoded_input = tokenizer(text, return_tensors='pt')\n",
47
+ "pprint(encoded_input)\n",
48
+ "\n",
49
+ "# 単語埋め込みを取得\n",
50
+ "with torch.no_grad():\n",
51
+ " output = model(**encoded_input)\n",
52
+ " embeddings = output.last_hidden_state\n",
53
+ " pprint(embeddings.shape)"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": []
62
+ }
63
+ ],
64
+ "metadata": {
65
+ "kernelspec": {
66
+ "display_name": "chiikawa-yonezu",
67
+ "language": "python",
68
+ "name": "python3"
69
+ },
70
+ "language_info": {
71
+ "codemirror_mode": {
72
+ "name": "ipython",
73
+ "version": 3
74
+ },
75
+ "file_extension": ".py",
76
+ "mimetype": "text/x-python",
77
+ "name": "python",
78
+ "nbconvert_exporter": "python",
79
+ "pygments_lexer": "ipython3",
80
+ "version": "3.11.7"
81
+ }
82
+ },
83
+ "nbformat": 4,
84
+ "nbformat_minor": 2
85
+ }
notebooks/preprocessing.ipynb ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 15,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import csv\n",
10
+ "import jaconv\n",
11
+ "import re\n",
12
+ "\n",
13
+ "def preprocess(csv_path: str, preprocessed_csv_path: str):\n",
14
+ " \"\"\"\n",
15
+ " 与えられたCSVファイルを読み込んだのち、以下の処理をしてから、{CSVファイル名}_preprocessed.csvとして保存する\n",
16
+ " CSVファイルのフォーマットは\"TEXT,LABEL\"の2列である。\n",
17
+ "\n",
18
+ " TEXTの変換ルールは次の通り。\n",
19
+ " 1. 文字列が半角・全角スペース・改行を含む場合、その文字列を複数の文字列に分割する\n",
20
+ " 2. 記号(!,?,!,?,・,.,…,',\",♪,♫)と全ての絵文字を削除する\n",
21
+ " 3. ()または()で囲まれた文字列を削除する\n",
22
+ " 4. 半角カタカナを全角カタカナに、~を~に、-をーに変換する\n",
23
+ " 5. 2つ以上連続する~~を~に、ーーをーに変換する\n",
24
+ " 6. 空文字列を削除する\n",
25
+ "\n",
26
+ " 保存する前にフィルタリングを行う。\n",
27
+ " 1. TEXTが空文字列の行を削除する\n",
28
+ " 2. TEXTとLABELの組み合わせが重複している行を削除する\n",
29
+ " \"\"\"\n",
30
+ " # Read the CSV file\n",
31
+ " with open(csv_path, 'r', encoding='utf-8') as file:\n",
32
+ " reader = csv.reader(file)\n",
33
+ " data = list(reader)\n",
34
+ " \n",
35
+ " preprocessed_data = []\n",
36
+ "\n",
37
+ " # Preprocess the TEXT column\n",
38
+ " for i in range(len(data)):\n",
39
+ " text, label = data[i]\n",
40
+ " # Split the text into multiple strings if it contains spaces or newlines\n",
41
+ " text = re.split(r'\\s+', text)\n",
42
+ " # Remove symbols\n",
43
+ " text = [re.sub(r'[!?!?・.…\\'\"’”\\♪♫]', '', word) for word in text]\n",
44
+ " # Remove strings enclosed in parentheses\n",
45
+ " text = [re.sub(r'\\(.*?\\)|(.*?)', '', word) for word in text]\n",
46
+ " # Convert half-width katakana to full-width katakana\n",
47
+ " text = [jaconv.h2z(word) for word in text]\n",
48
+ " # Convert ~ to ~ and - to ー\n",
49
+ " # Note: 〜(U+301C) is a different character from ~(U+FF5E\n",
50
+ " text = [re.sub(r'[~〜]', '~', word) for word in text]\n",
51
+ " text = [re.sub(r'-', 'ー', word) for word in text]\n",
52
+ " # Convert multiple consecutive ~ to ~ and ーー to ー\n",
53
+ " text = [re.sub(r'~+', '~', word) for word in text]\n",
54
+ " text = [re.sub(r'ー+', 'ー', word) for word in text]\n",
55
+ " \n",
56
+ " [preprocessed_data.append([word, label]) for word in text if word != '' ]\n",
57
+ "\n",
58
+ " # Remove duplicate rows based on TEXT and LABEL combination\n",
59
+ " preprocessed_data = [list(x) for x in set(tuple(x) for x in preprocessed_data)]\n",
60
+ "\n",
61
+ " # Sort the data by LABEL, TEXT\n",
62
+ " preprocessed_data.sort(key=lambda x: (x[1], x[0]))\n",
63
+ "\n",
64
+ " # Save the preprocessed data to a new CSV file\n",
65
+ " with open(preprocessed_csv_path, 'w', encoding='utf-8', newline='') as file:\n",
66
+ " writer = csv.writer(file)\n",
67
+ " writer.writerows(preprocessed_data)\n"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 16,
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "import pandas as pd\n",
77
+ "from sklearn.model_selection import train_test_split\n",
78
+ "\n",
79
+ "\n",
80
+ "def split(csv_path: str):\n",
81
+ " # 元のCSVファイルを読み込む\n",
82
+ " df = pd.read_csv(csv_path, encoding='utf-8')\n",
83
+ "\n",
84
+ " # 訓練用データセットとテスト用データセットに分割\n",
85
+ " train_df, test_df = train_test_split(df, test_size=0.05) # 高速化のため検証データの数を減らす\n",
86
+ "\n",
87
+ " # 新しいCSVファイルとして保存\n",
88
+ " train_df.to_csv(csv_path.replace('.csv', '_train.csv'), index=False)\n",
89
+ " test_df.to_csv(csv_path.replace('.csv', '_test.csv'), index=False)\n"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 17,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "csv_path = '../data/data.csv'\n",
99
+ "preprocessed_csv_path = csv_path.replace('.csv', '_preprocessed.csv')\n",
100
+ "preprocess(csv_path, preprocessed_csv_path)\n",
101
+ "split(preprocessed_csv_path)"
102
+ ]
103
+ }
104
+ ],
105
+ "metadata": {
106
+ "kernelspec": {
107
+ "display_name": "chiikawa-yonezu",
108
+ "language": "python",
109
+ "name": "python3"
110
+ },
111
+ "language_info": {
112
+ "codemirror_mode": {
113
+ "name": "ipython",
114
+ "version": 3
115
+ },
116
+ "file_extension": ".py",
117
+ "mimetype": "text/x-python",
118
+ "name": "python",
119
+ "nbconvert_exporter": "python",
120
+ "pygments_lexer": "ipython3",
121
+ "version": "3.11.7"
122
+ }
123
+ },
124
+ "nbformat": 4,
125
+ "nbformat_minor": 2
126
+ }
notebooks/train.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ safetensors
3
+ torch
4
+ transformers
utils/ClassifierModel.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertModel
3
+
4
+
5
+ class ClassifierModel(torch.nn.Module):
6
+ def __init__(self):
7
+ super(ClassifierModel, self).__init__()
8
+ self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
9
+ self.linear = torch.nn.Linear(768, 2) # BERTの隠れ層の次元数と出力クラス数
10
+
11
+ def forward(self, input_ids, attention_mask):
12
+ with torch.no_grad():
13
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
14
+ last_hidden_state = outputs[0]
15
+ pooled_output = last_hidden_state[:, 0]
16
+ return self.linear(pooled_output)