clinical_segment_splitter / run_segbot.py
kenichiro
commit
97b1ca5
raw
history blame
2.09 kB
import re
import pickle
import numpy as np
import random
import torch
from solver import TrainSolver
from model import PointerNetworks
import gensim
import MeCab
import pysbd
def create_data(doc,fm,split_method):
wakati = MeCab.Tagger("-Owakati -b 81920")
seg = pysbd.Segmenter(language="ja", clean=False)
texts = []
sent = ""
label = []
alls = []
labels, text, num = [], [], []
allab, altex, fukugenss = [], [], []
for n in range(1):
fukugens = []
if split_method == "pySBD":
lines = seg.segment(doc)
else:
doc = doc.strip().replace("。","。\n").replace(".",".\n")
doc = re.sub("(\n)+","\n",doc)
lines = doc.split("\n")
for line in lines:
line = line.strip()
if line == "":
continue
sent = wakati.parse(line).split(" ")[:-1]
flag = 0
label = []
texts = []
fukugen = []
for i in sent:
try:
texts.append(fm.vocab[i].index)
except KeyError:
texts.append(fm.vocab["<unk>"].index)
fukugen.append(i)
label.append(0)
label[-1] = 1
labels.append(np.array(label))
text.append(np.array(texts))
fukugens.append(fukugen)
allab.append(labels)
altex.append(text)
fukugenss.append(fukugens)
labels, text, fukugens= [], [], []
return altex, allab, fukugenss
def generate(doc, mymodel, fm, index2word, split_method):
X_tes, Y_tes, fukugen = create_data(doc,fm,split_method)
output_texts = mymodel.check_accuracy(X_tes, Y_tes,index2word, fukugen)
return output_texts
def setup():
with open('index2word.pickle', 'rb') as f:
index2word = pickle.load(f)
with open('model.pickle', 'rb') as f:
mysolver = torch.load(f, torch.device('cpu'))
with open('fm.pickle', 'rb') as f:
fm = pickle.load(f)
return mysolver,fm,index2word