File size: 1,882 Bytes
a450bc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from tqdm import tqdm
import torch
from read_file import *
from align_word_ids import *
from convertTotext import *

def pdf_predict(model, tokenizer, file_path, ids_to_labels, check_point='IndoBERT (IndoLEM)'):
  file_pdf = read_pdf(file_path)
  sentence_file = file_pdf.split(';')

  use_cuda = torch.cuda.is_available()
  device = torch.device("cuda" if use_cuda else "cpu")
  if use_cuda:
    model = model.cuda()

  file_check_point = 'model/IndoLEM/model_fold_4.pth' if check_point == 'IndoBERT (IndoLEM)' else 'model/IndoNLU/model_fold_4.pth'

  model_weights = torch.load(file_check_point, map_location=torch.device(device))
  model.load_state_dict(model_weights)

  label_extraction = []
  for text in tqdm(sentence_file, desc="Prediction Sentence"):
    toknize = tokenizer(text, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")
    input_ids = toknize['input_ids'].to(device)
    mask = toknize['attention_mask'].to(device)

    logits = model(input_ids, mask, None)
    label_ids = torch.Tensor(align_word_ids(text, tokenizer, True)).unsqueeze(0).to(device)
    logits_clean = logits[0][label_ids != -100]
    predictions = logits_clean.argmax(dim=1).tolist()
    prediction_label = [ids_to_labels[i] for i in predictions]

    input_ids_conv = tokenizer.convert_ids_to_tokens(toknize['input_ids'][0])
    data_token = [word for word in input_ids_conv if word not in ['[CLS]', '[SEP]', '[PAD]']]
    nerExtraction = convertTotext(data_token, prediction_label)

    if nerExtraction:
      label_extraction.append(nerExtraction)
      # print(f"\nText : {text}")
      # print(f"Predict Label : {prediction_label}")
      # print()

      # print(f"Hasil Ekstrak NER:")
      # print(nerExtraction)
      # print(f"Panjang Token : {len(data_token)}, Panjang Predict Label : {len(prediction_label)}")
      # print()

  return label_extraction