kenichiro commited on
Commit
89a2a73
·
1 Parent(s): d4c02d7
Files changed (1) hide show
  1. run_segbot.py +10 -1
run_segbot.py CHANGED
@@ -11,6 +11,15 @@ import MeCab
11
  import pysbd
12
  import io
13
 
 
 
 
 
 
 
 
 
 
14
  def create_data(doc,fm,split_method):
15
  wakati = MeCab.Tagger("-Owakati -b 81920")
16
  seg = pysbd.Segmenter(language="ja", clean=False)
@@ -67,7 +76,7 @@ def setup():
67
  with open('index2word.pickle', 'rb') as f:
68
  index2word = pickle.load(f)
69
  with open('model.pickle', 'rb') as f:
70
- mysolver = torch.load(io.BytesIO(f), map_location='cpu')
71
  with open('fm.pickle', 'rb') as f:
72
  fm = pickle.load(f)
73
 
 
11
  import pysbd
12
  import io
13
 
14
+
15
+
16
+ class CPU_Unpickler(pickle.Unpickler):
17
+ def find_class(self, module, name):
18
+ if module == 'torch.storage' and name == '_load_from_bytes':
19
+ return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
20
+ else: return super().find_class(module, name)
21
+
22
+
23
  def create_data(doc,fm,split_method):
24
  wakati = MeCab.Tagger("-Owakati -b 81920")
25
  seg = pysbd.Segmenter(language="ja", clean=False)
 
76
  with open('index2word.pickle', 'rb') as f:
77
  index2word = pickle.load(f)
78
  with open('model.pickle', 'rb') as f:
79
+ mysolver = CPU_Unpickler(f).load()
80
  with open('fm.pickle', 'rb') as f:
81
  fm = pickle.load(f)
82