File size: 5,182 Bytes
926183f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46a030d
926183f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46a030d
 
926183f
 
 
46a030d
926183f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46a030d
926183f
 
 
46a030d
 
926183f
 
 
 
 
 
 
46a030d
 
926183f
46a030d
 
 
926183f
 
 
46a030d
926183f
 
 
 
 
 
 
 
 
 
 
 
 
 
46a030d
 
926183f
46a030d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import torch.optim as optim
import numpy as np
import torch
from torch.autograd import Variable

import random
from torch.nn.utils import clip_grad_norm
import copy

import os
import pickle



def get_decoder_index_XY(batchY):
    '''

    :param batchY: like [0 0 1 0 0 0 0 1]
    :return:
    '''


    returnX =[]
    returnY =[]
    for i in range(len(batchY)):

        curY = batchY[i]
        index_1 = np.where(curY==1)

        decoderY = index_1[0]

        if len(index_1[0]) ==1:
            decoderX = np.array([0])
        else:
            decoderX = np.append([0],decoderY[0:-1]+1)
        returnX.append(decoderX)
        returnY.append(decoderY)

    returnX = np.array(returnX)
    returnY = np.array(returnY)

    return returnX,returnY

def align_variable_numpy(X,maxL,paddingNumber):

    aligned = []
    for cur in X:
        ext_cur = []
        ext_cur.extend(cur)
        ext_cur.extend([paddingNumber] * (maxL - len(cur)))
        aligned.append(ext_cur)
    aligned = np.array(aligned)

    return aligned


def sample_a_sorted_batch_from_numpy(numpyX,numpyY,batch_size,use_cuda):
    select_index = np.array(range(len(numpyY)))
    
    select_index = np.array(range(len(numpyX)))

    batch_x = [copy.deepcopy(numpyX[i]) for i in select_index]
    batch_y = [copy.deepcopy(numpyY[i]) for i in select_index]

    index_decoder_X,index_decoder_Y = get_decoder_index_XY(batch_y)
    all_lens = np.array([len(x) for x in batch_y])
    
    maxL = np.max(all_lens)
    
    idx = np.argsort(all_lens)
    idx = np.sort(idx)
    batch_x = [batch_x[i] for i in idx]
    batch_y = [batch_y[i] for i in idx]
    all_lens = all_lens[idx]
    
    index_decoder_X = np.array([index_decoder_X[i] for i in idx])
    index_decoder_Y = np.array([index_decoder_Y[i] for i in idx])

    numpy_batch_x = batch_x

    batch_x = align_variable_numpy(batch_x,maxL,2000001)
    batch_y = align_variable_numpy(batch_y,maxL,2)
    batch_x = Variable(torch.from_numpy(np.array(batch_x, dtype="int64")))

    if use_cuda:
        batch_x = batch_x.cuda()

    return  numpy_batch_x,batch_x,batch_y,index_decoder_X,index_decoder_Y,all_lens,maxL




class TrainSolver(object):
    def __init__(self, model,train_x,train_y,dev_x,dev_y,save_path,batch_size,eval_size,epoch, lr,lr_decay_epoch,weight_decay,use_cuda):

        self.lr = lr
        self.model = model
        self.epoch = epoch
        self.train_x = train_x
        self.train_y = train_y
        self.use_cuda = use_cuda
        self.batch_size = batch_size
        self.lr_decay_epoch = lr_decay_epoch
        self.eval_size  = eval_size

        self.dev_x, self.dev_y = dev_x, dev_y

        self.model = model
        self.save_path = save_path
        self.weight_decay =weight_decay


    def get_batch_micro_metric(self,pre_b, ground_b, x,index2word, fukugen, nloop):


        
        tokendic = {}
        for n,i in enumerate(index2word):
            tokendic[n] = i
        sents = []
        for i,cur_seq_y in enumerate(ground_b):
            fuku = fukugen[i]
            index_of_1 = np.where(cur_seq_y==1)[0]
            index_pre = pre_b[i]
            inp = x[i]

            index_pre = np.array(index_pre)
            END_B = index_of_1[-1]
            index_pre = index_pre[index_pre != END_B]
            index_of_1 = index_of_1[index_of_1 != END_B]


            index_of_1 = list(index_of_1)
            index_pre = list(index_pre)

            FP = []
            sent = []
            ex = ""
            sent = [tokendic[int(j.to('cpu').detach().numpy().copy())] for j in inp]
            for k in index_pre:
                if k not in index_of_1:
                    FP.append(k)
            #FP = [int(j.to('cpu').detach().numpy().copy()) for j in FP]

            for n,k in enumerate(zip(sent, fuku)):
                f = k[1]
                i = k[0]
                if k == "<pad>":
                    continue
                if n in FP:
                    ex += f
                    sents.append(ex)
                    ex = ""
                else:
                    ex += f
            sents.append(ex)
        return sents


    def check_accuracy(self,data2X,data2Y,index2word, fukugen2):
        for nloop in range(1):
            dataY = data2Y[nloop]
            dataX = data2X[nloop]
            fukugen = fukugen2[nloop]
            need_loop = int(np.ceil(len(dataY) / self.batch_size))
            
            for lp in range(need_loop):
                startN = lp*self.batch_size
                endN =  (lp+1)*self.batch_size
                if endN > len(dataY):
                    endN = len(dataY)
                fukuge = fukugen[startN:endN]
                numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
                    dataX[startN:endN], dataY[startN:endN], None, self.use_cuda)

                batch_ave_loss, batch_boundary, batch_boundary_start, batch_align_matrix = self.model.predict(batch_x,index_decoder_Y,all_lens)    
                output_texts = self.get_batch_micro_metric(batch_boundary,batch_y,batch_x,index2word, fukuge, nloop)

        return output_texts