File size: 6,656 Bytes
ecfd12f
 
 
 
 
de92ab7
ecfd12f
 
 
 
 
de92ab7
 
 
 
ecfd12f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de92ab7
ecfd12f
 
 
 
 
de92ab7
ecfd12f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de92ab7
 
ecfd12f
 
de92ab7
ecfd12f
 
 
 
 
 
 
 
 
 
de92ab7
 
ecfd12f
 
 
 
 
 
 
 
de92ab7
ecfd12f
de92ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecfd12f
 
de92ab7
ecfd12f
 
 
 
 
 
 
 
 
 
de92ab7
ecfd12f
 
 
 
 
 
 
 
 
 
 
 
de92ab7
 
ecfd12f
de92ab7
ecfd12f
 
de92ab7
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from src.helper import *
import gradio as gr
import torch

class LegalNER():
  def __init__(self, model, tokenizer, ids_to_labels, check_point='IndoBERT (IndoLEM)', label_all_tokens=True):
    self.model = model
    self.tokenizer = tokenizer
    self.check_point = check_point
    self.label_all_tokens = label_all_tokens
    self.prediction_label = ''
    self.data_token = ''
    self.ids_to_labels = ids_to_labels
    self.label_extraction = []
    self.tokenizer_decode = ''
    self.label_convert = {'B_VERN' : 'Nomor Putusan',
                   'B_DEFN' : 'Nama Terdakwa',
                   'B_CRIA' : 'Tindak Pidana',
                   'B_ARTV' : 'Melanggar KUHP',
                   'B_PENA' : 'Tuntutan Hukum',
                   'B_PUNI' : 'Putusan Hukum',
                   'B_TIMV' : 'Tanggal Putusan',
                   'B_JUDP' : 'Hakim Ketua',
                   'B_JUDG' : 'Hakim Anggota',
                   'B_REGI' : 'Panitera',
                   'B_PROS' : 'Penuntut Umum',
                   'B_ADVO' : 'Pengacara',
                   }

  def align_word_ids(self, texts):
    tokenized_inputs = self.tokenizer(texts, padding='max_length', max_length=512, truncation=True)
    word_ids = tokenized_inputs.word_ids()
    previous_word_idx = None
    label_ids = []

    for word_idx in word_ids:
        if word_idx is None:
            label_ids.append(-100)

        elif word_idx != previous_word_idx:
            try:
                label_ids.append(1)
            except:
                label_ids.append(-100)
        else:
            try:
                label_ids.append(1 if self.label_all_tokens else -100)
            except:
                label_ids.append(-100)
        previous_word_idx = word_idx

    return label_ids

  def labelToText(self):
    prev_tag = 'O'
    result = {}
    temp = ''

    # Menganggabungkan semua token menjadi satu kalimat sesuai dengan labelnya
    for i, word in enumerate(self.data_token):
      if self.prediction_label[i] != 'O':
        if prev_tag == 'O' and temp != '':
          temp = ''

        if '##' in word:
          temp += word.replace('##', '')

        else:
          temp +=  ' ' + word
      else:
        if temp != "":
          result[prev_tag.replace("I_", "B_")] = temp.strip()
        temp = ""

      prev_tag = self.prediction_label[i]

    return result

  def dis_pdf_prediction(self):
    # Memilih prediksi entitas yang paling bagus
    entity_result = {}
    for i in self.label_extraction:
      if len(list(i.keys())) > 1:
        for y in i.items():
          if y[0] not in entity_result:
            entity_result[y[0]] = y[1]
          else:
            if len(entity_result[y[0]]) < len(y[1]):
              entity_result[y[0]] = y[1]
      else:
        if tuple(i.items())[0] not in entity_result:
          entity_result[tuple(i.items())[0][0]] = tuple(i.items())[0][1]

    # Mengkonversi hasil ekstraski entitas dalam bentuk List
    result = ''
    for i, (label, data) in enumerate(entity_result.items()):
      if label in ['B_PENA', 'B_ARTV', 'B_PROS']:
        result += f'{i+1}. {self.label_convert[label]}\t = {data.capitalize()}\n'
      elif label in ['B_JUDP', 'B_CRIA']:
        result += f'{i+1}. {self.label_convert[label]}\t\t\t = {data.capitalize()}\n'
      elif label in ['B_ADVO', 'B_REGI']:
        result += f'{i+1}. {self.label_convert[label]}\t\t\t\t\t = {data.capitalize()}\n'
      else:
        result += f'{i+1}. {self.label_convert[label]}\t\t = {data.capitalize()}\n'

    return result

  def dis_text_prediction(self):
    result = []
    temp_result = {}
    count_huruf = 0
    temp_word = ''
    temp_label = ''
    temp_label = ''
    temp_count_huruf = 0
    prev_word = ''
    for i, (word, label) in enumerate(zip(self.data_token, self.prediction_label)):
      if label != 'O':
        if temp_word != '' and '##' not in word:
          temp_result['entity'] = temp_label
          temp_result['word'] = temp_word
          temp_result['start'] = temp_count_huruf
          temp_result['end'] = temp_count_huruf + (len(temp_word))
          result.append(temp_result)
          temp_word, temp_label, temp_count_huruf, temp_result = '', '', 0, {}

        if '##' in word:
          temp_word += word.replace('##', '')

        else:
          temp_label = label
          temp_word = word
          temp_count_huruf = count_huruf

      if i == len(self.data_token)-1:
        temp_result['entity'] = temp_label
        temp_result['word'] = temp_word
        temp_result['start'] = temp_count_huruf
        temp_result['end'] = temp_count_huruf + (len(temp_word))
        result.append(temp_result)
        temp_word, temp_label, temp_count_huruf, temp_result = '', '', 0, {}

      if '##' in word:
        count_huruf += len(word)-2

      else:
        count_huruf += len(word)+1

    return result

  def fit_transform(self, texts, progress=gr.Progress()):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if use_cuda:
      self.model = self.model.cuda()

    file_check_point = 'model/IndoLEM/model_fold_4.pth' if self.check_point == 'IndoBERT (IndoLEM)' else 'model/IndoNLU/model_fold_4.pth'

    model_weights = torch.load(file_check_point, map_location=torch.device(device))
    self.model.load_state_dict(model_weights)

    for text in progress.tqdm(texts, desc="Ekstraksi Entitas"):
      toknize = self.tokenizer(text, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")
      input_ids = toknize['input_ids'].to(device)
      mask = toknize['attention_mask'].to(device)

      logits = self.model(input_ids, mask, None)
      label_ids = torch.Tensor(self.align_word_ids(text)).unsqueeze(0).to(device)
      logits_clean = logits[0][label_ids != -100]
      predictions = logits_clean.argmax(dim=1).tolist()
      prediction_label = [self.ids_to_labels[i] for i in predictions]

      input_ids_conv = self.tokenizer.convert_ids_to_tokens(toknize['input_ids'][0])
      data_token = [word for word in input_ids_conv if word not in ['[CLS]', '[SEP]', '[PAD]']]
      self.tokenizer_decode = token_decode(input_ids_conv)
      self.data_token = data_token
      self.prediction_label = prediction_label
      labelConv = self.labelToText()

      if labelConv:
        self.label_extraction.append(labelConv)

  def predict(self, doc):
    if '.pdf' not in doc:
      self.fit_transform([doc.strip()])
      return self.dis_text_prediction()
    else:
      file_pdf = read_pdf(doc)
      sentence_file = file_pdf.split(';')
      self.fit_transform(sentence_file)
      return self.dis_pdf_prediction()