Spaces:
Runtime error
Runtime error
Hiroaki OGASAWARA
commited on
Commit
·
06a4fa8
1
Parent(s):
0aa8103
Upload folder using huggingface_hub
Browse files- .gitignore +5 -0
- README.md +24 -7
- app.py +78 -0
- environment.yml +285 -0
- notebooks/embeddings.ipynb +85 -0
- notebooks/preprocessing.ipynb +126 -0
- notebooks/train.ipynb +0 -0
- requirements.txt +4 -0
- utils/ClassifierModel.py +16 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data/
|
2 |
+
figures/
|
3 |
+
models/*
|
4 |
+
|
5 |
+
__pycache__
|
README.md
CHANGED
@@ -1,12 +1,29 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.13.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: chiikawa-yonezu
|
3 |
+
app_file: app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 4.13.0
|
|
|
|
|
6 |
---
|
7 |
+
# ちいかわか米津玄師か分類タスク
|
8 |
+
|
9 |
+
```powershell
|
10 |
+
conda create -f environment.yml
|
11 |
+
conda activate chiikawa-yonezu
|
12 |
+
pip install fugashi ipadic
|
13 |
+
```
|
14 |
+
|
15 |
+
## Run gradio
|
16 |
+
|
17 |
+
```powershell
|
18 |
+
conda activate chiikawa-yonezu
|
19 |
+
python app.py
|
20 |
+
# or
|
21 |
+
conda run -n chiikawa-yonezu python app.py # not recommended because standard output is not displayed
|
22 |
+
```
|
23 |
+
|
24 |
+
## Deploy to gradio
|
25 |
|
26 |
+
```powershell
|
27 |
+
conda activate chiikawa-yonezu
|
28 |
+
gradio deploy
|
29 |
+
```
|
app.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pprint import pprint
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from safetensors import safe_open
|
7 |
+
from transformers import BertTokenizer
|
8 |
+
|
9 |
+
from utils.ClassifierModel import ClassifierModel
|
10 |
+
|
11 |
+
|
12 |
+
def _classify_text(text, model, device, tokenizer, max_length=20):
|
13 |
+
"""
|
14 |
+
テキストが、'ちいかわ' と '米津玄師' のどちらに該当するかの確率を出力する。
|
15 |
+
"""
|
16 |
+
|
17 |
+
# テキストをトークナイズし、PyTorchのテンソルに変換
|
18 |
+
inputs = tokenizer.encode_plus(
|
19 |
+
text,
|
20 |
+
add_special_tokens=True,
|
21 |
+
max_length=max_length,
|
22 |
+
padding="max_length",
|
23 |
+
truncation=True,
|
24 |
+
return_attention_mask=True,
|
25 |
+
return_tensors="pt",
|
26 |
+
)
|
27 |
+
pprint(f"inputs: {inputs}")
|
28 |
+
|
29 |
+
# モデルの推論
|
30 |
+
model.eval()
|
31 |
+
with torch.no_grad():
|
32 |
+
outputs = model(
|
33 |
+
inputs["input_ids"].to(device), inputs["attention_mask"].to(device)
|
34 |
+
)
|
35 |
+
pprint(f"outputs: {outputs}")
|
36 |
+
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
37 |
+
|
38 |
+
# 確率の取得
|
39 |
+
chiikawa_prob = probabilities[0][0].item()
|
40 |
+
yonezu_prob = probabilities[0][1].item()
|
41 |
+
|
42 |
+
return chiikawa_prob, yonezu_prob
|
43 |
+
|
44 |
+
|
45 |
+
is_cuda = torch.cuda.is_available()
|
46 |
+
device = torch.device("cuda" if is_cuda else "cpu")
|
47 |
+
pprint(f"device: {device}")
|
48 |
+
|
49 |
+
model_save_path = "models/model.safetensors"
|
50 |
+
tensors = {}
|
51 |
+
with safe_open(model_save_path, framework="pt", device="cpu") as f:
|
52 |
+
for key in f.keys():
|
53 |
+
tensors[key] = f.get_tensor(key)
|
54 |
+
|
55 |
+
inference_model: torch.nn.Module = ClassifierModel().to(device)
|
56 |
+
inference_model.load_state_dict(tensors)
|
57 |
+
|
58 |
+
tokenizer = BertTokenizer.from_pretrained(
|
59 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking"
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def classify_text(text):
|
64 |
+
chii_prob, yone_prob = _classify_text(text, inference_model, device, tokenizer)
|
65 |
+
return {"ちいかわ": chii_prob, "米津玄師": yone_prob}
|
66 |
+
|
67 |
+
|
68 |
+
demo = gr.Interface(
|
69 |
+
fn=classify_text,
|
70 |
+
inputs="textbox",
|
71 |
+
outputs="label",
|
72 |
+
examples=[
|
73 |
+
"炊き立て・・・・ってコト!?",
|
74 |
+
"晴れた空に種を蒔こう",
|
75 |
+
],
|
76 |
+
)
|
77 |
+
|
78 |
+
demo.launch(share=True) # Share your demo with just 1 extra parameter 🚀
|
environment.yml
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: chiikawa-yonezu
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
- defaults
|
7 |
+
dependencies:
|
8 |
+
- aiohttp=3.9.1=py311ha68e1ae_0
|
9 |
+
- aiosignal=1.3.1=pyhd8ed1ab_0
|
10 |
+
- asttokens=2.4.1=pyhd8ed1ab_0
|
11 |
+
- attrs=23.2.0=pyh71513ae_0
|
12 |
+
- aws-c-auth=0.7.11=h9ca94be_0
|
13 |
+
- aws-c-cal=0.6.9=h7f0e5be_2
|
14 |
+
- aws-c-common=0.9.10=hcfcfb64_0
|
15 |
+
- aws-c-compression=0.2.17=h7f0e5be_7
|
16 |
+
- aws-c-event-stream=0.4.0=h51e6447_0
|
17 |
+
- aws-c-http=0.8.0=h80119a0_0
|
18 |
+
- aws-c-io=0.13.36=ha737126_3
|
19 |
+
- aws-c-mqtt=0.10.0=h2889a98_2
|
20 |
+
- aws-c-s3=0.4.7=h876bada_3
|
21 |
+
- aws-c-sdkutils=0.1.13=h7f0e5be_0
|
22 |
+
- aws-checksums=0.1.17=h7f0e5be_6
|
23 |
+
- aws-crt-cpp=0.26.0=h3a8c176_4
|
24 |
+
- aws-sdk-cpp=1.11.210=h79fa1a6_8
|
25 |
+
- blas=2.120=mkl
|
26 |
+
- blas-devel=3.9.0=20_win64_mkl
|
27 |
+
- brotli=1.1.0=hcfcfb64_1
|
28 |
+
- brotli-bin=1.1.0=hcfcfb64_1
|
29 |
+
- brotli-python=1.1.0=py311h12c1d0e_1
|
30 |
+
- bzip2=1.0.8=hcfcfb64_5
|
31 |
+
- c-ares=1.24.0=hcfcfb64_0
|
32 |
+
- ca-certificates=2023.11.17=h56e8100_0
|
33 |
+
- certifi=2023.11.17=pyhd8ed1ab_0
|
34 |
+
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
35 |
+
- colorama=0.4.6=pyhd8ed1ab_0
|
36 |
+
- comm=0.2.1=pyhd8ed1ab_0
|
37 |
+
- contourpy=1.2.0=py311h005e61a_0
|
38 |
+
- cuda-cccl=12.3.101=0
|
39 |
+
- cuda-cudart=12.1.105=0
|
40 |
+
- cuda-cudart-dev=12.1.105=0
|
41 |
+
- cuda-cupti=12.1.105=0
|
42 |
+
- cuda-libraries=12.1.0=0
|
43 |
+
- cuda-libraries-dev=12.1.0=0
|
44 |
+
- cuda-nvrtc=12.1.105=0
|
45 |
+
- cuda-nvrtc-dev=12.1.105=0
|
46 |
+
- cuda-nvtx=12.1.105=0
|
47 |
+
- cuda-opencl=12.3.101=0
|
48 |
+
- cuda-opencl-dev=12.3.101=0
|
49 |
+
- cuda-profiler-api=12.3.101=0
|
50 |
+
- cuda-runtime=12.1.0=0
|
51 |
+
- cycler=0.12.1=pyhd8ed1ab_0
|
52 |
+
- datasets=2.16.1=pyhd8ed1ab_0
|
53 |
+
- debugpy=1.8.0=py311h12c1d0e_1
|
54 |
+
- decorator=5.1.1=pyhd8ed1ab_0
|
55 |
+
- dill=0.3.7=pyhd8ed1ab_0
|
56 |
+
- exceptiongroup=1.2.0=pyhd8ed1ab_0
|
57 |
+
- executing=2.0.1=pyhd8ed1ab_0
|
58 |
+
- filelock=3.13.1=pyhd8ed1ab_0
|
59 |
+
- fonttools=4.47.0=py311ha68e1ae_0
|
60 |
+
- freetype=2.12.1=hdaf720e_2
|
61 |
+
- frozenlist=1.4.1=py311ha68e1ae_0
|
62 |
+
- fsspec=2023.10.0=pyhca7485f_0
|
63 |
+
- gettext=0.21.1=h5728263_0
|
64 |
+
- glib=2.78.3=h12be248_0
|
65 |
+
- glib-tools=2.78.3=h12be248_0
|
66 |
+
- gst-plugins-base=1.22.8=h001b923_0
|
67 |
+
- gstreamer=1.22.8=hb4038d2_0
|
68 |
+
- huggingface_hub=0.20.2=pyhd8ed1ab_0
|
69 |
+
- icu=73.2=h63175ca_0
|
70 |
+
- idna=3.6=pyhd8ed1ab_0
|
71 |
+
- importlib-metadata=7.0.1=pyha770c72_0
|
72 |
+
- importlib_metadata=7.0.1=hd8ed1ab_0
|
73 |
+
- intel-openmp=2023.2.0=h57928b3_50497
|
74 |
+
- ipykernel=6.28.0=pyha63f2e9_0
|
75 |
+
- ipython=8.19.0=pyh7428d3b_0
|
76 |
+
- jaconv=0.3.4=pyhd8ed1ab_0
|
77 |
+
- jedi=0.19.1=pyhd8ed1ab_0
|
78 |
+
- jinja2=3.1.2=pyhd8ed1ab_1
|
79 |
+
- joblib=1.3.2=pyhd8ed1ab_0
|
80 |
+
- jupyter_client=8.6.0=pyhd8ed1ab_0
|
81 |
+
- jupyter_core=5.7.0=py311h1ea47a8_0
|
82 |
+
- kiwisolver=1.4.5=py311h005e61a_1
|
83 |
+
- krb5=1.21.2=heb0366b_0
|
84 |
+
- lcms2=2.16=h67d730c_0
|
85 |
+
- lerc=4.0.0=h63175ca_0
|
86 |
+
- libabseil=20230802.1=cxx17_h63175ca_0
|
87 |
+
- libarrow=14.0.2=he5f67d5_2_cpu
|
88 |
+
- libarrow-acero=14.0.2=h63175ca_2_cpu
|
89 |
+
- libarrow-dataset=14.0.2=h63175ca_2_cpu
|
90 |
+
- libarrow-flight=14.0.2=h53b1db0_2_cpu
|
91 |
+
- libarrow-flight-sql=14.0.2=h78eab7c_2_cpu
|
92 |
+
- libarrow-gandiva=14.0.2=hb2eaab1_2_cpu
|
93 |
+
- libarrow-substrait=14.0.2=hd4c9904_2_cpu
|
94 |
+
- libblas=3.9.0=20_win64_mkl
|
95 |
+
- libbrotlicommon=1.1.0=hcfcfb64_1
|
96 |
+
- libbrotlidec=1.1.0=hcfcfb64_1
|
97 |
+
- libbrotlienc=1.1.0=hcfcfb64_1
|
98 |
+
- libcblas=3.9.0=20_win64_mkl
|
99 |
+
- libclang=15.0.7=default_h77d9078_3
|
100 |
+
- libclang13=15.0.7=default_h77d9078_3
|
101 |
+
- libcrc32c=1.1.2=h0e60522_0
|
102 |
+
- libcublas=12.1.0.26=0
|
103 |
+
- libcublas-dev=12.1.0.26=0
|
104 |
+
- libcufft=11.0.2.4=0
|
105 |
+
- libcufft-dev=11.0.2.4=0
|
106 |
+
- libcurand=10.3.4.107=0
|
107 |
+
- libcurand-dev=10.3.4.107=0
|
108 |
+
- libcurl=8.5.0=hd5e4a3a_0
|
109 |
+
- libcusolver=11.4.4.55=0
|
110 |
+
- libcusolver-dev=11.4.4.55=0
|
111 |
+
- libcusparse=12.0.2.55=0
|
112 |
+
- libcusparse-dev=12.0.2.55=0
|
113 |
+
- libdeflate=1.19=hcfcfb64_0
|
114 |
+
- libevent=2.1.12=h3671451_1
|
115 |
+
- libexpat=2.5.0=h63175ca_1
|
116 |
+
- libffi=3.4.2=h8ffe710_5
|
117 |
+
- libglib=2.78.3=h16e383f_0
|
118 |
+
- libgoogle-cloud=2.12.0=h39f2fc6_4
|
119 |
+
- libgrpc=1.59.3=h5bbd4a7_0
|
120 |
+
- libhwloc=2.9.3=default_haede6df_1009
|
121 |
+
- libiconv=1.17=hcfcfb64_2
|
122 |
+
- libjpeg-turbo=3.0.0=hcfcfb64_1
|
123 |
+
- liblapack=3.9.0=20_win64_mkl
|
124 |
+
- liblapacke=3.9.0=20_win64_mkl
|
125 |
+
- libnpp=12.0.2.50=0
|
126 |
+
- libnpp-dev=12.0.2.50=0
|
127 |
+
- libnvjitlink=12.1.105=0
|
128 |
+
- libnvjitlink-dev=12.1.105=0
|
129 |
+
- libnvjpeg=12.1.1.14=0
|
130 |
+
- libnvjpeg-dev=12.1.1.14=0
|
131 |
+
- libogg=1.3.4=h8ffe710_1
|
132 |
+
- libparquet=14.0.2=h7ec3a38_2_cpu
|
133 |
+
- libpng=1.6.39=h19919ed_0
|
134 |
+
- libprotobuf=4.24.4=hb8276f3_0
|
135 |
+
- libre2-11=2023.06.02=h8c5ae5e_0
|
136 |
+
- libsodium=1.0.18=h8d14728_1
|
137 |
+
- libsqlite=3.44.2=hcfcfb64_0
|
138 |
+
- libssh2=1.11.0=h7dfc565_0
|
139 |
+
- libthrift=0.19.0=ha2b3283_1
|
140 |
+
- libtiff=4.6.0=h6e2ebb7_2
|
141 |
+
- libutf8proc=2.8.0=h82a8f57_0
|
142 |
+
- libuv=1.44.2=hcfcfb64_1
|
143 |
+
- libvorbis=1.3.7=h0e60522_0
|
144 |
+
- libwebp-base=1.3.2=hcfcfb64_0
|
145 |
+
- libxcb=1.15=hcd874cb_0
|
146 |
+
- libxml2=2.11.6=hc3477c8_0
|
147 |
+
- libzlib=1.2.13=hcfcfb64_5
|
148 |
+
- lz4-c=1.9.4=hcfcfb64_0
|
149 |
+
- m2w64-gcc-libgfortran=5.3.0=6
|
150 |
+
- m2w64-gcc-libs=5.3.0=7
|
151 |
+
- m2w64-gcc-libs-core=5.3.0=7
|
152 |
+
- m2w64-gmp=6.1.0=2
|
153 |
+
- m2w64-libwinpthread-git=5.0.0.4634.697f757=2
|
154 |
+
- markupsafe=2.1.3=py311ha68e1ae_1
|
155 |
+
- matplotlib=3.8.2=py311h1ea47a8_0
|
156 |
+
- matplotlib-base=3.8.2=py311h6e989c2_0
|
157 |
+
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
|
158 |
+
- mkl=2023.2.0=h6a75c08_50497
|
159 |
+
- mkl-devel=2023.2.0=h57928b3_50497
|
160 |
+
- mkl-include=2023.2.0=h6a75c08_50497
|
161 |
+
- mpmath=1.3.0=pyhd8ed1ab_0
|
162 |
+
- msys2-conda-epoch=20160418=1
|
163 |
+
- multidict=6.0.4=py311ha68e1ae_1
|
164 |
+
- multiprocess=0.70.15=py311ha68e1ae_1
|
165 |
+
- munkres=1.1.4=pyh9f0ad1d_0
|
166 |
+
- nest-asyncio=1.5.8=pyhd8ed1ab_0
|
167 |
+
- networkx=3.2.1=pyhd8ed1ab_0
|
168 |
+
- numpy=1.26.3=py311h0b4df5a_0
|
169 |
+
- openjpeg=2.5.0=h3d672ee_3
|
170 |
+
- openssl=3.2.0=hcfcfb64_1
|
171 |
+
- orc=1.9.2=hf0b6bd4_0
|
172 |
+
- packaging=23.2=pyhd8ed1ab_0
|
173 |
+
- pandas=2.1.4=py311hf63dbb6_0
|
174 |
+
- parso=0.8.3=pyhd8ed1ab_0
|
175 |
+
- pcre2=10.42=h17e33f8_0
|
176 |
+
- pickleshare=0.7.5=py_1003
|
177 |
+
- pillow=10.2.0=py311h4dd8a23_0
|
178 |
+
- pip=23.3.2=pyhd8ed1ab_0
|
179 |
+
- platformdirs=4.1.0=pyhd8ed1ab_0
|
180 |
+
- ply=3.11=py_1
|
181 |
+
- prompt-toolkit=3.0.42=pyha770c72_0
|
182 |
+
- psutil=5.9.7=py311ha68e1ae_0
|
183 |
+
- pthread-stubs=0.4=hcd874cb_1001
|
184 |
+
- pthreads-win32=2.9.1=hfa6e2cd_3
|
185 |
+
- pure_eval=0.2.2=pyhd8ed1ab_0
|
186 |
+
- pyarrow=14.0.2=py311h6a6099b_2_cpu
|
187 |
+
- pyarrow-hotfix=0.6=pyhd8ed1ab_0
|
188 |
+
- pygments=2.17.2=pyhd8ed1ab_0
|
189 |
+
- pyparsing=3.1.1=pyhd8ed1ab_0
|
190 |
+
- pyqt=5.15.9=py311h125bc19_5
|
191 |
+
- pyqt5-sip=12.12.2=py311h12c1d0e_5
|
192 |
+
- pysocks=1.7.1=pyh0701188_6
|
193 |
+
- python=3.11.7=h2628c8c_1_cpython
|
194 |
+
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
195 |
+
- python-tzdata=2023.4=pyhd8ed1ab_0
|
196 |
+
- python-xxhash=3.4.1=py311ha68e1ae_0
|
197 |
+
- python_abi=3.11=4_cp311
|
198 |
+
- pytorch=2.1.2=py3.11_cuda12.1_cudnn8_0
|
199 |
+
- pytorch-cuda=12.1=hde6ce7c_5
|
200 |
+
- pytorch-mutex=1.0=cuda
|
201 |
+
- pytz=2023.3.post1=pyhd8ed1ab_0
|
202 |
+
- pywin32=306=py311h12c1d0e_2
|
203 |
+
- pyyaml=6.0.1=py311ha68e1ae_1
|
204 |
+
- pyzmq=25.1.2=py311h9250fbb_0
|
205 |
+
- qt-main=5.15.8=h9e85ed6_18
|
206 |
+
- re2=2023.06.02=hcbb65ff_0
|
207 |
+
- regex=2023.12.25=py311ha68e1ae_0
|
208 |
+
- requests=2.31.0=pyhd8ed1ab_0
|
209 |
+
- safetensors=0.3.3=py311hc37eb10_1
|
210 |
+
- scikit-learn=1.3.2=py311h142b183_2
|
211 |
+
- scipy=1.11.4=py311h0b4df5a_0
|
212 |
+
- setuptools=69.0.3=pyhd8ed1ab_0
|
213 |
+
- sip=6.7.12=py311h12c1d0e_0
|
214 |
+
- six=1.16.0=pyh6c4a22f_0
|
215 |
+
- snappy=1.1.10=hfb803bf_0
|
216 |
+
- stack_data=0.6.2=pyhd8ed1ab_0
|
217 |
+
- sympy=1.12=pyh04b8f61_3
|
218 |
+
- tbb=2021.11.0=h91493d7_0
|
219 |
+
- threadpoolctl=3.2.0=pyha21a80b_0
|
220 |
+
- tk=8.6.13=h5226925_1
|
221 |
+
- tokenizers=0.15.0=py311h91c4a10_1
|
222 |
+
- toml=0.10.2=pyhd8ed1ab_0
|
223 |
+
- tomli=2.0.1=pyhd8ed1ab_0
|
224 |
+
- tornado=6.3.3=py311ha68e1ae_1
|
225 |
+
- tqdm=4.66.1=pyhd8ed1ab_0
|
226 |
+
- traitlets=5.14.1=pyhd8ed1ab_0
|
227 |
+
- transformers=4.36.2=pyhd8ed1ab_0
|
228 |
+
- typing-extensions=4.9.0=hd8ed1ab_0
|
229 |
+
- typing_extensions=4.9.0=pyha770c72_0
|
230 |
+
- tzdata=2023d=h0c530f3_0
|
231 |
+
- ucrt=10.0.22621.0=h57928b3_0
|
232 |
+
- urllib3=2.1.0=pyhd8ed1ab_0
|
233 |
+
- vc=14.3=hcf57466_18
|
234 |
+
- vc14_runtime=14.38.33130=h82b7239_18
|
235 |
+
- vs2015_runtime=14.38.33130=hcb4865c_18
|
236 |
+
- wcwidth=0.2.12=pyhd8ed1ab_0
|
237 |
+
- wheel=0.42.0=pyhd8ed1ab_0
|
238 |
+
- win_inet_pton=1.1.0=pyhd8ed1ab_6
|
239 |
+
- xorg-libxau=1.0.11=hcd874cb_0
|
240 |
+
- xorg-libxdmcp=1.1.3=hcd874cb_0
|
241 |
+
- xxhash=0.8.2=hcfcfb64_0
|
242 |
+
- xz=5.2.6=h8d14728_0
|
243 |
+
- yaml=0.2.5=h8ffe710_2
|
244 |
+
- yarl=1.9.3=py311ha68e1ae_0
|
245 |
+
- zeromq=4.3.5=h63175ca_0
|
246 |
+
- zipp=3.17.0=pyhd8ed1ab_0
|
247 |
+
- zstd=1.5.5=h12be248_0
|
248 |
+
- pip:
|
249 |
+
- aiofiles==23.2.1
|
250 |
+
- altair==5.2.0
|
251 |
+
- annotated-types==0.6.0
|
252 |
+
- anyio==4.2.0
|
253 |
+
- click==8.1.7
|
254 |
+
- fastapi==0.108.0
|
255 |
+
- ffmpy==0.3.1
|
256 |
+
- gradio==4.13.0
|
257 |
+
- gradio-client==0.8.0
|
258 |
+
- h11==0.14.0
|
259 |
+
- httpcore==1.0.2
|
260 |
+
- httpx==0.26.0
|
261 |
+
- importlib-resources==6.1.1
|
262 |
+
- jsonschema==4.20.0
|
263 |
+
- jsonschema-specifications==2023.12.1
|
264 |
+
- markdown-it-py==3.0.0
|
265 |
+
- mdurl==0.1.2
|
266 |
+
- orjson==3.9.10
|
267 |
+
- pydantic==2.5.3
|
268 |
+
- pydantic-core==2.14.6
|
269 |
+
- pydub==0.25.1
|
270 |
+
- python-multipart==0.0.6
|
271 |
+
- referencing==0.32.1
|
272 |
+
- rich==13.7.0
|
273 |
+
- rpds-py==0.16.2
|
274 |
+
- semantic-version==2.10.0
|
275 |
+
- shellingham==1.5.4
|
276 |
+
- sniffio==1.3.0
|
277 |
+
- starlette==0.32.0.post1
|
278 |
+
- tomlkit==0.12.0
|
279 |
+
- toolz==0.12.0
|
280 |
+
- torchaudio==2.1.2
|
281 |
+
- torchvision==0.16.2
|
282 |
+
- typer==0.9.0
|
283 |
+
- uvicorn==0.25.0
|
284 |
+
- websockets==11.0.3
|
285 |
+
|
notebooks/embeddings.ipynb
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 4,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from transformers import BertTokenizer, BertModel\n",
|
10 |
+
"import torch\n",
|
11 |
+
"from pprint import pprint"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 5,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [
|
19 |
+
{
|
20 |
+
"name": "stderr",
|
21 |
+
"output_type": "stream",
|
22 |
+
"text": [
|
23 |
+
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n",
|
24 |
+
"The tokenizer class you load from this checkpoint is 'BertJapaneseTokenizer'. \n",
|
25 |
+
"The class this function is called from is 'BertTokenizer'.\n"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"name": "stdout",
|
30 |
+
"output_type": "stream",
|
31 |
+
"text": [
|
32 |
+
"{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]]),\n",
|
33 |
+
" 'input_ids': tensor([[ 2, 73, 371, 37, 1541, 546, 3]]),\n",
|
34 |
+
" 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]])}\n",
|
35 |
+
"torch.Size([1, 7, 768])\n"
|
36 |
+
]
|
37 |
+
}
|
38 |
+
],
|
39 |
+
"source": [
|
40 |
+
"# 日本語の事前学習済みモデルとトークナイザーの読み込み\n",
|
41 |
+
"tokenizer = BertTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')\n",
|
42 |
+
"model = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')\n",
|
43 |
+
"\n",
|
44 |
+
"# テキストをトークン化し、PyTorchテンソルに変換\n",
|
45 |
+
"text = \"お正月休み\"\n",
|
46 |
+
"encoded_input = tokenizer(text, return_tensors='pt')\n",
|
47 |
+
"pprint(encoded_input)\n",
|
48 |
+
"\n",
|
49 |
+
"# 単語埋め込みを取得\n",
|
50 |
+
"with torch.no_grad():\n",
|
51 |
+
" output = model(**encoded_input)\n",
|
52 |
+
" embeddings = output.last_hidden_state\n",
|
53 |
+
" pprint(embeddings.shape)"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": null,
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [],
|
61 |
+
"source": []
|
62 |
+
}
|
63 |
+
],
|
64 |
+
"metadata": {
|
65 |
+
"kernelspec": {
|
66 |
+
"display_name": "chiikawa-yonezu",
|
67 |
+
"language": "python",
|
68 |
+
"name": "python3"
|
69 |
+
},
|
70 |
+
"language_info": {
|
71 |
+
"codemirror_mode": {
|
72 |
+
"name": "ipython",
|
73 |
+
"version": 3
|
74 |
+
},
|
75 |
+
"file_extension": ".py",
|
76 |
+
"mimetype": "text/x-python",
|
77 |
+
"name": "python",
|
78 |
+
"nbconvert_exporter": "python",
|
79 |
+
"pygments_lexer": "ipython3",
|
80 |
+
"version": "3.11.7"
|
81 |
+
}
|
82 |
+
},
|
83 |
+
"nbformat": 4,
|
84 |
+
"nbformat_minor": 2
|
85 |
+
}
|
notebooks/preprocessing.ipynb
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 15,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import csv\n",
|
10 |
+
"import jaconv\n",
|
11 |
+
"import re\n",
|
12 |
+
"\n",
|
13 |
+
"def preprocess(csv_path: str, preprocessed_csv_path: str):\n",
|
14 |
+
" \"\"\"\n",
|
15 |
+
" 与えられたCSVファイルを読み込んだのち、以下の処理をしてから、{CSVファイル名}_preprocessed.csvとして保存する\n",
|
16 |
+
" CSVファイルのフォーマットは\"TEXT,LABEL\"の2列である。\n",
|
17 |
+
"\n",
|
18 |
+
" TEXTの変換ルールは次の通り。\n",
|
19 |
+
" 1. 文字列が半角・全角スペース・改行を含む場合、その文字列を複数の文字列に分割する\n",
|
20 |
+
" 2. 記号(!,?,!,?,・,.,…,',\",♪,♫)と全ての絵文字を削除する\n",
|
21 |
+
" 3. ()または()で囲まれた文字列を削除する\n",
|
22 |
+
" 4. 半角カタカナを全角カタカナに、~を~に、-をーに変換する\n",
|
23 |
+
" 5. 2つ以上連続する~~を~に、ーーをーに変換する\n",
|
24 |
+
" 6. 空文字列を削除する\n",
|
25 |
+
"\n",
|
26 |
+
" 保存する前にフィルタリングを行う。\n",
|
27 |
+
" 1. TEXTが空文字列の行を削除する\n",
|
28 |
+
" 2. TEXTとLABELの組み合わせが重複している行を削除する\n",
|
29 |
+
" \"\"\"\n",
|
30 |
+
" # Read the CSV file\n",
|
31 |
+
" with open(csv_path, 'r', encoding='utf-8') as file:\n",
|
32 |
+
" reader = csv.reader(file)\n",
|
33 |
+
" data = list(reader)\n",
|
34 |
+
" \n",
|
35 |
+
" preprocessed_data = []\n",
|
36 |
+
"\n",
|
37 |
+
" # Preprocess the TEXT column\n",
|
38 |
+
" for i in range(len(data)):\n",
|
39 |
+
" text, label = data[i]\n",
|
40 |
+
" # Split the text into multiple strings if it contains spaces or newlines\n",
|
41 |
+
" text = re.split(r'\\s+', text)\n",
|
42 |
+
" # Remove symbols\n",
|
43 |
+
" text = [re.sub(r'[!?!?・.…\\'\"’”\\♪♫]', '', word) for word in text]\n",
|
44 |
+
" # Remove strings enclosed in parentheses\n",
|
45 |
+
" text = [re.sub(r'\\(.*?\\)|(.*?)', '', word) for word in text]\n",
|
46 |
+
" # Convert half-width katakana to full-width katakana\n",
|
47 |
+
" text = [jaconv.h2z(word) for word in text]\n",
|
48 |
+
" # Convert ~ to ~ and - to ー\n",
|
49 |
+
" # Note: 〜(U+301C) is a different character from ~(U+FF5E\n",
|
50 |
+
" text = [re.sub(r'[~〜]', '~', word) for word in text]\n",
|
51 |
+
" text = [re.sub(r'-', 'ー', word) for word in text]\n",
|
52 |
+
" # Convert multiple consecutive ~ to ~ and ーー to ー\n",
|
53 |
+
" text = [re.sub(r'~+', '~', word) for word in text]\n",
|
54 |
+
" text = [re.sub(r'ー+', 'ー', word) for word in text]\n",
|
55 |
+
" \n",
|
56 |
+
" [preprocessed_data.append([word, label]) for word in text if word != '' ]\n",
|
57 |
+
"\n",
|
58 |
+
" # Remove duplicate rows based on TEXT and LABEL combination\n",
|
59 |
+
" preprocessed_data = [list(x) for x in set(tuple(x) for x in preprocessed_data)]\n",
|
60 |
+
"\n",
|
61 |
+
" # Sort the data by LABEL, TEXT\n",
|
62 |
+
" preprocessed_data.sort(key=lambda x: (x[1], x[0]))\n",
|
63 |
+
"\n",
|
64 |
+
" # Save the preprocessed data to a new CSV file\n",
|
65 |
+
" with open(preprocessed_csv_path, 'w', encoding='utf-8', newline='') as file:\n",
|
66 |
+
" writer = csv.writer(file)\n",
|
67 |
+
" writer.writerows(preprocessed_data)\n"
|
68 |
+
]
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"cell_type": "code",
|
72 |
+
"execution_count": 16,
|
73 |
+
"metadata": {},
|
74 |
+
"outputs": [],
|
75 |
+
"source": [
|
76 |
+
"import pandas as pd\n",
|
77 |
+
"from sklearn.model_selection import train_test_split\n",
|
78 |
+
"\n",
|
79 |
+
"\n",
|
80 |
+
"def split(csv_path: str):\n",
|
81 |
+
" # 元のCSVファイルを読み込む\n",
|
82 |
+
" df = pd.read_csv(csv_path, encoding='utf-8')\n",
|
83 |
+
"\n",
|
84 |
+
" # 訓練用データセットとテスト用データセットに分割\n",
|
85 |
+
" train_df, test_df = train_test_split(df, test_size=0.05) # 高速化のため検証データの数を減らす\n",
|
86 |
+
"\n",
|
87 |
+
" # 新しいCSVファイルとして保存\n",
|
88 |
+
" train_df.to_csv(csv_path.replace('.csv', '_train.csv'), index=False)\n",
|
89 |
+
" test_df.to_csv(csv_path.replace('.csv', '_test.csv'), index=False)\n"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"execution_count": 17,
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [],
|
97 |
+
"source": [
|
98 |
+
"csv_path = '../data/data.csv'\n",
|
99 |
+
"preprocessed_csv_path = csv_path.replace('.csv', '_preprocessed.csv')\n",
|
100 |
+
"preprocess(csv_path, preprocessed_csv_path)\n",
|
101 |
+
"split(preprocessed_csv_path)"
|
102 |
+
]
|
103 |
+
}
|
104 |
+
],
|
105 |
+
"metadata": {
|
106 |
+
"kernelspec": {
|
107 |
+
"display_name": "chiikawa-yonezu",
|
108 |
+
"language": "python",
|
109 |
+
"name": "python3"
|
110 |
+
},
|
111 |
+
"language_info": {
|
112 |
+
"codemirror_mode": {
|
113 |
+
"name": "ipython",
|
114 |
+
"version": 3
|
115 |
+
},
|
116 |
+
"file_extension": ".py",
|
117 |
+
"mimetype": "text/x-python",
|
118 |
+
"name": "python",
|
119 |
+
"nbconvert_exporter": "python",
|
120 |
+
"pygments_lexer": "ipython3",
|
121 |
+
"version": "3.11.7"
|
122 |
+
}
|
123 |
+
},
|
124 |
+
"nbformat": 4,
|
125 |
+
"nbformat_minor": 2
|
126 |
+
}
|
notebooks/train.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
safetensors
|
3 |
+
torch
|
4 |
+
transformers
|
utils/ClassifierModel.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import BertModel
|
3 |
+
|
4 |
+
|
5 |
+
class ClassifierModel(torch.nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super(ClassifierModel, self).__init__()
|
8 |
+
self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
|
9 |
+
self.linear = torch.nn.Linear(768, 2) # BERTの隠れ層の次元数と出力クラス数
|
10 |
+
|
11 |
+
def forward(self, input_ids, attention_mask):
|
12 |
+
with torch.no_grad():
|
13 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
14 |
+
last_hidden_state = outputs[0]
|
15 |
+
pooled_output = last_hidden_state[:, 0]
|
16 |
+
return self.linear(pooled_output)
|