import gradio as gr
import torch
import requests
import PyPDF2
import re
# import nltk
# nltk.download('punkt')

class LegalNER():
  def __init__(self, model, tokenizer, ids_to_labels, check_point='IndoBERT (IndoLEM)'):
    self.model = model
    self.tokenizer = tokenizer
    self.check_point = check_point
    self.prediction_label = ''
    self.data_token = ''
    self.ids_to_labels = ids_to_labels
    self.label_extraction = []
    self.tokenizer_decode = ''
    self.label_convert = {'VERN' : 'Nomor Putusan',
                          'DEFN' : 'Nama Terdakwa',
                          'CRIA' : 'Tindak Pidana',
                          'ARTV' : 'Melanggar KUHP',
                          'PENA' : 'Tuntutan Hukum',
                          'PUNI' : 'Putusan Hukum',
                          'TIMV' : 'Tanggal Putusan',
                          'JUDP' : 'Hakim Ketua',
                          'JUDG' : 'Hakim Anggota',
                          'REGI' : 'Panitera',
                          'PROS' : 'Penuntut Umum',
                          'ADVO' : 'Pengacara'}

  def align_word_ids(self, texts):

    tokenized_inputs = self.tokenizer(texts, padding='max_length', max_length=512, truncation=True)

    word_ids = tokenized_inputs.word_ids()

    previous_word_idx = None
    label_ids = []

    for word_idx in word_ids:

        if word_idx is None:
            label_ids.append(-100)

        elif word_idx != previous_word_idx:
            try:
                label_ids.append(1)
            except:
                label_ids.append(-100)
        else:
            try:
                label_ids.append(1)
            except:
                label_ids.append(-100)
        previous_word_idx = word_idx

    return label_ids

  def labelToText(self):
    prev_tag = 'O'
    result = {}
    temp = ''

    # Menganggabungkan semua token menjadi satu kalimat sesuai dengan labelnya
    for i, word in enumerate(self.data_token):
      # Memproses semua token yang berlabel entitas bukan O
      if self.prediction_label[i] != 'O':
        if prev_tag == 'O' and temp != '':
          temp = ''

        if '##' in word:
          temp += word.replace('##', '')

        else:
          temp +=  ' ' + word

      else:
        # cek jika temp nya ada isinya di tambahkan ke dict result dengan key label sebelumnya
        if temp != "":
          # hanya mengambil label setelah tanda B_ /I_
          result[prev_tag[2:]] = temp.strip()
        temp = ""

      prev_tag = self.prediction_label[i]

    return result # Dictionary {VERN : 120 ...}

  # Menggabungkan setiap token hasil tokenizer dalam bentuk string
  def token_decode(self, input_ids_conv):
    result = ''
    temp = ''
    for i, word in enumerate(input_ids_conv):
      # Memfilter Token tambahan
      if word not in ['[CLS]', '[SEP]', '[PAD]']:
        # cek bahwa token saat ini termasuk token lanjutan atau tidak
        if temp != '' and '##' not in word:
          result += ' ' + temp
        # token lanjutan di tanda i dengan tanda paggar 2 "##"
        if '##' in word:
          temp += word.replace('##', '')
        # untuk posisi awal token
        else:
          temp = word
      # cek token terakhir sudah masuk atau belum
      if i == len(input_ids_conv)-1:
        result += ' ' + temp
    return result.strip()

  def dis_pdf_prediction(self):
    # Memilih prediksi entitas yang paling bagus
    entity_result = {}

    # Hasil dari extraksi label ini kadang double sehingga perlu di cari mana yang isinya lebih panjang
    for i in self.label_extraction:
      # jika hasil extraksinya lebih dari 1
      if len(list(i.keys())) > 1:
        # looping setiap item
        for y in i.items():
          # cek key nya sudah ada atau belum
          if y[0] not in entity_result:
            # jika belum tambahkan
            entity_result[y[0]] = y[1]
          else:
            # membandaingkan mana yang lebih panjang
            if len(entity_result[y[0]]) < len(y[1]):
              entity_result[y[0]] = y[1]
      else:
        # cek ada atu tidak dalam enity_result kalau tdidak langsung di tambahkan
        if list(i.items())[0] not in entity_result:
          entity_result[list(i.items())[0][0]] = list(i.items())[0][1]

    # Mengurutkan hasil entitas yang di dapat berdasarkan label convert
    sorted_entitu_result = {key: entity_result[key] for key in self.label_convert if key in entity_result}

    # Mengkonversi hasil ekstraski entitas dalam bentuk String
    result = ''
    for i, (label, data) in enumerate(sorted_entitu_result.items()):
      if label in ['PENA', 'ARTV']:
        result += f'{i+1}. {self.label_convert[label]}\t =   {data.capitalize()}\n'
      elif label in ['PROS']:
        if (i+1) >= 10:
          result += f'{i+1}. {self.label_convert[label]}\t =   {data.capitalize()}\n'
        else:
          result += f'{i+1}. {self.label_convert[label]}\t\t =   {data.capitalize()}\n'
      elif label in ['JUDP', 'CRIA']:
        result += f'{i+1}. {self.label_convert[label]}\t\t\t =   {data.capitalize()}\n'
      elif label in ['ADVO']:
        result += f'{i+1}. {self.label_convert[label]}\t\t\t\t =   {data.capitalize()}\n'
      elif label in ['REGI']:
        if (i+1) >= 10:
          result += f'{i+1}. {self.label_convert[label]}\t\t\t\t\t =   {data.capitalize()}\n'
        else:
          result += f'{i+1}. {self.label_convert[label]}\t\t\t\t\t\t =   {data.capitalize()}\n'
      else:
        result += f'{i+1}. {self.label_convert[label]}\t\t =   {data.capitalize()}\n'

    return result

  def dis_text_prediction(self):
    result = []
    temp_result = {}
    count_huruf = 0
    temp_word = ''
    temp_label = ''
    temp_count_huruf = 0
    prev_word = ''
    for i, (word, label) in enumerate(zip(self.data_token, self.prediction_label)):
      if label != 'O':
        # menambahkan token ketika token merupakan token tunggal atau tidak di pecah dengan tanda pagar
        if temp_word != '' and '##' not in word:
          temp_result['entity'] = temp_label
          temp_result['word'] = temp_word
          temp_result['start'] = temp_count_huruf
          temp_result['end'] = temp_count_huruf + (len(temp_word))
          result.append(temp_result)
          temp_word, temp_label, temp_count_huruf, temp_result = '', '', 0, {}

        # Jika sebuah kata lanjutan maka di tambahakan langung dengan menghapus tanda pagar
        if '##' in word:
          temp_word += word.replace('##', '')

        # Menyimpan token untuk pengecekan iterasi selanjutnya apakah memiliki token lanjutan atau tidak
        else:
          temp_label = label
          temp_word = word
          temp_count_huruf = count_huruf

      # Menambahkan token terakhir yang masih tersimpan dalam temporari variabel
      if i == len(self.data_token)-1:
        temp_result['entity'] = temp_label
        temp_result['word'] = temp_word
        temp_result['start'] = temp_count_huruf
        temp_result['end'] = temp_count_huruf + (len(temp_word))
        result.append(temp_result)
        temp_word, temp_label, temp_count_huruf, temp_result = '', '', 0, {}

      # Perhitungan jumlah huruf untuk pembuatan labelnya
      if '##' in word:
        count_huruf += len(word)-2

      else:
        count_huruf += len(word)+1

    return result

  # Fungsi untuk proses Predict dari inputan
  def fit_transform(self, texts, progress=gr.Progress()):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if use_cuda:
      self.model = self.model.cuda()

    file_check_point = 'model/indoBERT-indoLEM-Fold-5.pth' if self.check_point == 'IndoBERT (IndoLEM)' else 'model/indoBERT-indoNLU-Fold-5.pth'

    model_weights = torch.load(file_check_point, map_location=torch.device(device))
    self.model.load_state_dict(model_weights)

    for text in progress.tqdm(texts, desc="Ekstraksi Entitas"):
      toknize = self.tokenizer(text, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")
      input_ids = toknize['input_ids'].to(device)
      mask = toknize['attention_mask'].to(device)

      logits = self.model(input_ids, mask, None)
      label_ids = torch.Tensor(self.align_word_ids(text)).unsqueeze(0).to(device)
      logits_clean = logits[0][label_ids != -100]
      predictions = logits_clean.argmax(dim=1).tolist()
      prediction_label = [self.ids_to_labels[i] for i in predictions]

      input_ids_conv = self.tokenizer.convert_ids_to_tokens(toknize['input_ids'][0])
      data_token = [word for word in input_ids_conv if word not in ['[CLS]', '[SEP]', '[PAD]']]
      self.tokenizer_decode = self.token_decode(input_ids_conv)
      self.data_token = data_token
      self.prediction_label = prediction_label
      labelConv = self.labelToText()

      if labelConv:
        self.label_extraction.append(labelConv) # Dictionary {VERN : 120 ...}

  def clean_text(self, text):
    # Watermark dan Header
    text = text.replace("Mahkamah Agung Republik Indonesia\nMahkamah Agung Republik Indonesia\nMahkamah Agung Republik Indonesia\nMahkamah Agung Republik Indonesia\nMahkamah Agung Republik Indonesia\nDirektori Putusan Mahkamah Agung Republik Indonesia\nputusan.mahkamahagung.go.id\n", "")
    # Footer
    text = text.replace("\nDisclaimer\nKepaniteraan Mahkamah Agung Republik Indonesia berusaha untuk selalu mencantumkan informasi paling kini dan akurat sebagai bentuk komitmen Mahkamah Agung untuk pelayanan publik, transparansi dan akuntabilitas\npelaksanaan fungsi peradilan. Namun dalam hal-hal tertentu masih dimungkinkan terjadi permasalahan teknis terkait dengan akurasi dan keterkinian informasi yang kami sajikan, hal mana akan terus kami perbaiki dari waktu kewaktu.\nDalam hal Anda menemukan inakurasi informasi yang termuat pada situs ini atau informasi yang seharusnya ada, namun belum tersedia, maka harap segera hubungi Kepaniteraan Mahkamah Agung RI melalui :\nEmail : kepaniteraan@mahkamahagung.go.id", "")
    text = text.replace("Telp : 021-384 3348 (ext.318)", "")
    # Membetulkan penulisan token
    text = text.replace('P U T U S A N', 'PUTUSAN').replace('T erdakwa', 'Terdakwa').replace('T empat', 'Tempat').replace('T ahun', 'Tahun')
    text = text.replace('P  E  N  E  T  A  P  A  N', 'PENETAPAN').replace('J u m l a h', 'Jumlah').replace('\n', '')
    # Menghapus Halaman
    text = re.sub(r'\nHalaman \d+ dari \d+ .*', '', text)
    text = re.sub(r'Halaman \d+ dari \d+ .*', '', text)
    text = re.sub(r'\nHal. \d+ dari \d+ .*', '', text)
    text = re.sub(r'Hal. \d+ dari \d+ .*', '', text)
    # Menghapus kode tidak digunakan
    text = re.sub(r' +|[\uf0fc\uf0a7\uf0a8\uf0b7]', ' ', text)
    text = re.sub(r'[\u2026]+|\.{3,}', '', text)
    return text.strip()

  def read_pdf(self, file_pdf):
    try:
      pdf_text = ''
      pdf_file = open(file_pdf, 'rb')
      pdf_reader = PyPDF2.PdfReader(pdf_file)

      for page_num in range(len(pdf_reader.pages)):
          page = pdf_reader.pages[page_num]
          # clean text
          text = self.clean_text(page.extract_text())

          pdf_text += text

      pdf_file.close()
      return pdf_text.strip()

    except requests.exceptions.RequestException as e:
      print("Error:", e)

  def predict(self, doc):
    if '.pdf' not in doc:
      self.fit_transform([doc.strip()])
      return self.dis_text_prediction()
    else:
      file_pdf = self.read_pdf(doc)
      sentence_file = file_pdf.split(';')
      self.fit_transform(sentence_file)
      return self.dis_pdf_prediction()