qgyd2021 commited on
Commit
b06652e
·
1 Parent(s): c82bbd4

[update]add sent_tokenize

Browse files
Files changed (2) hide show
  1. examples/sent_tokenize/sent_tokenize.py +78 -0
  2. main.py +27 -2
examples/sent_tokenize/sent_tokenize.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ import re
6
+ from typing import List
7
+
8
+ from project_settings import project_path
9
+
10
+ os.environ['NLTK_DATA'] = (project_path / "thirdparty_data/nltk_data").as_posix()
11
+
12
+ import jieba
13
+ import nltk
14
+
15
+
16
+ nltk_sent_tokenize_languages = [
17
+ "czech", "danish", "dutch", "flemish", "english", "estonian",
18
+ "finnish", "french", "german", "italian", "norwegian",
19
+ "polish", "portuguese", "russian", "spanish", "swedish", "turkish"
20
+ ]
21
+
22
+
23
+ def get_args():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument(
26
+ "--text",
27
+ default="M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this paper and first released in this repository.",
28
+ # default="我是一个句子。我是另一个句子。",
29
+ type=str,
30
+ )
31
+ parser.add_argument(
32
+ "--language",
33
+ default="english",
34
+ # default="chinese",
35
+ choices=nltk_sent_tokenize_languages,
36
+ type=str
37
+ )
38
+ args = parser.parse_args()
39
+ return args
40
+
41
+
42
+ def chinese_sent_tokenize(text: str):
43
+ # 单字符断句符
44
+ text = re.sub(r"([。!?\?])([^”’])", r"\1\n\2", text)
45
+ # 英文省略号
46
+ text = re.sub(r"(\.{6})([^”’])", r"\1\n\2", text)
47
+ # 中文省略号
48
+ text = re.sub(r"(\…{2})([^”’])", r"\1\n\2", text)
49
+ # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
50
+ text = re.sub(r"([。!?\?][”’])([^,。!?\?])", r"\1\n\2", text)
51
+ # 段尾如果有多余的\n就去掉它
52
+ # 很多规则中会考虑分号; ,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
53
+ text = text.rstrip()
54
+
55
+ return text.split("\n")
56
+
57
+
58
+ def sent_tokenize(text: str, language: str) -> List[str]:
59
+ if language in ["chinese"]:
60
+ sent_list = chinese_sent_tokenize(text)
61
+ else:
62
+ sent_list = nltk.sent_tokenize(text, language)
63
+ return sent_list
64
+
65
+
66
+ def main():
67
+ args = get_args()
68
+
69
+ sent_list = sent_tokenize(args.text, language=args.language)
70
+
71
+ for sent in sent_list:
72
+ print(sent)
73
+
74
+ return
75
+
76
+
77
+ if __name__ == '__main__':
78
+ main()
main.py CHANGED
@@ -3,6 +3,8 @@
3
  import argparse
4
  import json
5
  import os
 
 
6
 
7
  from project_settings import project_path
8
 
@@ -12,7 +14,6 @@ os.environ['NLTK_DATA'] = (project_path / "thirdparty_data/nltk_data").as_posix(
12
  import gradio as gr
13
  import nltk
14
  from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
15
- from transformers.generation.streamers import TextIteratorStreamer
16
 
17
 
18
  language_map = {
@@ -45,6 +46,30 @@ nltk_sent_tokenize_languages = [
45
  ]
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def main():
49
  model_dict = {
50
  "facebook/m2m100_418M": {
@@ -77,7 +102,7 @@ def main():
77
  tokenizer.src_lang = language_map[src_lang]
78
 
79
  if src_lang.lower() in nltk_sent_tokenize_languages:
80
- src_t_list = nltk.sent_tokenize(src_text, language="")
81
  else:
82
  src_t_list = [src_text]
83
 
 
3
  import argparse
4
  import json
5
  import os
6
+ import re
7
+ from typing import List
8
 
9
  from project_settings import project_path
10
 
 
14
  import gradio as gr
15
  import nltk
16
  from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
 
17
 
18
 
19
  language_map = {
 
46
  ]
47
 
48
 
49
+ def chinese_sent_tokenize(text: str):
50
+ # 单字符断句符
51
+ text = re.sub(r"([。!?\?])([^”’])", r"\1\n\2", text)
52
+ # 英文省略号
53
+ text = re.sub(r"(\.{6})([^”’])", r"\1\n\2", text)
54
+ # 中文省略号
55
+ text = re.sub(r"(\…{2})([^”’])", r"\1\n\2", text)
56
+ # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
57
+ text = re.sub(r"([。!?\?][”’])([^,。!?\?])", r"\1\n\2", text)
58
+ # 段尾如果有多余的\n就去掉它
59
+ # 很多规则中会考虑分号; ,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
60
+ text = text.rstrip()
61
+
62
+ return text.split("\n")
63
+
64
+
65
+ def sent_tokenize(text: str, language: str) -> List[str]:
66
+ if language in ["chinese"]:
67
+ sent_list = chinese_sent_tokenize(text)
68
+ else:
69
+ sent_list = nltk.sent_tokenize(text, language)
70
+ return sent_list
71
+
72
+
73
  def main():
74
  model_dict = {
75
  "facebook/m2m100_418M": {
 
102
  tokenizer.src_lang = language_map[src_lang]
103
 
104
  if src_lang.lower() in nltk_sent_tokenize_languages:
105
+ src_t_list = sent_tokenize(src_text, language=src_lang.lower())
106
  else:
107
  src_t_list = [src_text]
108