Spaces:
Runtime error
Runtime error
kenichiro
commited on
Commit
·
46a030d
1
Parent(s):
b615e10
commit
Browse files- LICENSE +201 -0
- README.md +1 -1
- __pycache__/chat.cpython-38.pyc +0 -0
- __pycache__/functionforDownloadButtons.cpython-36.pyc +0 -0
- __pycache__/functionforDownloadButtons.cpython-38.pyc +0 -0
- __pycache__/model.cpython-36.pyc +0 -0
- __pycache__/model.cpython-38.pyc +0 -0
- __pycache__/model2.cpython-36.pyc +0 -0
- __pycache__/run_segbot.cpython-36.pyc +0 -0
- __pycache__/run_segbot.cpython-38.pyc +0 -0
- __pycache__/solver.cpython-36.pyc +0 -0
- __pycache__/solver.cpython-38.pyc +0 -0
- __pycache__/solver2.cpython-36.pyc +0 -0
- app.py +117 -14
- credata.py +653 -0
- fm.pickle +3 -0
- functionforDownloadButtons.py +171 -0
- logo.png +0 -0
- model.py +465 -0
- requirements.txt +7 -0
- run_segbot.py +61 -94
- solver.py +16 -282
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Clinical Segnemt
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: yellow
|
6 |
sdk: streamlit
|
|
|
1 |
---
|
2 |
title: Clinical Segnemt
|
3 |
+
emoji: 🌖
|
4 |
colorFrom: purple
|
5 |
colorTo: yellow
|
6 |
sdk: streamlit
|
__pycache__/chat.cpython-38.pyc
DELETED
Binary file (1.46 kB)
|
|
__pycache__/functionforDownloadButtons.cpython-36.pyc
ADDED
Binary file (4.54 kB). View file
|
|
__pycache__/functionforDownloadButtons.cpython-38.pyc
ADDED
Binary file (4.59 kB). View file
|
|
__pycache__/model.cpython-36.pyc
ADDED
Binary file (7.24 kB). View file
|
|
__pycache__/model.cpython-38.pyc
ADDED
Binary file (7.26 kB). View file
|
|
__pycache__/model2.cpython-36.pyc
ADDED
Binary file (7.04 kB). View file
|
|
__pycache__/run_segbot.cpython-36.pyc
ADDED
Binary file (1.9 kB). View file
|
|
__pycache__/run_segbot.cpython-38.pyc
ADDED
Binary file (1.9 kB). View file
|
|
__pycache__/solver.cpython-36.pyc
ADDED
Binary file (4.81 kB). View file
|
|
__pycache__/solver.cpython-38.pyc
ADDED
Binary file (4.83 kB). View file
|
|
__pycache__/solver2.cpython-36.pyc
ADDED
Binary file (4.43 kB). View file
|
|
app.py
CHANGED
@@ -1,19 +1,122 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
app = Flask(__name__)
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
@app.post("/predict")
|
12 |
-
def predict():
|
13 |
-
text = request.get_json().get("message")
|
14 |
-
response = get_response(text)
|
15 |
-
message = {"answer": response}
|
16 |
-
return jsonify(message)
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
from pandas import DataFrame
|
4 |
+
import run_segbot
|
5 |
+
from functionforDownloadButtons import download_button
|
6 |
+
import os
|
7 |
+
import json
|
8 |
|
9 |
+
st.set_page_config(
|
10 |
+
page_title="Clinical segment generater",
|
11 |
+
page_icon="🚑",
|
12 |
+
layout="wide"
|
13 |
+
)
|
14 |
|
|
|
15 |
|
16 |
+
def _max_width_():
|
17 |
+
max_width_str = f"max-width: 1400px;"
|
18 |
+
st.markdown(
|
19 |
+
f"""
|
20 |
+
<style>
|
21 |
+
.reportview-container .main .block-container{{
|
22 |
+
{max_width_str}
|
23 |
+
}}
|
24 |
+
</style>
|
25 |
+
""",
|
26 |
+
unsafe_allow_html=True,
|
27 |
+
)
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
#_max_width_()
|
31 |
+
|
32 |
+
#c30 = st.columns([1,])
|
33 |
+
|
34 |
+
#with c30:
|
35 |
+
# st.image("logo.png", width=400)
|
36 |
+
st.title("🚑 Clinical segment generater")
|
37 |
+
st.header("")
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
with st.expander("ℹ️ - About this app", expanded=True):
|
42 |
+
|
43 |
+
st.write(
|
44 |
+
"""
|
45 |
+
- The *Clinical segment generater* app is an implementation of [our paper](https://journals.plos.org/digitalhealth/article?id=10.1371/journal.pdig.0000099).
|
46 |
+
- It automatically splits Japanese sentences into smaller units representing medical meanings.
|
47 |
+
"""
|
48 |
+
)
|
49 |
+
|
50 |
+
st.markdown("")
|
51 |
+
|
52 |
+
st.markdown("")
|
53 |
+
st.markdown("## 📌 Paste document")
|
54 |
+
@st.cache(allow_output_mutation=True)
|
55 |
+
def model_load():
|
56 |
+
return run_segbot.setup()
|
57 |
+
model,fm,index = model_load()
|
58 |
+
with st.form(key="my_form"):
|
59 |
+
|
60 |
+
|
61 |
+
ce, c1, ce, c2, c3 = st.columns([0.07, 1, 0.07, 5, 0.07])
|
62 |
+
with c1:
|
63 |
+
ModelType = st.radio(
|
64 |
+
"Choose the method of sentence split",
|
65 |
+
["fullstop & linebreak (Default)", "pySBD"],
|
66 |
+
help="""
|
67 |
+
At present, you can choose between 2 methods to split your text into sentences.
|
68 |
+
|
69 |
+
The fullstop & linebreak is naive and robust to noise, but has low accuracy.
|
70 |
+
pySBD is more accurate, but more complex and less robust to noise.
|
71 |
+
""",
|
72 |
+
)
|
73 |
+
|
74 |
+
if ModelType == "fullstop & linebreak (Default)":
|
75 |
+
split_method="fullstop"
|
76 |
+
|
77 |
+
else:
|
78 |
+
split_method="pySBD"
|
79 |
+
|
80 |
+
|
81 |
+
with c2:
|
82 |
+
doc = st.text_area(
|
83 |
+
"Paste your text below",
|
84 |
+
height=510,
|
85 |
+
)
|
86 |
+
|
87 |
+
submit_button = st.form_submit_button(label="👍 Go to split!")
|
88 |
+
|
89 |
+
|
90 |
+
if not submit_button:
|
91 |
+
st.stop()
|
92 |
+
|
93 |
+
keywords = run_segbot.generate(doc, model, fm, index, split_method)
|
94 |
+
|
95 |
+
|
96 |
+
st.markdown("## 🎈 Check & download results")
|
97 |
+
|
98 |
+
st.header("")
|
99 |
+
|
100 |
+
|
101 |
+
cs, c1, c2, c3, cLast = st.columns([2, 1.5, 1.5, 1.5, 2])
|
102 |
+
|
103 |
+
with c1:
|
104 |
+
CSVButton2 = download_button(keywords, "Data.csv", "📥 Download (.csv)")
|
105 |
+
with c2:
|
106 |
+
CSVButton2 = download_button(keywords, "Data.txt", "📥 Download (.txt)")
|
107 |
+
with c3:
|
108 |
+
CSVButton2 = download_button(keywords, "Data.json", "📥 Download (.json)")
|
109 |
+
|
110 |
+
st.header("")
|
111 |
+
|
112 |
+
#df = DataFrame(keywords, columns=["Keyword/Keyphrase", "Relevancy"])
|
113 |
+
df = DataFrame(keywords)
|
114 |
+
df.index += 1
|
115 |
+
df.columns = ['Segment']
|
116 |
+
print(df)
|
117 |
+
# Add styling
|
118 |
+
|
119 |
+
#c1, c2, c3 = st.columns([1, 3, 1])
|
120 |
+
|
121 |
+
#with c2:
|
122 |
+
st.table(df)
|
credata.py
ADDED
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gensim
|
2 |
+
import MeCab
|
3 |
+
import pickle
|
4 |
+
from gensim.models.wrappers.fasttext import FastText
|
5 |
+
#import fasttext as ft
|
6 |
+
import random
|
7 |
+
import mojimoji
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
def ymyi(lis):
|
12 |
+
wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
|
13 |
+
|
14 |
+
with open('fm_space.pickle', 'rb') as f:
|
15 |
+
fm = pickle.load(f)
|
16 |
+
#model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
|
17 |
+
model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
|
18 |
+
texts = []
|
19 |
+
sent = ""
|
20 |
+
sparate = []
|
21 |
+
label = []
|
22 |
+
ruiseki = 0
|
23 |
+
ruiseki2 = 0
|
24 |
+
alls = []
|
25 |
+
labels, text, num = [], [], []
|
26 |
+
for n, line in enumerate(open(lis)):
|
27 |
+
line = line.strip("\t").rstrip("\n")
|
28 |
+
#print(line)
|
29 |
+
if line == "":
|
30 |
+
if sent == "":
|
31 |
+
continue
|
32 |
+
sent = wakati.parse(sent).split(" ")[:-1]
|
33 |
+
flag = 0
|
34 |
+
for i in sent:
|
35 |
+
for j in sparate:
|
36 |
+
if ruiseki+len(i) > j and ruiseki < j:
|
37 |
+
label.append(1)
|
38 |
+
flag = 1
|
39 |
+
elif ruiseki+len(i) == j:
|
40 |
+
label.append(1)
|
41 |
+
flag = 1
|
42 |
+
if flag == 0:
|
43 |
+
label.append(0)
|
44 |
+
flag = 0
|
45 |
+
ruiseki += len(i)
|
46 |
+
#texts += i + " "
|
47 |
+
try:
|
48 |
+
texts.append(model[i])
|
49 |
+
#texts.append(np.array(fm.vocab[i]))
|
50 |
+
#texts += str(fm.vocab[i].index) + " "
|
51 |
+
#print(i,str(fm.vocab[i].index))
|
52 |
+
except KeyError:
|
53 |
+
texts.append(fm["<unk>"])
|
54 |
+
label[-1] = 1
|
55 |
+
#texts = texts.rstrip() + "\t"
|
56 |
+
#texts += " ".join(label) + "\n"
|
57 |
+
#alls.append((n,texts,label))
|
58 |
+
labels.append(label)
|
59 |
+
text.append(texts)
|
60 |
+
num.append(n)
|
61 |
+
sent = ""
|
62 |
+
sparate = []
|
63 |
+
texts = []
|
64 |
+
label = []
|
65 |
+
ruiseki = 0
|
66 |
+
ruiseki2 = 0
|
67 |
+
continue
|
68 |
+
sent += mojimoji.han_to_zen(line, digit=False, ascii=False)
|
69 |
+
ruiseki2 += len(line)
|
70 |
+
sparate.append(ruiseki2)
|
71 |
+
return num,text,labels
|
72 |
+
|
73 |
+
def nmni(lis):
|
74 |
+
#wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
|
75 |
+
wakati = MeCab.Tagger("-Owakati -b 81920")
|
76 |
+
|
77 |
+
with open('fm_space.pickle', 'rb') as f:
|
78 |
+
fm = pickle.load(f)
|
79 |
+
#model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
|
80 |
+
#model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
|
81 |
+
texts = []
|
82 |
+
sent = ""
|
83 |
+
sparate = []
|
84 |
+
label = []
|
85 |
+
ruiseki = 0
|
86 |
+
ruiseki2 = 0
|
87 |
+
alls = []
|
88 |
+
labels, text, num = [], [], []
|
89 |
+
for n, line in enumerate(open(lis)):
|
90 |
+
line = line.strip("\t").rstrip("\n")
|
91 |
+
#print(line)
|
92 |
+
if line == "":
|
93 |
+
if sent == "":
|
94 |
+
continue
|
95 |
+
sent = wakati.parse(sent).split(" ")[:-1]
|
96 |
+
flag = 0
|
97 |
+
for i in sent:
|
98 |
+
for j in sparate:
|
99 |
+
if ruiseki+len(i) > j and ruiseki < j:
|
100 |
+
label.append(1)
|
101 |
+
flag = 1
|
102 |
+
elif ruiseki+len(i) == j:
|
103 |
+
label.append(1)
|
104 |
+
flag = 1
|
105 |
+
if flag == 0:
|
106 |
+
label.append(0)
|
107 |
+
flag = 0
|
108 |
+
ruiseki += len(i)
|
109 |
+
#texts += i + " "
|
110 |
+
try:
|
111 |
+
#texts.append(model[i])
|
112 |
+
texts.append(fm[i])
|
113 |
+
#texts += str(fm.vocab[i].index) + " "
|
114 |
+
#print(i,str(fm.vocab[i].index))
|
115 |
+
except KeyError:
|
116 |
+
texts.append(fm["<unk>"])
|
117 |
+
label[-1] = 1
|
118 |
+
#texts = texts.rstrip() + "\t"
|
119 |
+
#texts += " ".join(label) + "\n"
|
120 |
+
#alls.append((n,texts,label))
|
121 |
+
labels.append(label)
|
122 |
+
text.append(texts)
|
123 |
+
num.append(n)
|
124 |
+
sent = ""
|
125 |
+
sparate = []
|
126 |
+
texts = []
|
127 |
+
label = []
|
128 |
+
ruiseki = 0
|
129 |
+
ruiseki2 = 0
|
130 |
+
continue
|
131 |
+
sent += mojimoji.han_to_zen(line, digit=False, ascii=False)
|
132 |
+
ruiseki2 += len(line)
|
133 |
+
sparate.append(ruiseki2)
|
134 |
+
return num,text,labels
|
135 |
+
|
136 |
+
def nmni_finetune(lis):
|
137 |
+
#wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
|
138 |
+
wakati = MeCab.Tagger("-Owakati -b 81920")
|
139 |
+
#fm = gensim.models.KeyedVectors.load_word2vec_format('/clwork/ando/SEGBOT/cc.ja.300.vec', binary=False)
|
140 |
+
with open('fm.pickle', 'rb') as f:
|
141 |
+
fm = pickle.load(f)
|
142 |
+
#fm = gensim.models.KeyedVectors.load_word2vec_format('cc.ja.300.vec', binary=False)
|
143 |
+
#with open('fm.pickle', 'wb') as f:
|
144 |
+
# pickle.dump(fm, f)
|
145 |
+
#model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
|
146 |
+
#model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
|
147 |
+
texts = []
|
148 |
+
sent = ""
|
149 |
+
sparate = []
|
150 |
+
label = []
|
151 |
+
ruiseki = 0
|
152 |
+
ruiseki2 = 0
|
153 |
+
alls = []
|
154 |
+
labels, text, num = [], [], []
|
155 |
+
for n, line in enumerate(open(lis)):
|
156 |
+
line = line.strip("\t").rstrip("\n")
|
157 |
+
#print(line)
|
158 |
+
if line == "":
|
159 |
+
if sent == "":
|
160 |
+
continue
|
161 |
+
sent = wakati.parse(sent).split(" ")[:-1]
|
162 |
+
flag = 0
|
163 |
+
for i in sent:
|
164 |
+
for j in sparate:
|
165 |
+
if ruiseki+len(i) > j and ruiseki < j:
|
166 |
+
label.append(1)
|
167 |
+
flag = 1
|
168 |
+
elif ruiseki+len(i) == j:
|
169 |
+
label.append(1)
|
170 |
+
flag = 1
|
171 |
+
if flag == 0:
|
172 |
+
label.append(0)
|
173 |
+
flag = 0
|
174 |
+
ruiseki += len(i)
|
175 |
+
#texts += i + " "
|
176 |
+
try:
|
177 |
+
#texts.append(model[i])
|
178 |
+
#texts.append(fm[i])
|
179 |
+
texts.append(fm.vocab[i].index)
|
180 |
+
#print(i,str(fm.vocab[i].index))
|
181 |
+
except KeyError:
|
182 |
+
texts.append(fm.vocab["<unk>"].index)
|
183 |
+
label[-1] = 1
|
184 |
+
#texts = texts.rstrip() + "\t"
|
185 |
+
#texts += " ".join(label) + "\n"
|
186 |
+
#alls.append((n,texts,label))
|
187 |
+
labels.append(np.array(label))
|
188 |
+
text.append(np.array(texts))
|
189 |
+
num.append(n)
|
190 |
+
sent = ""
|
191 |
+
sparate = []
|
192 |
+
texts = []
|
193 |
+
label = []
|
194 |
+
ruiseki = 0
|
195 |
+
ruiseki2 = 0
|
196 |
+
continue
|
197 |
+
sent += mojimoji.han_to_zen(line, digit=False, ascii=False)
|
198 |
+
ruiseki2 += len(line)
|
199 |
+
sparate.append(ruiseki2)
|
200 |
+
return text,labels
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
def nmni_carte(lis):
|
205 |
+
#wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
|
206 |
+
wakati = MeCab.Tagger("-Owakati -b 81920")
|
207 |
+
#fm = gensim.models.KeyedVectors.load_word2vec_format('/clwork/ando/SEGBOT/cc.ja.300.vec', binary=False)
|
208 |
+
#fm = gensim.models.KeyedVectors.load_word2vec_format('cc.ja.300.vec', binary=False)
|
209 |
+
#with open('fm.pickle', 'wb') as f:
|
210 |
+
# pickle.dump(fm, f)
|
211 |
+
#model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
|
212 |
+
#model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
|
213 |
+
with open('fm.pickle', 'rb') as f:
|
214 |
+
fm = pickle.load(f)
|
215 |
+
texts = []
|
216 |
+
sent = ""
|
217 |
+
sparate = []
|
218 |
+
label = []
|
219 |
+
ruiseki = 0
|
220 |
+
ruiseki2 = 0
|
221 |
+
alls = []
|
222 |
+
labels, text, num = [], [], []
|
223 |
+
allab, altex, fukugenss = [], [], []
|
224 |
+
#for n in tqdm(range(26431)):
|
225 |
+
for n in tqdm(range(108)):
|
226 |
+
fukugens = []
|
227 |
+
for line in open(lis+str(n)+".txt"):
|
228 |
+
line = line.strip()
|
229 |
+
if line == "":
|
230 |
+
continue
|
231 |
+
sent = wakati.parse(line).split(" ")[:-1]
|
232 |
+
flag = 0
|
233 |
+
label = []
|
234 |
+
texts = []
|
235 |
+
fukugen = []
|
236 |
+
for i in sent:
|
237 |
+
try:
|
238 |
+
texts.append(fm.vocab[i].index)
|
239 |
+
except KeyError:
|
240 |
+
texts.append(fm.vocab["<unk>"].index)
|
241 |
+
fukugen.append(i)
|
242 |
+
label.append(0)
|
243 |
+
label[-1] = 1
|
244 |
+
labels.append(np.array(label))
|
245 |
+
text.append(np.array(texts))
|
246 |
+
#labels.append(label)
|
247 |
+
#text.append(texts)
|
248 |
+
fukugens.append(fukugen)
|
249 |
+
allab.append(labels)
|
250 |
+
altex.append(text)
|
251 |
+
fukugenss.append(fukugens)
|
252 |
+
labels, text, fukugens= [], [], []
|
253 |
+
return altex, allab, fukugenss
|
254 |
+
|
255 |
+
|
256 |
+
def nmni_finetune_s(lis):
|
257 |
+
#wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
|
258 |
+
wakati = MeCab.Tagger("-Owakati -b 81920")
|
259 |
+
#fm = gensim.models.KeyedVectors.load_word2vec_format('/clwork/ando/SEGBOT/cc.ja.300.vec', binary=False)
|
260 |
+
fm = gensim.models.KeyedVectors.load_word2vec_format('cc.ja.300.vec', binary=False)
|
261 |
+
with open('fm.pickle', 'wb') as f:
|
262 |
+
pickle.dump(fm, f)
|
263 |
+
#model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
|
264 |
+
#model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
|
265 |
+
texts = []
|
266 |
+
sent = ""
|
267 |
+
sparate = []
|
268 |
+
label = []
|
269 |
+
ruiseki = 0
|
270 |
+
ruiseki2 = 0
|
271 |
+
alls = []
|
272 |
+
labels, text, num = [], [], []
|
273 |
+
for n, line in enumerate(open(lis)):
|
274 |
+
line = line.strip("\t").rstrip("\n")
|
275 |
+
sent = wakati.parse(line).split(" ")[:-1]
|
276 |
+
flag = 0
|
277 |
+
label = []
|
278 |
+
texts = []
|
279 |
+
for i in sent:
|
280 |
+
try:
|
281 |
+
texts.append(fm.vocab[i].index)
|
282 |
+
except KeyError:
|
283 |
+
texts.append(fm.vocab["<unk>"].index)
|
284 |
+
label.append(0)
|
285 |
+
label[-1] = 1
|
286 |
+
labels.append(np.array(label))
|
287 |
+
text.append(np.array(texts))
|
288 |
+
return text,labels
|
289 |
+
|
290 |
+
|
291 |
+
def nmni_finetune_ss(lis):
|
292 |
+
#wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
|
293 |
+
wakati = MeCab.Tagger("-Owakati -b 81920")
|
294 |
+
fm = gensim.models.KeyedVectors.load_word2vec_format('cc.ja.300.vec', binary=False)
|
295 |
+
with open('fm.pickle', 'wb') as f:
|
296 |
+
pickle.dump(fm, f)
|
297 |
+
#with open('fm.pickle', 'rb') as f:
|
298 |
+
# fm = pickle.load(f)
|
299 |
+
#model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
|
300 |
+
#model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
|
301 |
+
t,l =[],[]
|
302 |
+
for i in range(108):
|
303 |
+
texts = []
|
304 |
+
sent = ""
|
305 |
+
sparate = []
|
306 |
+
label = []
|
307 |
+
ruiseki = 0
|
308 |
+
ruiseki2 = 0
|
309 |
+
alls = []
|
310 |
+
labels, text, num = [], [], []
|
311 |
+
for n, line in enumerate(open(lis+str(i)+".txt")):
|
312 |
+
line = line.strip("\t").rstrip("\n")
|
313 |
+
if line == "":
|
314 |
+
continue
|
315 |
+
sent = wakati.parse(line).split(" ")[:-1]
|
316 |
+
flag = 0
|
317 |
+
label = []
|
318 |
+
texts = []
|
319 |
+
for i in sent:
|
320 |
+
try:
|
321 |
+
texts.append(fm.vocab[i].index)
|
322 |
+
except KeyError:
|
323 |
+
texts.append(fm.vocab["<unk>"].index)
|
324 |
+
label.append(0)
|
325 |
+
label[-1] = 1
|
326 |
+
labels.append(np.array(label))
|
327 |
+
text.append(np.array(texts))
|
328 |
+
t.append(text)
|
329 |
+
l.append(labels)
|
330 |
+
return t,l
|
331 |
+
|
332 |
+
#model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
|
333 |
+
#print(model.get_subwords("間質性肺炎"))
|
334 |
+
#print(model.get_subwords("誤嚥性肺炎"))
|
335 |
+
#print(model.get_subwords("談話ユニット分割"))
|
336 |
+
|
337 |
+
"""
|
338 |
+
texts = []
|
339 |
+
sent = ""
|
340 |
+
sparate = []
|
341 |
+
label = []
|
342 |
+
ruiseki = 0
|
343 |
+
ruiseki2 = 0
|
344 |
+
alls = []
|
345 |
+
for n, line in enumerate(open("/clwork/ando/SEGBOT/randomdata.tsv")):
|
346 |
+
line = line.strip("\t").rstrip("\n")
|
347 |
+
if line == "":
|
348 |
+
if sent == "":
|
349 |
+
continue
|
350 |
+
alls.append(sent)
|
351 |
+
sent = ""
|
352 |
+
continue
|
353 |
+
else:
|
354 |
+
sent += line
|
355 |
+
if len(sent) != 0:
|
356 |
+
alls.append(sent)
|
357 |
+
random.shuffle(alls)
|
358 |
+
#v = random.sample(alls, 300)
|
359 |
+
#for i in v:
|
360 |
+
# alls.remove(i)
|
361 |
+
#t = random.sample(alls, 300)
|
362 |
+
#for i in t:
|
363 |
+
# alls.remove(i)
|
364 |
+
with open("randomdata_concat.tsv","a")as f:
|
365 |
+
f.write("\n".join())
|
366 |
+
#with open("dev_fix.tsv","a")as f:
|
367 |
+
# for i in v:
|
368 |
+
# f.write("\n".join(i))
|
369 |
+
# f.write("\n\n")
|
370 |
+
#with open("test_fix.tsv","a")as f:
|
371 |
+
# for i in t:
|
372 |
+
# f.write("\n".join(i))
|
373 |
+
# f.write("\n\n")
|
374 |
+
"""
|
375 |
+
|
376 |
+
"""
|
377 |
+
out = ""
|
378 |
+
for line in open("/clwork/ando/SEGBOT_BERT/alldata2_bert.tsv"):
|
379 |
+
line = line.split("\t")
|
380 |
+
line = line[0].strip()
|
381 |
+
if line == "" or "サマリ" in line:
|
382 |
+
continue
|
383 |
+
out += line + "\n"
|
384 |
+
with open("alldata3.tsv","w")as f:
|
385 |
+
f.write(out)
|
386 |
+
"""
|
387 |
+
"""
|
388 |
+
#wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
|
389 |
+
wakati = MeCab.Tagger("-Owakati -b 81920")
|
390 |
+
|
391 |
+
with open('fm_space.pickle', 'rb') as f:
|
392 |
+
fm = pickle.load(f)
|
393 |
+
#model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
|
394 |
+
#model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
|
395 |
+
texts = []
|
396 |
+
sent = ""
|
397 |
+
sparate = []
|
398 |
+
label = []
|
399 |
+
ruiseki = 0
|
400 |
+
ruiseki2 = 0
|
401 |
+
alls = []
|
402 |
+
for n, line in enumerate(open("/clwork/ando/SEGBOT/train_fix.tsv")):
|
403 |
+
line = line.strip("\t").rstrip("\n")
|
404 |
+
#print(line)
|
405 |
+
if line == "":
|
406 |
+
if sent == "":
|
407 |
+
continue
|
408 |
+
sent = wakati.parse(sent).split(" ")[:-1]
|
409 |
+
flag = 0
|
410 |
+
for i in sent:
|
411 |
+
for j in sparate:
|
412 |
+
if ruiseki+len(i) > j and ruiseki < j:
|
413 |
+
label.append(1)
|
414 |
+
flag = 1
|
415 |
+
elif ruiseki+len(i) == j:
|
416 |
+
label.append(1)
|
417 |
+
flag = 1
|
418 |
+
if flag == 0:
|
419 |
+
label.append(0)
|
420 |
+
flag = 0
|
421 |
+
ruiseki += len(i)
|
422 |
+
#texts += i + " "
|
423 |
+
try:
|
424 |
+
#texts.append(model[i])
|
425 |
+
texts.append(fm.vocab[i])
|
426 |
+
#texts += str(fm.vocab[i].index) + " "
|
427 |
+
#print(i,str(fm.vocab[i].index))
|
428 |
+
except KeyError:
|
429 |
+
texts.append(fm.vocab["<unk>"])
|
430 |
+
print(i)
|
431 |
+
label[-1] = 1
|
432 |
+
#texts = texts.rstrip() + "\t"
|
433 |
+
#texts += " ".join(label) + "\n"
|
434 |
+
alls.append((str(n),texts,label))
|
435 |
+
sent = ""
|
436 |
+
sparate = []
|
437 |
+
texts = []
|
438 |
+
label = []
|
439 |
+
ruiseki = 0
|
440 |
+
ruiseki2 = 0
|
441 |
+
continue
|
442 |
+
sent += mojimoji.han_to_zen(line, digit=False, ascii=False)
|
443 |
+
ruiseki2 += len(line)
|
444 |
+
sparate.append(ruiseki2)
|
445 |
+
with open('nm_ni/train.pickle', 'wb') as f:
|
446 |
+
pickle.dump(alls, f)
|
447 |
+
#print(alls)
|
448 |
+
#with open("resepdata_seped.tsv","w")as f:
|
449 |
+
# f.write(texts)
|
450 |
+
"""
|
451 |
+
|
452 |
+
|
453 |
+
|
454 |
+
wakati = MeCab.Tagger("-Owakati")
|
455 |
+
|
456 |
+
#fm = gensim.models.KeyedVectors.load_word2vec_format('/clwork/ando/SEGBOT/cc.ja.300.vec', binary=False)
|
457 |
+
#with open('fm.pickle', 'wb') as f:
|
458 |
+
# pickle.dump(fm, f)
|
459 |
+
texts = ""
|
460 |
+
sent = ""
|
461 |
+
sparate = []
|
462 |
+
label = []
|
463 |
+
ruiseki = 0
|
464 |
+
ruiseki2 = 0
|
465 |
+
for line in open("alldata.tsv"):
|
466 |
+
line = line.split("\t")
|
467 |
+
line = line[0].strip()
|
468 |
+
if line == "" or "サマリ" in line:
|
469 |
+
if sent == "":
|
470 |
+
continue
|
471 |
+
sent = wakati.parse(sent).split(" ")[:-1]
|
472 |
+
flag = 0
|
473 |
+
#print(sent,sparate)
|
474 |
+
for i in sent:
|
475 |
+
#print(i)
|
476 |
+
for j in sparate:
|
477 |
+
if ruiseki+len(i) > j and ruiseki < j:
|
478 |
+
#print(j)
|
479 |
+
label.append("1")
|
480 |
+
flag = 1
|
481 |
+
elif ruiseki+len(i) == j:
|
482 |
+
#print(j)
|
483 |
+
label.append("1")
|
484 |
+
flag = 1
|
485 |
+
if flag == 0:
|
486 |
+
label.append("0")
|
487 |
+
flag = 0
|
488 |
+
ruiseki += len(i)
|
489 |
+
#texts += i + " "
|
490 |
+
|
491 |
+
try:
|
492 |
+
texts += str(0) + " "
|
493 |
+
except KeyError:
|
494 |
+
print(i)
|
495 |
+
#texts += str(fm.vocab["<unk>"].index) + " "
|
496 |
+
|
497 |
+
label[-1] = "1"
|
498 |
+
texts = texts.rstrip() + "\t"
|
499 |
+
texts += " ".join(label) + "\n"
|
500 |
+
sent = ""
|
501 |
+
sparate = []
|
502 |
+
label = []
|
503 |
+
ruiseki = 0
|
504 |
+
ruiseki2 = 0
|
505 |
+
#print(texts)
|
506 |
+
continue
|
507 |
+
sent += line.strip()
|
508 |
+
ruiseki2 += len(line.strip())
|
509 |
+
sparate.append(ruiseki2)
|
510 |
+
with open("random_labbeled.tsv","w")as f:
|
511 |
+
f.write(texts)
|
512 |
+
|
513 |
+
|
514 |
+
|
515 |
+
|
516 |
+
|
517 |
+
"""
|
518 |
+
wakati = MeCab.Tagger("-Owakati -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
|
519 |
+
|
520 |
+
|
521 |
+
#fm = gensim.models.KeyedVectors.load_word2vec_format('/clwork/ando/SEGBOT/cc.ja.300_space.vec', binary=False)
|
522 |
+
#with open('fm_space.pickle', 'wb') as f:
|
523 |
+
# pickle.dump(fm, f)
|
524 |
+
|
525 |
+
with open('fm_space.pickle', 'rb') as f:
|
526 |
+
fm = pickle.load(f)
|
527 |
+
texts = ""
|
528 |
+
sent = ""
|
529 |
+
sparate = []
|
530 |
+
label = []
|
531 |
+
ruiseki = 0
|
532 |
+
ruiseki2 = 0
|
533 |
+
for line in open("/clwork/ando/SEGBOT/alldata_resep.tsv"):
|
534 |
+
line = line.split("\t")
|
535 |
+
line = line[0].strip("\t").rstrip("\n")
|
536 |
+
#print(line)
|
537 |
+
if line == "" or "サマリ" in line:
|
538 |
+
if sent == "":
|
539 |
+
continue
|
540 |
+
print(sent)
|
541 |
+
sent = sent.replace(" ","<space>")
|
542 |
+
sent = wakati.parse(sent).split(" ")[:-1]
|
543 |
+
print(sent)
|
544 |
+
flag = 0
|
545 |
+
#print(sent,sparate)
|
546 |
+
for i in sent:
|
547 |
+
#print(i)
|
548 |
+
for j in sparate:
|
549 |
+
if ruiseki+len(i) > j and ruiseki < j:
|
550 |
+
#print(j)
|
551 |
+
label.append("1")
|
552 |
+
flag = 1
|
553 |
+
elif ruiseki+len(i) == j:
|
554 |
+
#print(j)
|
555 |
+
label.append("1")
|
556 |
+
flag = 1
|
557 |
+
if flag == 0:
|
558 |
+
label.append("0")
|
559 |
+
flag = 0
|
560 |
+
ruiseki += len(i)
|
561 |
+
#texts += i + " "
|
562 |
+
|
563 |
+
try:
|
564 |
+
texts += str(fm.vocab[i].index) + " "
|
565 |
+
#print(i,str(fm.vocab[i].index))
|
566 |
+
except KeyError:
|
567 |
+
texts += str(fm.vocab["<unk>"].index) + " "
|
568 |
+
label[-1] = "1"
|
569 |
+
texts = texts.rstrip() + "\t"
|
570 |
+
texts += " ".join(label) + "\n"
|
571 |
+
sent = ""
|
572 |
+
sparate = []
|
573 |
+
label = []
|
574 |
+
ruiseki = 0
|
575 |
+
ruiseki2 = 0
|
576 |
+
#print(texts)
|
577 |
+
continue
|
578 |
+
sent += line.strip("\t")
|
579 |
+
ruiseki2 += len(line)
|
580 |
+
sparate.append(ruiseki2)
|
581 |
+
with open("alldata2_space.tsv","w")as f:
|
582 |
+
f.write(texts)
|
583 |
+
"""
|
584 |
+
|
585 |
+
|
586 |
+
|
587 |
+
"""
|
588 |
+
wakati = MeCab.Tagger("-Owakati")
|
589 |
+
|
590 |
+
fm = gensim.models.KeyedVectors.load_word2vec_format('/clwork/ando/SEGBOT/cc.ja.300.vec', binary=False)
|
591 |
+
texts = ""
|
592 |
+
sent = ""
|
593 |
+
cand = ""
|
594 |
+
sparate = []
|
595 |
+
label = []
|
596 |
+
ruiseki = 0
|
597 |
+
ruiseki2 = 0
|
598 |
+
flag2 = 1
|
599 |
+
for line in open("data2.tsv"):
|
600 |
+
line = line.split("\t")
|
601 |
+
if flag2 == 1:
|
602 |
+
cand = line
|
603 |
+
flag2 = 2
|
604 |
+
continue
|
605 |
+
if flag2 == 2:
|
606 |
+
flag2 = 1
|
607 |
+
#print(line,cand)
|
608 |
+
for n,z in enumerate(zip(cand,line)):
|
609 |
+
i = z[0]
|
610 |
+
j = z[1]
|
611 |
+
n = n+1
|
612 |
+
if i == "":
|
613 |
+
sent = wakati.parse(sent).split(" ")[:-1]
|
614 |
+
flag = 0
|
615 |
+
#print(sent,sparate)
|
616 |
+
for i in sent:
|
617 |
+
#print(i)
|
618 |
+
for j in sparate:
|
619 |
+
if ruiseki+len(i) > j and ruiseki < j:
|
620 |
+
#print(j)
|
621 |
+
label.append("1")
|
622 |
+
flag = 1
|
623 |
+
elif ruiseki+len(i) == j:
|
624 |
+
#print(j)
|
625 |
+
label.append("1")
|
626 |
+
flag = 1
|
627 |
+
if flag == 0:
|
628 |
+
label.append("0")
|
629 |
+
flag = 0
|
630 |
+
ruiseki += len(i)
|
631 |
+
#texts += i + " "
|
632 |
+
|
633 |
+
try:
|
634 |
+
texts += str(fm.vocab[i].index) + " "
|
635 |
+
except KeyError:
|
636 |
+
texts += str(fm.vocab["<unk>"].index) + " "
|
637 |
+
|
638 |
+
label[-1] = "1"
|
639 |
+
texts = texts.rstrip() + "\t"
|
640 |
+
texts += " ".join(label) + "\n"
|
641 |
+
sent = ""
|
642 |
+
sparate = []
|
643 |
+
label = []
|
644 |
+
ruiseki = 0
|
645 |
+
ruiseki2 = 0
|
646 |
+
#print(texts)
|
647 |
+
break
|
648 |
+
if j == "|":
|
649 |
+
sparate.append(n)
|
650 |
+
sent += i
|
651 |
+
with open("alldata.tsv","w")as f:
|
652 |
+
f.write(texts)
|
653 |
+
"""
|
fm.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f4c02d5957824106f6217e9a56d89ee5b7ca9ae399c7a49af8dc062e1ea0be99
|
3 |
+
size 2521658187
|
functionforDownloadButtons.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pickle
|
3 |
+
import pandas as pd
|
4 |
+
import json
|
5 |
+
import base64
|
6 |
+
import uuid
|
7 |
+
import re
|
8 |
+
|
9 |
+
import importlib.util
|
10 |
+
|
11 |
+
|
12 |
+
def import_from_file(module_name: str, filepath: str):
|
13 |
+
"""
|
14 |
+
Imports a module from file.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
module_name (str): Assigned to the module's __name__ parameter (does not
|
18 |
+
influence how the module is named outside of this function)
|
19 |
+
filepath (str): Path to the .py file
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
The module
|
23 |
+
"""
|
24 |
+
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
25 |
+
module = importlib.util.module_from_spec(spec)
|
26 |
+
spec.loader.exec_module(module)
|
27 |
+
return module
|
28 |
+
|
29 |
+
|
30 |
+
def notebook_header(text):
|
31 |
+
"""
|
32 |
+
Insert section header into a jinja file, formatted as notebook cell.
|
33 |
+
|
34 |
+
Leave 2 blank lines before the header.
|
35 |
+
"""
|
36 |
+
return f"""# # {text}
|
37 |
+
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
def code_header(text):
|
42 |
+
"""
|
43 |
+
Insert section header into a jinja file, formatted as Python comment.
|
44 |
+
|
45 |
+
Leave 2 blank lines before the header.
|
46 |
+
"""
|
47 |
+
seperator_len = (75 - len(text)) / 2
|
48 |
+
seperator_len_left = math.floor(seperator_len)
|
49 |
+
seperator_len_right = math.ceil(seperator_len)
|
50 |
+
return f"# {'-' * seperator_len_left} {text} {'-' * seperator_len_right}"
|
51 |
+
|
52 |
+
|
53 |
+
def to_notebook(code):
|
54 |
+
"""Converts Python code to Jupyter notebook format."""
|
55 |
+
notebook = jupytext.reads(code, fmt="py")
|
56 |
+
return jupytext.writes(notebook, fmt="ipynb")
|
57 |
+
|
58 |
+
|
59 |
+
def open_link(url, new_tab=True):
|
60 |
+
"""Dirty hack to open a new web page with a streamlit button."""
|
61 |
+
# From: https://discuss.streamlit.io/t/how-to-link-a-button-to-a-webpage/1661/3
|
62 |
+
if new_tab:
|
63 |
+
js = f"window.open('{url}')" # New tab or window
|
64 |
+
else:
|
65 |
+
js = f"window.location.href = '{url}'" # Current tab
|
66 |
+
html = '<img src onerror="{}">'.format(js)
|
67 |
+
div = Div(text=html)
|
68 |
+
st.bokeh_chart(div)
|
69 |
+
|
70 |
+
|
71 |
+
def download_button(object_to_download, download_filename, button_text):
|
72 |
+
"""
|
73 |
+
Generates a link to download the given object_to_download.
|
74 |
+
|
75 |
+
From: https://discuss.streamlit.io/t/a-download-button-with-custom-css/4220
|
76 |
+
|
77 |
+
Params:
|
78 |
+
------
|
79 |
+
object_to_download: The object to be downloaded.
|
80 |
+
download_filename (str): filename and extension of file. e.g. mydata.csv,
|
81 |
+
some_txt_output.txt download_link_text (str): Text to display for download
|
82 |
+
link.
|
83 |
+
|
84 |
+
button_text (str): Text to display on download button (e.g. 'click here to download file')
|
85 |
+
pickle_it (bool): If True, pickle file.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
-------
|
89 |
+
(str): the anchor tag to download object_to_download
|
90 |
+
|
91 |
+
Examples:
|
92 |
+
--------
|
93 |
+
download_link(your_df, 'YOUR_DF.csv', 'Click to download data!')
|
94 |
+
download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!')
|
95 |
+
|
96 |
+
"""
|
97 |
+
# if pickle_it:
|
98 |
+
# try:
|
99 |
+
# object_to_download = pickle.dumps(object_to_download)
|
100 |
+
# except pickle.PicklingError as e:
|
101 |
+
# st.write(e)
|
102 |
+
# return None
|
103 |
+
|
104 |
+
# if:
|
105 |
+
if isinstance(object_to_download, bytes):
|
106 |
+
pass
|
107 |
+
|
108 |
+
elif isinstance(object_to_download, pd.DataFrame):
|
109 |
+
object_to_download = object_to_download.to_csv(index=False)
|
110 |
+
# Try JSON encode for everything else
|
111 |
+
else:
|
112 |
+
object_to_download = json.dumps(object_to_download)
|
113 |
+
|
114 |
+
try:
|
115 |
+
# some strings <-> bytes conversions necessary here
|
116 |
+
b64 = base64.b64encode(object_to_download.encode()).decode()
|
117 |
+
except AttributeError as e:
|
118 |
+
b64 = base64.b64encode(object_to_download).decode()
|
119 |
+
|
120 |
+
button_uuid = str(uuid.uuid4()).replace("-", "")
|
121 |
+
button_id = re.sub("\d+", "", button_uuid)
|
122 |
+
|
123 |
+
custom_css = f"""
|
124 |
+
<style>
|
125 |
+
#{button_id} {{
|
126 |
+
display: inline-flex;
|
127 |
+
align-items: center;
|
128 |
+
justify-content: center;
|
129 |
+
background-color: rgb(255, 255, 255);
|
130 |
+
color: rgb(38, 39, 48);
|
131 |
+
padding: .25rem .75rem;
|
132 |
+
position: relative;
|
133 |
+
text-decoration: none;
|
134 |
+
border-radius: 4px;
|
135 |
+
border-width: 1px;
|
136 |
+
border-style: solid;
|
137 |
+
border-color: rgb(230, 234, 241);
|
138 |
+
border-image: initial;
|
139 |
+
}}
|
140 |
+
#{button_id}:hover {{
|
141 |
+
border-color: rgb(246, 51, 102);
|
142 |
+
color: rgb(246, 51, 102);
|
143 |
+
}}
|
144 |
+
#{button_id}:active {{
|
145 |
+
box-shadow: none;
|
146 |
+
background-color: rgb(246, 51, 102);
|
147 |
+
color: white;
|
148 |
+
}}
|
149 |
+
</style> """
|
150 |
+
|
151 |
+
dl_link = (
|
152 |
+
custom_css
|
153 |
+
+ f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br><br>'
|
154 |
+
)
|
155 |
+
# dl_link = f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}"><input type="button" kind="primary" value="{button_text}"></a><br></br>'
|
156 |
+
|
157 |
+
st.markdown(dl_link, unsafe_allow_html=True)
|
158 |
+
|
159 |
+
|
160 |
+
# def download_link(
|
161 |
+
# content, label="Download", filename="file.txt", mimetype="text/plain"
|
162 |
+
# ):
|
163 |
+
# """Create a HTML link to download a string as a file."""
|
164 |
+
# # From: https://discuss.streamlit.io/t/how-to-download-file-in-streamlit/1806/9
|
165 |
+
# b64 = base64.b64encode(
|
166 |
+
# content.encode()
|
167 |
+
# ).decode() # some strings <-> bytes conversions necessary here
|
168 |
+
# href = (
|
169 |
+
# f'<a href="data:{mimetype};base64,{b64}" download="{filename}">{label}</a>'
|
170 |
+
# )
|
171 |
+
# return href
|
logo.png
ADDED
![]() |
model.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.utils.rnn as R
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.autograd import Variable
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class PointerNetworks(nn.Module):
|
11 |
+
def __init__(self,voca_size, voc_embeddings,word_dim, hidden_dim,is_bi_encoder_rnn,rnn_type,rnn_layers,
|
12 |
+
dropout_prob,use_cuda,finedtuning,isbanor,batchsize):
|
13 |
+
super(PointerNetworks,self).__init__()
|
14 |
+
|
15 |
+
self.word_dim = word_dim
|
16 |
+
self.voca_size = voca_size
|
17 |
+
|
18 |
+
self.hidden_dim = hidden_dim
|
19 |
+
self.dropout_prob = dropout_prob
|
20 |
+
self.is_bi_encoder_rnn = is_bi_encoder_rnn
|
21 |
+
self.num_rnn_layers = rnn_layers
|
22 |
+
self.rnn_type = rnn_type
|
23 |
+
self.voc_embeddings = voc_embeddings
|
24 |
+
self.finedtuning = finedtuning
|
25 |
+
self.batchsize = batchsize
|
26 |
+
|
27 |
+
self.nnDropout = nn.Dropout(dropout_prob)
|
28 |
+
|
29 |
+
self.isbanor = isbanor
|
30 |
+
|
31 |
+
|
32 |
+
if rnn_type in ['LSTM', 'GRU']:
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
self.decoder_rnn = getattr(nn, rnn_type)(input_size=word_dim,
|
37 |
+
hidden_size=2 * hidden_dim if is_bi_encoder_rnn else hidden_dim,
|
38 |
+
num_layers=rnn_layers,
|
39 |
+
dropout=dropout_prob,
|
40 |
+
batch_first=True)
|
41 |
+
|
42 |
+
self.encoder_rnn = getattr(nn, rnn_type)(input_size=word_dim,
|
43 |
+
hidden_size=hidden_dim,
|
44 |
+
num_layers=rnn_layers,
|
45 |
+
bidirectional=is_bi_encoder_rnn,
|
46 |
+
dropout=dropout_prob,
|
47 |
+
batch_first=True)
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
else:
|
52 |
+
print('rnn_type should be LSTM,GRU')
|
53 |
+
|
54 |
+
self.use_cuda = True
|
55 |
+
|
56 |
+
self.nnSELU = nn.SELU()
|
57 |
+
|
58 |
+
|
59 |
+
self.nnEm = nn.Embedding(self.voca_size,self.word_dim,padding_idx=2000001)
|
60 |
+
#self.nnEm = nn.Embedding.from_pretrained(self.voc_embeddings,freeze=self.finedtuning,padding_idx=-1)
|
61 |
+
self.initEmbeddings(self.voc_embeddings)
|
62 |
+
if self.use_cuda:
|
63 |
+
self.nnEm = self.nnEm.cuda()
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
if self.is_bi_encoder_rnn:
|
71 |
+
self.num_encoder_bi = 2
|
72 |
+
else:
|
73 |
+
self.num_encoder_bi = 1
|
74 |
+
|
75 |
+
|
76 |
+
self.nnW1 = nn.Linear(self.num_encoder_bi * hidden_dim, self.num_encoder_bi * hidden_dim, bias=False)
|
77 |
+
self.nnW2 = nn.Linear(self.num_encoder_bi * hidden_dim, self.num_encoder_bi * hidden_dim, bias=False)
|
78 |
+
self.nnV = nn.Linear(self.num_encoder_bi * hidden_dim, 1, bias=False)
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
def initEmbeddings(self,weights):
|
91 |
+
self.nnEm.weight.data.copy_(torch.from_numpy(weights))
|
92 |
+
self.nnEm.weight.requires_grad = self.finedtuning
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
def initHidden(self,hsize,batchsize):
|
97 |
+
|
98 |
+
#hsize=self.hidden_dim
|
99 |
+
#batchsize=self.batchsize
|
100 |
+
if self.rnn_type == 'LSTM':
|
101 |
+
|
102 |
+
h_0 = Variable(torch.zeros(self.num_encoder_bi*self.num_rnn_layers, batchsize, hsize))
|
103 |
+
c_0 = Variable(torch.zeros(self.num_encoder_bi*self.num_rnn_layers, batchsize, hsize))
|
104 |
+
|
105 |
+
if self.use_cuda:
|
106 |
+
h_0 = h_0.cuda()
|
107 |
+
c_0 = c_0.cuda()
|
108 |
+
|
109 |
+
return (h_0, c_0)
|
110 |
+
else:
|
111 |
+
|
112 |
+
h_0 = Variable(torch.zeros(self.num_encoder_bi*self.num_rnn_layers, batchsize, hsize))
|
113 |
+
|
114 |
+
if self.use_cuda:
|
115 |
+
h_0 = h_0.cuda()
|
116 |
+
|
117 |
+
|
118 |
+
return h_0
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
def _run_rnn_packed(self, cell, x, x_lens, h=None):
|
127 |
+
#print(x_lens)
|
128 |
+
x_packed = R.pack_padded_sequence(x, x_lens.data.tolist(),
|
129 |
+
batch_first=True, enforce_sorted=False)
|
130 |
+
if h is not None:
|
131 |
+
output, h = cell(x_packed, h)
|
132 |
+
else:
|
133 |
+
output, h = cell(x_packed)
|
134 |
+
|
135 |
+
output, _ = R.pad_packed_sequence(output, batch_first=True)
|
136 |
+
|
137 |
+
return output, h
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
def pointerEncoder(self,Xin,lens):
|
144 |
+
self.bn_inputdata = nn.BatchNorm1d(self.word_dim, affine=False, track_running_stats=False)
|
145 |
+
|
146 |
+
|
147 |
+
batch_size,maxL = Xin.size()
|
148 |
+
|
149 |
+
X = self.nnEm(Xin) # N L C
|
150 |
+
|
151 |
+
if self.isbanor and maxL>1:
|
152 |
+
X= X.permute(0,2,1) # N C L
|
153 |
+
X = self.bn_inputdata(X)
|
154 |
+
X = X.permute(0, 2, 1) # N L C
|
155 |
+
|
156 |
+
X = self.nnDropout(X)
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
encoder_lstm_co_h_o = self.initHidden(self.hidden_dim, batch_size)
|
161 |
+
o, h = self._run_rnn_packed(self.encoder_rnn, X, lens, encoder_lstm_co_h_o) # batch_first=True
|
162 |
+
o = o.contiguous()
|
163 |
+
|
164 |
+
o = self.nnDropout(o)
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
return o,h
|
170 |
+
|
171 |
+
|
172 |
+
def pointerLayer(self,en,di):
|
173 |
+
"""
|
174 |
+
|
175 |
+
:param en: [L,H]
|
176 |
+
:param di: [H,]
|
177 |
+
:return:
|
178 |
+
"""
|
179 |
+
|
180 |
+
|
181 |
+
WE = self.nnW1(en)
|
182 |
+
|
183 |
+
|
184 |
+
exdi = di.expand_as(en)
|
185 |
+
|
186 |
+
WD = self.nnW2(exdi)
|
187 |
+
|
188 |
+
nnV = self.nnV(self.nnSELU(WE+WD))
|
189 |
+
|
190 |
+
nnV = nnV.permute(1,0)
|
191 |
+
|
192 |
+
nnV = self.nnSELU(nnV)
|
193 |
+
|
194 |
+
|
195 |
+
#TODO: for log loss
|
196 |
+
att_weights = F.softmax(nnV)
|
197 |
+
logits = F.log_softmax(nnV)
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
return logits,att_weights
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
def training_decoder(self,hn,hend,X,Xindex,Yindex,lens):
|
211 |
+
"""
|
212 |
+
|
213 |
+
|
214 |
+
"""
|
215 |
+
|
216 |
+
|
217 |
+
loss_function = nn.NLLLoss()
|
218 |
+
batch_loss =0
|
219 |
+
LoopN =0
|
220 |
+
batch_size = len(lens)
|
221 |
+
for i in range(len(lens)): #Loop batch size
|
222 |
+
|
223 |
+
curX_index = Xindex[i]
|
224 |
+
#print(curX_index)
|
225 |
+
#print()
|
226 |
+
curY_index = Yindex[i]
|
227 |
+
curL = lens[i]
|
228 |
+
curX = X[i]
|
229 |
+
#print(curX)
|
230 |
+
|
231 |
+
x_index_var = Variable(torch.from_numpy(curX_index.astype(np.int64)))
|
232 |
+
if self.use_cuda:
|
233 |
+
x_index_var = x_index_var.cuda()
|
234 |
+
cur_lookup = curX[x_index_var]
|
235 |
+
#print(cur_lookup)
|
236 |
+
|
237 |
+
curX_vectors = self.nnEm(cur_lookup) # output: [seq,features]
|
238 |
+
|
239 |
+
curX_vectors = curX_vectors.unsqueeze(0) # [batch, seq, features]
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
if self.rnn_type =='LSTM':# need h_end,c_end
|
244 |
+
|
245 |
+
|
246 |
+
h_end = hend[0].permute(1, 0, 2).contiguous().view(batch_size, self.num_rnn_layers,-1)
|
247 |
+
c_end = hend[1].permute(1, 0, 2).contiguous().view(batch_size, self.num_rnn_layers,-1)
|
248 |
+
|
249 |
+
curh0 = h_end[i].unsqueeze(0).permute(1, 0, 2)
|
250 |
+
curc0 = c_end[i].unsqueeze(0).permute(1, 0, 2)
|
251 |
+
|
252 |
+
|
253 |
+
h_pass = (curh0,curc0)
|
254 |
+
else:
|
255 |
+
|
256 |
+
|
257 |
+
h_end = hend.permute(1, 0, 2).contiguous().view(batch_size, self.num_rnn_layers,-1)
|
258 |
+
curh0 = h_end[i].unsqueeze(0).permute(1, 0, 2)
|
259 |
+
h_pass = curh0
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
decoder_out,_ = self.decoder_rnn(curX_vectors,h_pass)
|
264 |
+
decoder_out = decoder_out.squeeze(0) #[seq,features]
|
265 |
+
|
266 |
+
|
267 |
+
curencoder_hn = hn[i,0:curL,:] # hn[batch,seq,H] -->[seq,H] i is loop batch size
|
268 |
+
|
269 |
+
for j in range(len(decoder_out)): #Loop di
|
270 |
+
#print(len(decoder_out),curY_index)
|
271 |
+
cur_dj = decoder_out[j]
|
272 |
+
cur_groundy = curY_index[j]
|
273 |
+
|
274 |
+
cur_start_index = curX_index[j]
|
275 |
+
predict_range = list(range(cur_start_index,curL))
|
276 |
+
|
277 |
+
# TODO: make it point backward, only consider predict_range in current time step
|
278 |
+
# align groundtruth
|
279 |
+
cur_groundy_var = Variable(torch.LongTensor([int(cur_groundy) - int(cur_start_index)]))
|
280 |
+
if self.use_cuda:
|
281 |
+
cur_groundy_var = cur_groundy_var.cuda()
|
282 |
+
|
283 |
+
curencoder_hn_back = curencoder_hn[predict_range,:]
|
284 |
+
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
cur_logists, cur_weights = self.pointerLayer(curencoder_hn_back,cur_dj)
|
289 |
+
|
290 |
+
batch_loss = batch_loss + loss_function(cur_logists,cur_groundy_var)
|
291 |
+
LoopN = LoopN +1
|
292 |
+
|
293 |
+
batch_loss = batch_loss/LoopN
|
294 |
+
|
295 |
+
return batch_loss
|
296 |
+
|
297 |
+
|
298 |
+
def neg_log_likelihood(self,Xin,index_decoder_x, index_decoder_y,lens):
|
299 |
+
|
300 |
+
'''
|
301 |
+
:param Xin: stack_x, [allseq,wordDim]
|
302 |
+
:param Yin:
|
303 |
+
:param lens:
|
304 |
+
:return:
|
305 |
+
'''
|
306 |
+
|
307 |
+
|
308 |
+
encoder_hn, encoder_h_end = self.pointerEncoder(Xin,lens)
|
309 |
+
|
310 |
+
loss = self.training_decoder(encoder_hn, encoder_h_end,Xin,index_decoder_x, index_decoder_y,lens)
|
311 |
+
|
312 |
+
return loss
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
+
def test_decoder(self,hn,hend,X,Yindex,lens):
|
318 |
+
|
319 |
+
loss_function = nn.NLLLoss()
|
320 |
+
batch_loss = 0
|
321 |
+
LoopN = 0
|
322 |
+
|
323 |
+
batch_boundary =[]
|
324 |
+
batch_boundary_start =[]
|
325 |
+
batch_align_matrix =[]
|
326 |
+
|
327 |
+
batch_size = len(lens)
|
328 |
+
|
329 |
+
for i in range(len(lens)): # Loop batch size
|
330 |
+
|
331 |
+
|
332 |
+
|
333 |
+
curL = lens[i]
|
334 |
+
curY_index = Yindex[i]
|
335 |
+
curX = X[i]
|
336 |
+
cur_end_boundary =curY_index[-1]
|
337 |
+
|
338 |
+
cur_boundary = []
|
339 |
+
cur_b_start = []
|
340 |
+
cur_align_matrix = []
|
341 |
+
|
342 |
+
cur_sentence_vectors = self.nnEm(curX) # output: [seq,features]
|
343 |
+
|
344 |
+
|
345 |
+
if self.rnn_type =='LSTM':# need h_end,c_end
|
346 |
+
|
347 |
+
|
348 |
+
h_end = hend[0].permute(1, 0, 2).contiguous().view(batch_size, self.num_rnn_layers,-1)
|
349 |
+
c_end = hend[1].permute(1, 0, 2).contiguous().view(batch_size, self.num_rnn_layers,-1)
|
350 |
+
|
351 |
+
curh0 = h_end[i].unsqueeze(0).permute(1, 0, 2)
|
352 |
+
curc0 = c_end[i].unsqueeze(0).permute(1, 0, 2)
|
353 |
+
|
354 |
+
h_pass = (curh0,curc0)
|
355 |
+
else: # only need h_end
|
356 |
+
|
357 |
+
|
358 |
+
h_end = hend.permute(1, 0, 2).contiguous().view(batch_size, self.num_rnn_layers,-1)
|
359 |
+
curh0 = h_end[i].unsqueeze(0).permute(1, 0, 2)
|
360 |
+
h_pass = curh0
|
361 |
+
|
362 |
+
|
363 |
+
|
364 |
+
curencoder_hn = hn[i, 0:curL, :] # hn[batch,seq,H] --> [seq,H] i is loop batch size
|
365 |
+
|
366 |
+
Not_break = True
|
367 |
+
|
368 |
+
loop_in = cur_sentence_vectors[0,:].unsqueeze(0).unsqueeze(0) #[1,1,H]
|
369 |
+
loop_hc = h_pass
|
370 |
+
|
371 |
+
|
372 |
+
loopstart =0
|
373 |
+
|
374 |
+
loop_j =0
|
375 |
+
while (Not_break): #if not end
|
376 |
+
|
377 |
+
loop_o, loop_hc = self.decoder_rnn(loop_in,loop_hc)
|
378 |
+
|
379 |
+
|
380 |
+
#TODO: make it point backward
|
381 |
+
|
382 |
+
predict_range = list(range(loopstart,curL))
|
383 |
+
curencoder_hn_back = curencoder_hn[predict_range,:]
|
384 |
+
cur_logists, cur_weights = self.pointerLayer(curencoder_hn_back, loop_o.squeeze(0).squeeze(0))
|
385 |
+
|
386 |
+
cur_align_vector = np.zeros(curL)
|
387 |
+
cur_align_vector[predict_range]=cur_weights.data.cpu().numpy()[0]
|
388 |
+
cur_align_matrix.append(cur_align_vector)
|
389 |
+
|
390 |
+
#TODO:align groundtruth
|
391 |
+
if loop_j > len(curY_index)-1:
|
392 |
+
cur_groundy = curY_index[-1]
|
393 |
+
else:
|
394 |
+
cur_groundy = curY_index[loop_j]
|
395 |
+
|
396 |
+
|
397 |
+
cur_groundy_var = Variable(torch.LongTensor([max(0,int(cur_groundy) - loopstart)]))
|
398 |
+
if self.use_cuda:
|
399 |
+
cur_groundy_var = cur_groundy_var.cuda()
|
400 |
+
|
401 |
+
batch_loss = batch_loss + loss_function(cur_logists, cur_groundy_var)
|
402 |
+
|
403 |
+
|
404 |
+
#TODO: get predicted boundary
|
405 |
+
topv, topi = cur_logists.data.topk(1)
|
406 |
+
|
407 |
+
pred_index = topi[0][0]
|
408 |
+
|
409 |
+
|
410 |
+
#TODO: align pred_index to original seq
|
411 |
+
ori_pred_index =pred_index + loopstart
|
412 |
+
|
413 |
+
|
414 |
+
if cur_end_boundary == ori_pred_index:
|
415 |
+
cur_boundary.append(ori_pred_index)
|
416 |
+
cur_b_start.append(loopstart)
|
417 |
+
Not_break = False
|
418 |
+
loop_j = loop_j + 1
|
419 |
+
LoopN = LoopN + 1
|
420 |
+
break
|
421 |
+
else:
|
422 |
+
cur_boundary.append(ori_pred_index)
|
423 |
+
|
424 |
+
loop_in = cur_sentence_vectors[ori_pred_index+1,:].unsqueeze(0).unsqueeze(0)
|
425 |
+
cur_b_start.append(loopstart)
|
426 |
+
|
427 |
+
loopstart = ori_pred_index+1 # start = pred_end + 1
|
428 |
+
|
429 |
+
loop_j = loop_j + 1
|
430 |
+
LoopN = LoopN + 1
|
431 |
+
|
432 |
+
|
433 |
+
#For each instance in batch
|
434 |
+
batch_boundary.append(cur_boundary)
|
435 |
+
batch_boundary_start.append(cur_b_start)
|
436 |
+
batch_align_matrix.append(cur_align_matrix)
|
437 |
+
|
438 |
+
batch_loss = batch_loss / LoopN
|
439 |
+
|
440 |
+
batch_boundary=np.array(batch_boundary)
|
441 |
+
batch_boundary_start = np.array(batch_boundary_start)
|
442 |
+
batch_align_matrix = np.array(batch_align_matrix)
|
443 |
+
|
444 |
+
return batch_loss,batch_boundary,batch_boundary_start,batch_align_matrix
|
445 |
+
|
446 |
+
|
447 |
+
|
448 |
+
|
449 |
+
|
450 |
+
|
451 |
+
|
452 |
+
|
453 |
+
def predict(self,Xin,index_decoder_y,lens):
|
454 |
+
|
455 |
+
batch_size = index_decoder_y.shape[0]
|
456 |
+
|
457 |
+
encoder_hn, encoder_h_end = self.pointerEncoder(Xin, lens)
|
458 |
+
|
459 |
+
|
460 |
+
|
461 |
+
|
462 |
+
|
463 |
+
batch_loss, batch_boundary, batch_boundary_start, batch_align_matrix = self.test_decoder(encoder_hn,encoder_h_end,Xin,index_decoder_y,lens)
|
464 |
+
|
465 |
+
return batch_loss,batch_boundary,batch_boundary_start,batch_align_matrix
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seaborn
|
2 |
+
matplotlib
|
3 |
+
streamlit == 0.87
|
4 |
+
pandas == 1.2.4
|
5 |
+
keybert
|
6 |
+
flair
|
7 |
+
click<8
|
run_segbot.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import re
|
2 |
-
from nltk.tokenize import word_tokenize
|
3 |
import pickle
|
4 |
import numpy as np
|
5 |
import random
|
@@ -8,99 +7,67 @@ from solver import TrainSolver
|
|
8 |
|
9 |
from model import PointerNetworks
|
10 |
import gensim
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
self.word2index[word] = self.n_words
|
28 |
-
self.word2count[word] = 1
|
29 |
-
self.index2word[self.n_words] = word
|
30 |
-
self.n_words += 1
|
31 |
-
else:
|
32 |
-
self.word2count[word] += 1
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
def mytokenizer(inS,all_dict):
|
37 |
-
|
38 |
-
#repDig = re.sub(r'\d+[\.,/]?\d+','RE_DIGITS',inS)
|
39 |
-
#repDig = re.sub(r'\d*[\d,]*\d+', 'RE_DIGITS', inS)
|
40 |
-
toked = inS
|
41 |
-
or_toked = inS
|
42 |
-
re_unk_list = []
|
43 |
-
ori_list = []
|
44 |
-
|
45 |
-
for (i,t) in enumerate(toked):
|
46 |
-
if t not in all_dict and t not in ['RE_DIGITS']:
|
47 |
-
re_unk_list.append('UNKNOWN')
|
48 |
-
ori_list.append(or_toked[i])
|
49 |
-
else:
|
50 |
-
re_unk_list.append(t)
|
51 |
-
ori_list.append(or_toked[i])
|
52 |
-
|
53 |
-
labey_edus = [0]*len(re_unk_list)
|
54 |
-
labey_edus[-1] = 1
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
return ori_list,re_unk_list,labey_edus
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
def get_mapping(X,Y,D):
|
64 |
-
|
65 |
-
X_map = []
|
66 |
-
for w in X:
|
67 |
-
if w in D:
|
68 |
-
X_map.append(D[w])
|
69 |
else:
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
with open('model.pickle', 'rb') as f:
|
85 |
mysolver = pickle.load(f)
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
#test_batch_ave_loss, test_pre, test_rec, test_f1, visdata = mysolver.check_accuracy(X_tes, Y_tes,0)
|
91 |
-
#with open(str(i)+"seped","w")as f:
|
92 |
-
# f.write(o)
|
93 |
-
#test_batch_ave_loss, test_pre, test_rec, test_f1, visdata = mysolver.check_accuracy(X_tes, Y_tes,0)
|
94 |
-
print(test_pre, test_rec, test_f1)
|
95 |
-
#start_b = visdata[3][0]
|
96 |
-
#end_b = visdata[2][0] + 1
|
97 |
-
#segments = []
|
98 |
-
|
99 |
-
#for i, END in enumerate(end_b):
|
100 |
-
# START = start_b[i]
|
101 |
-
# segments.append(' '.join(ori_X[START:END]))
|
102 |
-
|
103 |
-
return test_pre, test_rec, test_f1
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
1 |
import re
|
|
|
2 |
import pickle
|
3 |
import numpy as np
|
4 |
import random
|
|
|
7 |
|
8 |
from model import PointerNetworks
|
9 |
import gensim
|
10 |
+
import MeCab
|
11 |
+
import pysbd
|
12 |
+
|
13 |
+
def create_data(doc,fm,split_method):
|
14 |
+
wakati = MeCab.Tagger("-Owakati -b 81920")
|
15 |
+
seg = pysbd.Segmenter(language="ja", clean=False)
|
16 |
+
texts = []
|
17 |
+
sent = ""
|
18 |
+
label = []
|
19 |
+
alls = []
|
20 |
+
labels, text, num = [], [], []
|
21 |
+
allab, altex, fukugenss = [], [], []
|
22 |
+
for n in range(1):
|
23 |
+
fukugens = []
|
24 |
+
if split_method == "pySBD":
|
25 |
+
lines = seg.segment(doc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
else:
|
27 |
+
doc = doc.strip().replace("。","。\n").replace(".",".\n")
|
28 |
+
doc = re.sub("(\n)+","\n",doc)
|
29 |
+
lines = doc.split("\n")
|
30 |
+
for line in lines:
|
31 |
+
line = line.strip()
|
32 |
+
if line == "":
|
33 |
+
continue
|
34 |
+
sent = wakati.parse(line).split(" ")[:-1]
|
35 |
+
flag = 0
|
36 |
+
label = []
|
37 |
+
texts = []
|
38 |
+
fukugen = []
|
39 |
+
for i in sent:
|
40 |
+
try:
|
41 |
+
texts.append(fm.vocab[i].index)
|
42 |
+
except KeyError:
|
43 |
+
texts.append(fm.vocab["<unk>"].index)
|
44 |
+
fukugen.append(i)
|
45 |
+
label.append(0)
|
46 |
+
label[-1] = 1
|
47 |
+
labels.append(np.array(label))
|
48 |
+
text.append(np.array(texts))
|
49 |
+
fukugens.append(fukugen)
|
50 |
+
allab.append(labels)
|
51 |
+
altex.append(text)
|
52 |
+
fukugenss.append(fukugens)
|
53 |
+
labels, text, fukugens= [], [], []
|
54 |
+
return altex, allab, fukugenss
|
55 |
+
|
56 |
+
|
57 |
+
def generate(doc, mymodel, fm, index2word, split_method):
|
58 |
+
X_tes, Y_tes, fukugen = create_data(doc,fm,split_method)
|
59 |
+
output_texts = mymodel.check_accuracy(X_tes, Y_tes,index2word, fukugen)
|
60 |
+
|
61 |
+
return output_texts
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
def setup():
|
66 |
+
with open('index2word.pickle', 'rb') as f:
|
67 |
+
index2word = pickle.load(f)
|
68 |
with open('model.pickle', 'rb') as f:
|
69 |
mysolver = pickle.load(f)
|
70 |
+
with open('fm.pickle', 'rb') as f:
|
71 |
+
fm = pickle.load(f)
|
72 |
+
|
73 |
+
return mysolver,fm,index2word
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
solver.py
CHANGED
@@ -6,7 +6,6 @@ from torch.autograd import Variable
|
|
6 |
import random
|
7 |
from torch.nn.utils import clip_grad_norm
|
8 |
import copy
|
9 |
-
from tqdm import tqdm
|
10 |
|
11 |
import os
|
12 |
import pickle
|
@@ -56,76 +55,36 @@ def align_variable_numpy(X,maxL,paddingNumber):
|
|
56 |
|
57 |
|
58 |
def sample_a_sorted_batch_from_numpy(numpyX,numpyY,batch_size,use_cuda):
|
59 |
-
|
60 |
-
|
61 |
-
if batch_size != None:
|
62 |
-
select_index = random.sample(range(len(numpyY)), batch_size)
|
63 |
-
else:
|
64 |
-
select_index = np.array(range(len(numpyY)))
|
65 |
|
66 |
select_index = np.array(range(len(numpyX)))
|
67 |
|
68 |
batch_x = [copy.deepcopy(numpyX[i]) for i in select_index]
|
69 |
batch_y = [copy.deepcopy(numpyY[i]) for i in select_index]
|
70 |
|
71 |
-
#print(batch_y)
|
72 |
index_decoder_X,index_decoder_Y = get_decoder_index_XY(batch_y)
|
73 |
-
#index_decoder = [get_decoder_index_XY(i) for i in batch_y]
|
74 |
-
#index_decoder_X = [i[0] for i in index_decoder]
|
75 |
-
#index_decoder_Y = [i[1] for i in index_decoder]
|
76 |
-
#print(index_decoder_Y)
|
77 |
-
|
78 |
-
|
79 |
-
#all_lens = []
|
80 |
all_lens = np.array([len(x) for x in batch_y])
|
81 |
-
#for x in batch_y:
|
82 |
-
# print(x)
|
83 |
-
# try:
|
84 |
-
# all_lens.append(len(x))
|
85 |
-
# except:
|
86 |
-
# all_lens.append(1)
|
87 |
-
#all_lens = np.array(all_lens)
|
88 |
|
89 |
maxL = np.max(all_lens)
|
90 |
|
91 |
-
#idx = all_lens
|
92 |
-
#print(idx)
|
93 |
idx = np.argsort(all_lens)
|
94 |
idx = np.sort(idx)
|
95 |
-
#print(idx)
|
96 |
-
#idx = idx[::-1] # decreasing
|
97 |
-
#print(idx)
|
98 |
batch_x = [batch_x[i] for i in idx]
|
99 |
batch_y = [batch_y[i] for i in idx]
|
100 |
all_lens = all_lens[idx]
|
101 |
|
102 |
index_decoder_X = np.array([index_decoder_X[i] for i in idx])
|
103 |
index_decoder_Y = np.array([index_decoder_Y[i] for i in idx])
|
104 |
-
#print(index_decoder_Y)
|
105 |
|
106 |
numpy_batch_x = batch_x
|
107 |
|
108 |
-
|
109 |
-
|
110 |
batch_x = align_variable_numpy(batch_x,maxL,2000001)
|
111 |
batch_y = align_variable_numpy(batch_y,maxL,2)
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
print(len(batch_x))
|
120 |
-
#batch_x = Variable(torch.from_numpy(batch_x.astype(np.int64)))
|
121 |
batch_x = Variable(torch.from_numpy(np.array(batch_x, dtype="int64")))
|
122 |
|
123 |
-
|
124 |
if use_cuda:
|
125 |
batch_x = batch_x.cuda()
|
126 |
|
127 |
-
|
128 |
-
|
129 |
return numpy_batch_x,batch_x,batch_y,index_decoder_X,index_decoder_Y,all_lens,maxL
|
130 |
|
131 |
|
@@ -144,7 +103,6 @@ class TrainSolver(object):
|
|
144 |
self.lr_decay_epoch = lr_decay_epoch
|
145 |
self.eval_size = eval_size
|
146 |
|
147 |
-
|
148 |
self.dev_x, self.dev_y = dev_x, dev_y
|
149 |
|
150 |
self.model = model
|
@@ -152,294 +110,70 @@ class TrainSolver(object):
|
|
152 |
self.weight_decay =weight_decay
|
153 |
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
def sample_dev(self):
|
158 |
-
test_tr_x = []
|
159 |
-
test_tr_y = []
|
160 |
-
select_index = random.sample(range(len(self.train_y)),self.eval_size)
|
161 |
-
test_tr_x = [self.train_x[n] for n in select_index]
|
162 |
-
test_tr_y = [self.train_y[n] for n in select_index]
|
163 |
-
|
164 |
-
return test_tr_x,test_tr_y
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
def get_batch_micro_metric(self,pre_b, ground_b, x,index2word, fukugen, nloop):
|
173 |
|
|
|
|
|
174 |
tokendic = {}
|
175 |
-
#with open('index2word.pickle', 'rb') as f:
|
176 |
-
# index2word = pickle.load(f)
|
177 |
for n,i in enumerate(index2word):
|
178 |
tokendic[n] = i
|
179 |
-
|
180 |
-
All_R = []
|
181 |
-
All_G = []
|
182 |
-
"""
|
183 |
-
for i,cur_seq_y in enumerate(zip(ground_b,fukugen[nloop])):
|
184 |
-
#print(fukugen[nloop])
|
185 |
-
fuku = cur_seq_y[1]
|
186 |
-
cur_seq_y = cur_seq_y[0]
|
187 |
-
index_of_1 = np.where(cur_seq_y==1)[0]
|
188 |
-
#print(index_of_1)
|
189 |
-
index_pre = pre_b[i]
|
190 |
-
inp = x[i]
|
191 |
-
#print(len(inp))
|
192 |
-
"""
|
193 |
-
print(len(pre_b), len(ground_b), len(fukugen))
|
194 |
-
#global leng
|
195 |
-
#print(fukugen)
|
196 |
for i,cur_seq_y in enumerate(ground_b):
|
197 |
-
#print(fukugen[nloop])
|
198 |
fuku = fukugen[i]
|
199 |
-
#cur_seq_y = cur_seq_y[0]
|
200 |
index_of_1 = np.where(cur_seq_y==1)[0]
|
201 |
-
#print(index_of_1)
|
202 |
index_pre = pre_b[i]
|
203 |
inp = x[i]
|
204 |
-
#print(len(inp))
|
205 |
|
206 |
index_pre = np.array(index_pre)
|
207 |
END_B = index_of_1[-1]
|
208 |
index_pre = index_pre[index_pre != END_B]
|
209 |
index_of_1 = index_of_1[index_of_1 != END_B]
|
210 |
|
211 |
-
no_correct = len(np.intersect1d(list(index_of_1), list(index_pre)))
|
212 |
-
All_C.append(no_correct)
|
213 |
-
All_R.append(len(index_pre))
|
214 |
-
All_G.append(len(index_of_1))
|
215 |
|
216 |
index_of_1 = list(index_of_1)
|
217 |
index_pre = list(index_pre)
|
218 |
|
219 |
-
FN = []
|
220 |
FP = []
|
221 |
-
TP = []
|
222 |
sent = []
|
223 |
ex = ""
|
224 |
-
for j in inp
|
225 |
-
sent.append(tokendic[int(j.to('cpu').detach().numpy().copy())])
|
226 |
-
for k in index_of_1:
|
227 |
-
if k not in index_pre:
|
228 |
-
FN.append(k)
|
229 |
-
if k in index_pre:
|
230 |
-
TP.append(k)
|
231 |
for k in index_pre:
|
232 |
if k not in index_of_1:
|
233 |
FP.append(k)
|
234 |
-
#
|
235 |
-
|
236 |
-
#for n,i in enumerate(sent):
|
237 |
for n,k in enumerate(zip(sent, fuku)):
|
238 |
f = k[1]
|
239 |
i = k[0]
|
240 |
if k == "<pad>":
|
241 |
continue
|
242 |
if n in FP:
|
243 |
-
ex += f + "<FP>"
|
244 |
-
else:
|
245 |
ex += f
|
246 |
-
|
247 |
-
|
248 |
-
#ex += i + "<FN>"
|
249 |
-
ex += i
|
250 |
-
elif n in FP:
|
251 |
-
ex += i + "<FP>"
|
252 |
-
elif n in TP:
|
253 |
-
ex += i + "<TP>"
|
254 |
else:
|
255 |
-
ex +=
|
256 |
-
|
257 |
-
|
258 |
-
# f.write(ex+"\n")
|
259 |
-
#print(i)
|
260 |
-
#leng += 1
|
261 |
-
|
262 |
-
return All_C,All_R,All_G
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
def get_batch_metric(self,pre_b, ground_b):
|
269 |
-
|
270 |
-
b_pr =[]
|
271 |
-
b_re =[]
|
272 |
-
b_f1 =[]
|
273 |
-
for i,cur_seq_y in enumerate(ground_b):
|
274 |
-
index_of_1 = np.where(cur_seq_y==1)[0]
|
275 |
-
index_pre = pre_b[i]
|
276 |
-
|
277 |
-
no_correct = len(np.intersect1d(index_of_1,index_pre))
|
278 |
-
|
279 |
-
cur_pre = no_correct / len(index_pre)
|
280 |
-
cur_rec = no_correct / len(index_of_1)
|
281 |
-
cur_f1 = 2*cur_pre*cur_rec/ (cur_pre+cur_rec)
|
282 |
-
|
283 |
-
b_pr.append(cur_pre)
|
284 |
-
b_re.append(cur_rec)
|
285 |
-
b_f1.append(cur_f1)
|
286 |
-
|
287 |
-
return b_pr,b_re,b_f1
|
288 |
-
|
289 |
|
290 |
|
291 |
def check_accuracy(self,data2X,data2Y,index2word, fukugen2):
|
292 |
-
for nloop in
|
293 |
dataY = data2Y[nloop]
|
294 |
dataX = data2X[nloop]
|
295 |
fukugen = fukugen2[nloop]
|
296 |
-
#print(len(dataX), len(dataY), len(fukugen))
|
297 |
need_loop = int(np.ceil(len(dataY) / self.batch_size))
|
298 |
-
#need_loop = int(np.ceil(len(dataY) / 1))
|
299 |
-
all_ave_loss =[]
|
300 |
-
all_boundary =[]
|
301 |
-
all_boundary_start = []
|
302 |
-
all_align_matrix = []
|
303 |
-
all_index_decoder_y =[]
|
304 |
-
all_x_save = []
|
305 |
-
|
306 |
-
all_C =[]
|
307 |
-
all_R =[]
|
308 |
-
all_G =[]
|
309 |
|
310 |
for lp in range(need_loop):
|
311 |
startN = lp*self.batch_size
|
312 |
endN = (lp+1)*self.batch_size
|
313 |
if endN > len(dataY):
|
314 |
endN = len(dataY)
|
315 |
-
#print(fukugen)
|
316 |
fukuge = fukugen[startN:endN]
|
317 |
-
#print(startN, endN)
|
318 |
-
#print(len(fukugen))
|
319 |
-
#print(fukugen)
|
320 |
-
#for nloop in tqdm(range(0,26431)):
|
321 |
numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
|
322 |
dataX[startN:endN], dataY[startN:endN], None, self.use_cuda)
|
323 |
-
#numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
|
324 |
-
# dataX, dataY, None, self.use_cuda)
|
325 |
-
|
326 |
-
batch_ave_loss, batch_boundary, batch_boundary_start, batch_align_matrix = self.model.predict(batch_x,
|
327 |
-
index_decoder_Y,
|
328 |
-
all_lens)
|
329 |
-
|
330 |
-
all_ave_loss.extend([batch_ave_loss.data.item()]) #[batch_ave_loss.data[0]]
|
331 |
-
all_boundary.extend(batch_boundary)
|
332 |
-
all_boundary_start.extend(batch_boundary_start)
|
333 |
-
all_align_matrix.extend(batch_align_matrix)
|
334 |
-
all_index_decoder_y.extend(index_decoder_Y)
|
335 |
-
all_x_save.extend(numpy_batch_x)
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
#print(batch_y)
|
340 |
-
ba_C,ba_R,ba_G = self.get_batch_micro_metric(batch_boundary,batch_y,batch_x,index2word, fukuge, nloop)
|
341 |
-
|
342 |
-
all_C.extend(ba_C)
|
343 |
-
all_R.extend(ba_R)
|
344 |
-
all_G.extend(ba_G)
|
345 |
-
|
346 |
-
|
347 |
-
ba_pre = np.sum(all_C)/ np.sum(all_R)
|
348 |
-
ba_rec = np.sum(all_C)/ np.sum(all_G)
|
349 |
-
ba_f1 = 2*ba_pre*ba_rec/ (ba_pre+ba_rec)
|
350 |
-
|
351 |
-
|
352 |
-
return np.mean(all_ave_loss),ba_pre,ba_rec,ba_f1, (all_x_save,all_index_decoder_y,all_boundary, all_boundary_start, all_align_matrix)
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
def adjust_learning_rate(self,optimizer,epoch,lr_decay=0.5, lr_decay_epoch=5):
|
361 |
-
|
362 |
-
if (epoch % lr_decay_epoch == 0) and (epoch != 0):
|
363 |
-
for param_group in optimizer.param_groups:
|
364 |
-
param_group['lr'] *= lr_decay
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
def train(self,n):
|
369 |
-
|
370 |
-
self.test_train_x, self.test_train_y = self.sample_dev()
|
371 |
-
|
372 |
-
optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr, weight_decay=self.weight_decay)
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
num_each_batch = int(np.round(len(self.train_y) / self.batch_size))
|
377 |
-
|
378 |
-
#os.mkdir(self.save_path)
|
379 |
-
|
380 |
-
best_i =0
|
381 |
-
best_f1 =0
|
382 |
-
|
383 |
-
for epoch in range(self.epoch):
|
384 |
-
print(epoch)
|
385 |
-
self.adjust_learning_rate(optimizer, epoch, 0.8, self.lr_decay_epoch)
|
386 |
-
|
387 |
-
track_epoch_loss = []
|
388 |
-
for iter in tqdm(range(num_each_batch)):
|
389 |
-
#print("epoch:%d,iteration:%d" % (epoch, iter))
|
390 |
-
|
391 |
-
self.model.zero_grad()
|
392 |
-
|
393 |
-
numpy_batch_x,batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
|
394 |
-
self.train_x, self.train_y, self.batch_size, self.use_cuda)
|
395 |
-
|
396 |
-
neg_loss = self.model.neg_log_likelihood(batch_x, index_decoder_X, index_decoder_Y,all_lens)
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
neg_loss_v = float(neg_loss.data.item())
|
401 |
-
#print(neg_loss_v)
|
402 |
-
track_epoch_loss.append(neg_loss_v)
|
403 |
-
|
404 |
-
neg_loss.backward()
|
405 |
-
|
406 |
-
clip_grad_norm(self.model.parameters(), 5)
|
407 |
-
optimizer.step()
|
408 |
-
|
409 |
-
|
410 |
-
#TODO: after each epoch,check accuracy
|
411 |
-
|
412 |
-
|
413 |
-
self.model.eval()
|
414 |
-
|
415 |
-
#tr_batch_ave_loss, tr_pre, tr_rec, tr_f1 ,visdata= self.check_accuracy(self.test_train_x,self.test_train_y)
|
416 |
-
|
417 |
-
dev_batch_ave_loss, dev_pre, dev_rec, dev_f1, visdata =self.check_accuracy(self.dev_x,self.dev_y,n)
|
418 |
-
print("f1="+str(dev_f1))
|
419 |
-
print("loss="+str(dev_batch_ave_loss))
|
420 |
-
"""
|
421 |
-
if best_f1 < dev_f1:
|
422 |
-
best_f1 = dev_f1
|
423 |
-
best_rec = dev_rec
|
424 |
-
best_pre = dev_pre
|
425 |
-
best_i = epoch
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
save_data = [epoch,dev_batch_ave_loss,dev_pre,dev_rec,dev_f1]
|
430 |
-
|
431 |
-
|
432 |
-
save_file_name = 'bs_{}_es_{}_lr_{}_lrdc_{}_wd_{}_epoch_loss_acc_pk_wd.txt'.format(self.batch_size,self.eval_size,self.lr,self.lr_decay_epoch,self.weight_decay)
|
433 |
-
"""
|
434 |
-
#with open(os.path.join(self.save_path,save_file_name), 'a') as f:
|
435 |
-
# f.write(','.join(map(str,save_data))+'\n')
|
436 |
-
|
437 |
-
|
438 |
-
#if epoch % 1 ==0 and epoch !=0:
|
439 |
-
# torch.save(self.model, os.path.join(self.save_path,r'model_epoch_%d.torchsave'%(epoch)))
|
440 |
-
|
441 |
|
442 |
-
|
|
|
443 |
|
444 |
-
|
445 |
-
return best_i,best_f1,n
|
|
|
6 |
import random
|
7 |
from torch.nn.utils import clip_grad_norm
|
8 |
import copy
|
|
|
9 |
|
10 |
import os
|
11 |
import pickle
|
|
|
55 |
|
56 |
|
57 |
def sample_a_sorted_batch_from_numpy(numpyX,numpyY,batch_size,use_cuda):
|
58 |
+
select_index = np.array(range(len(numpyY)))
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
select_index = np.array(range(len(numpyX)))
|
61 |
|
62 |
batch_x = [copy.deepcopy(numpyX[i]) for i in select_index]
|
63 |
batch_y = [copy.deepcopy(numpyY[i]) for i in select_index]
|
64 |
|
|
|
65 |
index_decoder_X,index_decoder_Y = get_decoder_index_XY(batch_y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
all_lens = np.array([len(x) for x in batch_y])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
maxL = np.max(all_lens)
|
69 |
|
|
|
|
|
70 |
idx = np.argsort(all_lens)
|
71 |
idx = np.sort(idx)
|
|
|
|
|
|
|
72 |
batch_x = [batch_x[i] for i in idx]
|
73 |
batch_y = [batch_y[i] for i in idx]
|
74 |
all_lens = all_lens[idx]
|
75 |
|
76 |
index_decoder_X = np.array([index_decoder_X[i] for i in idx])
|
77 |
index_decoder_Y = np.array([index_decoder_Y[i] for i in idx])
|
|
|
78 |
|
79 |
numpy_batch_x = batch_x
|
80 |
|
|
|
|
|
81 |
batch_x = align_variable_numpy(batch_x,maxL,2000001)
|
82 |
batch_y = align_variable_numpy(batch_y,maxL,2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
batch_x = Variable(torch.from_numpy(np.array(batch_x, dtype="int64")))
|
84 |
|
|
|
85 |
if use_cuda:
|
86 |
batch_x = batch_x.cuda()
|
87 |
|
|
|
|
|
88 |
return numpy_batch_x,batch_x,batch_y,index_decoder_X,index_decoder_Y,all_lens,maxL
|
89 |
|
90 |
|
|
|
103 |
self.lr_decay_epoch = lr_decay_epoch
|
104 |
self.eval_size = eval_size
|
105 |
|
|
|
106 |
self.dev_x, self.dev_y = dev_x, dev_y
|
107 |
|
108 |
self.model = model
|
|
|
110 |
self.weight_decay =weight_decay
|
111 |
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
def get_batch_micro_metric(self,pre_b, ground_b, x,index2word, fukugen, nloop):
|
114 |
|
115 |
+
|
116 |
+
|
117 |
tokendic = {}
|
|
|
|
|
118 |
for n,i in enumerate(index2word):
|
119 |
tokendic[n] = i
|
120 |
+
sents = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
for i,cur_seq_y in enumerate(ground_b):
|
|
|
122 |
fuku = fukugen[i]
|
|
|
123 |
index_of_1 = np.where(cur_seq_y==1)[0]
|
|
|
124 |
index_pre = pre_b[i]
|
125 |
inp = x[i]
|
|
|
126 |
|
127 |
index_pre = np.array(index_pre)
|
128 |
END_B = index_of_1[-1]
|
129 |
index_pre = index_pre[index_pre != END_B]
|
130 |
index_of_1 = index_of_1[index_of_1 != END_B]
|
131 |
|
|
|
|
|
|
|
|
|
132 |
|
133 |
index_of_1 = list(index_of_1)
|
134 |
index_pre = list(index_pre)
|
135 |
|
|
|
136 |
FP = []
|
|
|
137 |
sent = []
|
138 |
ex = ""
|
139 |
+
sent = [tokendic[int(j.to('cpu').detach().numpy().copy())] for j in inp]
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
for k in index_pre:
|
141 |
if k not in index_of_1:
|
142 |
FP.append(k)
|
143 |
+
#FP = [int(j.to('cpu').detach().numpy().copy()) for j in FP]
|
144 |
+
|
|
|
145 |
for n,k in enumerate(zip(sent, fuku)):
|
146 |
f = k[1]
|
147 |
i = k[0]
|
148 |
if k == "<pad>":
|
149 |
continue
|
150 |
if n in FP:
|
|
|
|
|
151 |
ex += f
|
152 |
+
sents.append(ex)
|
153 |
+
ex = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
else:
|
155 |
+
ex += f
|
156 |
+
sents.append(ex)
|
157 |
+
return sents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
|
160 |
def check_accuracy(self,data2X,data2Y,index2word, fukugen2):
|
161 |
+
for nloop in range(1):
|
162 |
dataY = data2Y[nloop]
|
163 |
dataX = data2X[nloop]
|
164 |
fukugen = fukugen2[nloop]
|
|
|
165 |
need_loop = int(np.ceil(len(dataY) / self.batch_size))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
for lp in range(need_loop):
|
168 |
startN = lp*self.batch_size
|
169 |
endN = (lp+1)*self.batch_size
|
170 |
if endN > len(dataY):
|
171 |
endN = len(dataY)
|
|
|
172 |
fukuge = fukugen[startN:endN]
|
|
|
|
|
|
|
|
|
173 |
numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
|
174 |
dataX[startN:endN], dataY[startN:endN], None, self.use_cuda)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
+
batch_ave_loss, batch_boundary, batch_boundary_start, batch_align_matrix = self.model.predict(batch_x,index_decoder_Y,all_lens)
|
177 |
+
output_texts = self.get_batch_micro_metric(batch_boundary,batch_y,batch_x,index2word, fukuge, nloop)
|
178 |
|
179 |
+
return output_texts
|
|