File size: 4,711 Bytes
df81629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f2eb8c
df81629
 
 
 
 
 
 
 
 
 
31aac55
ddfda7d
31aac55
df81629
 
 
66db8e5
87b9111
 
 
 
 
 
 
 
 
3743aa7
 
 
66db8e5
87b9111
 
 
 
 
df81629
 
 
 
7ecd7c4
 
df81629
 
 
66db8e5
87b9111
df81629
 
 
 
 
 
87b9111
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os

import transformers
from transformers import pipeline
from transformers.pipelines.token_classification import TokenClassificationPipeline
import py_vncorenlp

os.system('pwd')
os.system('sudo update-alternatives --config java')
os.mkdir('/home/user/app/vncorenlp')
py_vncorenlp.download_model(save_dir='/home/user/app/vncorenlp')
rdrsegmenter = py_vncorenlp.VnCoreNLP(annotators=["wseg"], save_dir='/home/user/app/vncorenlp')

class MyPipeline(TokenClassificationPipeline):
  def preprocess(self, sentence, offset_mapping=None):
      truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False 
      
      model_inputs = self.tokenizer(
          sentence,
          return_tensors=self.framework,
          truncation=truncation,
          return_special_tokens_mask=True,
          return_offsets_mapping=self.tokenizer.is_fast,
      )


      length = len(model_inputs['input_ids'][0]) - 2
      tokens = self.tokenizer.tokenize(sentence)
      seek = 0
      offset_mapping_list = [[(0, 0)]]
      for i in range(length):
        if tokens[i][-2:] == '@@':
          offset_mapping_list[0].append((seek, seek + len(tokens[i]) - 2))
          seek += len(tokens[i]) - 2            
        else:
          offset_mapping_list[0].append((seek, seek + len(tokens[i])))
          seek += len(tokens[i]) + 1
      offset_mapping_list[0].append((0, 0))

      # if offset_mapping:
      #     model_inputs["offset_mapping"] = offset_mapping

      model_inputs['offset_mapping'] = offset_mapping_list
      model_inputs["sentence"] = sentence

      return model_inputs

model_checkpoint = "DD0101/disfluency-large"

my_classifier = pipeline(
  "token-classification", model=model_checkpoint, aggregation_strategy="simple", pipeline_class=MyPipeline)


import gradio as gr

def ner(text):
  text = " ".join(rdrsegmenter.word_segment(text))
    
  # Some words in lowercase like "đà nẵng" will get error (due to vncorenlp)
  text = text.replace("đà ", " đà ")    
    
  output = my_classifier(text)
  for entity in output:
    entity['entity'] = entity.pop('entity_group')

  # Remove Disfluency-entities to return a sentence with "Fluency" version
  list_str = list(text)

  for entity in output[::-1]: # if we use default order of output list, we will shorten the length of the sentence, so the words later are not in the correct start and end index
    start = max(0, entity['start'] - 1)
    end = min(len(list_str), entity['end'] + 1)

    list_str[start:end] = ' '

  fluency_sentence = "".join(list_str).strip() # use strip() in case we need to remove entity at the beginning or the end of sentence
                                               # (without strip(): "Giá vé khứ hồi à nhầm giá vé một chiều ..." -> " giá vé một chiều ...")
  fluency_sentence = fluency_sentence[0].upper() + fluency_sentence[1:] # since capitalize() just lowercase whole sentence first then uppercase the first letter

  # Replace words like "Đà_Nẵng" to "Đà Nẵng"  
  text = text.replace("_", " ")
  fluency_sentence = fluency_sentence.replace("_", " ")
    
  return {'text': text, 'entities': output}, fluency_sentence

examples = ['Tôi cần thuê à tôi muốn bay một chuyến khứ hồi từ Đà Nẵng đến Đà Lạt', 
            'Giá vé một chiều à không khứ hồi từ Đà Nẵng đến Vinh dưới 2 triệu đồng giá vé khứ hồi từ Quy Nhơn đến Vinh dưới 3 triệu đồng giá vé khứ hồi từ Buôn Ma Thuột đến Quy Nhơn à đến Vinh dưới 4 triệu rưỡi', 
            'Cho tôi biết các chuyến bay đến Đà Nẵng vào ngày 12 mà không ngày 14 tháng sáu',
            'Những chuyến bay nào khởi hành từ Thành phố Hồ Chí Minh bay đến Frankfurt mà nối chuyến ở Singapore và hạ cánh trước 10 giờ ý tôi là 9 giờ tối',
            'Thành Phố nào có VNA ừm thôi cho tôi xem tất cả các chuyến bay từ Thanh Hóa hay Nghệ An nhỉ à Thanh Hóa đến Đà Lạt vào Thứ ba à thôi tôi cần vào Thứ hai'
] 

demo = gr.Interface(ner, 
                    gr.Textbox(label='Sentence', placeholder="Enter your sentence here..."),
                    outputs=[gr.HighlightedText(label='Highlighted Output'), gr.Textbox(label='"Fluency" version')],
                    examples=examples,
                    title="Disfluency Detection",
                    description="This is an easy-to-use built in Gradio for desmontrating a NER System that identifies disfluency-entities in \
                    Vietnamese utterances",
                    theme=gr.themes.Soft())

demo.launch()