Spaces:
Sleeping
Sleeping
File size: 6,656 Bytes
ecfd12f de92ab7 ecfd12f de92ab7 ecfd12f de92ab7 ecfd12f de92ab7 ecfd12f de92ab7 ecfd12f de92ab7 ecfd12f de92ab7 ecfd12f de92ab7 ecfd12f de92ab7 ecfd12f de92ab7 ecfd12f de92ab7 ecfd12f de92ab7 ecfd12f de92ab7 ecfd12f de92ab7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
from src.helper import *
import gradio as gr
import torch
class LegalNER():
def __init__(self, model, tokenizer, ids_to_labels, check_point='IndoBERT (IndoLEM)', label_all_tokens=True):
self.model = model
self.tokenizer = tokenizer
self.check_point = check_point
self.label_all_tokens = label_all_tokens
self.prediction_label = ''
self.data_token = ''
self.ids_to_labels = ids_to_labels
self.label_extraction = []
self.tokenizer_decode = ''
self.label_convert = {'B_VERN' : 'Nomor Putusan',
'B_DEFN' : 'Nama Terdakwa',
'B_CRIA' : 'Tindak Pidana',
'B_ARTV' : 'Melanggar KUHP',
'B_PENA' : 'Tuntutan Hukum',
'B_PUNI' : 'Putusan Hukum',
'B_TIMV' : 'Tanggal Putusan',
'B_JUDP' : 'Hakim Ketua',
'B_JUDG' : 'Hakim Anggota',
'B_REGI' : 'Panitera',
'B_PROS' : 'Penuntut Umum',
'B_ADVO' : 'Pengacara',
}
def align_word_ids(self, texts):
tokenized_inputs = self.tokenizer(texts, padding='max_length', max_length=512, truncation=True)
word_ids = tokenized_inputs.word_ids()
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100)
elif word_idx != previous_word_idx:
try:
label_ids.append(1)
except:
label_ids.append(-100)
else:
try:
label_ids.append(1 if self.label_all_tokens else -100)
except:
label_ids.append(-100)
previous_word_idx = word_idx
return label_ids
def labelToText(self):
prev_tag = 'O'
result = {}
temp = ''
# Menganggabungkan semua token menjadi satu kalimat sesuai dengan labelnya
for i, word in enumerate(self.data_token):
if self.prediction_label[i] != 'O':
if prev_tag == 'O' and temp != '':
temp = ''
if '##' in word:
temp += word.replace('##', '')
else:
temp += ' ' + word
else:
if temp != "":
result[prev_tag.replace("I_", "B_")] = temp.strip()
temp = ""
prev_tag = self.prediction_label[i]
return result
def dis_pdf_prediction(self):
# Memilih prediksi entitas yang paling bagus
entity_result = {}
for i in self.label_extraction:
if len(list(i.keys())) > 1:
for y in i.items():
if y[0] not in entity_result:
entity_result[y[0]] = y[1]
else:
if len(entity_result[y[0]]) < len(y[1]):
entity_result[y[0]] = y[1]
else:
if tuple(i.items())[0] not in entity_result:
entity_result[tuple(i.items())[0][0]] = tuple(i.items())[0][1]
# Mengkonversi hasil ekstraski entitas dalam bentuk List
result = ''
for i, (label, data) in enumerate(entity_result.items()):
if label in ['B_PENA', 'B_ARTV', 'B_PROS']:
result += f'{i+1}. {self.label_convert[label]}\t = {data.capitalize()}\n'
elif label in ['B_JUDP', 'B_CRIA']:
result += f'{i+1}. {self.label_convert[label]}\t\t\t = {data.capitalize()}\n'
elif label in ['B_ADVO', 'B_REGI']:
result += f'{i+1}. {self.label_convert[label]}\t\t\t\t\t = {data.capitalize()}\n'
else:
result += f'{i+1}. {self.label_convert[label]}\t\t = {data.capitalize()}\n'
return result
def dis_text_prediction(self):
result = []
temp_result = {}
count_huruf = 0
temp_word = ''
temp_label = ''
temp_label = ''
temp_count_huruf = 0
prev_word = ''
for i, (word, label) in enumerate(zip(self.data_token, self.prediction_label)):
if label != 'O':
if temp_word != '' and '##' not in word:
temp_result['entity'] = temp_label
temp_result['word'] = temp_word
temp_result['start'] = temp_count_huruf
temp_result['end'] = temp_count_huruf + (len(temp_word))
result.append(temp_result)
temp_word, temp_label, temp_count_huruf, temp_result = '', '', 0, {}
if '##' in word:
temp_word += word.replace('##', '')
else:
temp_label = label
temp_word = word
temp_count_huruf = count_huruf
if i == len(self.data_token)-1:
temp_result['entity'] = temp_label
temp_result['word'] = temp_word
temp_result['start'] = temp_count_huruf
temp_result['end'] = temp_count_huruf + (len(temp_word))
result.append(temp_result)
temp_word, temp_label, temp_count_huruf, temp_result = '', '', 0, {}
if '##' in word:
count_huruf += len(word)-2
else:
count_huruf += len(word)+1
return result
def fit_transform(self, texts, progress=gr.Progress()):
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
if use_cuda:
self.model = self.model.cuda()
file_check_point = 'model/IndoLEM/model_fold_4.pth' if self.check_point == 'IndoBERT (IndoLEM)' else 'model/IndoNLU/model_fold_4.pth'
model_weights = torch.load(file_check_point, map_location=torch.device(device))
self.model.load_state_dict(model_weights)
for text in progress.tqdm(texts, desc="Ekstraksi Entitas"):
toknize = self.tokenizer(text, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")
input_ids = toknize['input_ids'].to(device)
mask = toknize['attention_mask'].to(device)
logits = self.model(input_ids, mask, None)
label_ids = torch.Tensor(self.align_word_ids(text)).unsqueeze(0).to(device)
logits_clean = logits[0][label_ids != -100]
predictions = logits_clean.argmax(dim=1).tolist()
prediction_label = [self.ids_to_labels[i] for i in predictions]
input_ids_conv = self.tokenizer.convert_ids_to_tokens(toknize['input_ids'][0])
data_token = [word for word in input_ids_conv if word not in ['[CLS]', '[SEP]', '[PAD]']]
self.tokenizer_decode = token_decode(input_ids_conv)
self.data_token = data_token
self.prediction_label = prediction_label
labelConv = self.labelToText()
if labelConv:
self.label_extraction.append(labelConv)
def predict(self, doc):
if '.pdf' not in doc:
self.fit_transform([doc.strip()])
return self.dis_text_prediction()
else:
file_pdf = read_pdf(doc)
sentence_file = file_pdf.split(';')
self.fit_transform(sentence_file)
return self.dis_pdf_prediction() |