rickltt commited on
Commit
7172c83
1 Parent(s): c49599a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +78 -3
README.md CHANGED
@@ -1,3 +1,78 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - Chinese Medical
4
+ - Punctuation Restoration
5
+ language:
6
+ - zh
7
+ license: mit
8
+ pipeline_tag: token-classification
9
+ base_model: rickltt/pmp-h256
10
+ ---
11
+
12
+ ## Example Usage
13
+ ```python
14
+ import torch
15
+ import jieba
16
+ import numpy as np
17
+ from classifier import BertForMaskClassification
18
+ from transformers import AutoTokenizer, AutoConfig, BertForTokenClassification
19
+
20
+ label_list = ["O","COMMA","PERIOD","COLON"]
21
+
22
+ label2punct = {
23
+ "COMMA": ",",
24
+ "PERIOD": "。",
25
+ "COLON":":",
26
+ }
27
+
28
+ model_name_or_path = "pmp-h768"
29
+
30
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
31
+ model = BertForMaskClassification.from_pretrained(model_name_or_path)
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ def punct(text):
35
+
36
+ tokenize_words = jieba.lcut(''.join(text))
37
+ mask_tokens = []
38
+ for word in tokenize_words:
39
+ mask_tokens.extend(word)
40
+ mask_tokens.append("[MASK]")
41
+ tokenized_inputs = tokenizer(mask_tokens,is_split_into_words=True, return_tensors="pt")
42
+ with torch.no_grad():
43
+ logits = model(**tokenized_inputs).logits
44
+ predictions = logits.argmax(-1).tolist()
45
+ predictions = predictions[0]
46
+ tokens = tokenizer.convert_ids_to_tokens(tokenized_inputs["input_ids"][0])
47
+
48
+ result =[]
49
+ print(tokens)
50
+ print(predictions)
51
+ for token, prediction in zip(tokens, predictions):
52
+ if token =="[CLS]" or token =="[SEP]":
53
+ continue
54
+ if token == "[MASK]":
55
+ label = label_list[prediction]
56
+ if label != "O":
57
+ punct = label2punct[label]
58
+ result.append(punct)
59
+ else:
60
+ result.append(token)
61
+
62
+ return "".join(result)
63
+
64
+ text = '肝浊音界正常肝上界位于锁骨中线第五肋间移动浊音阴性肾区无叩痛'
65
+ print(punct(text))
66
+
67
+ # 肝浊音界正常,肝上界位于锁骨中线第五肋间,移动浊音阴性,肾区无叩痛。
68
+ ```
69
+
70
+ # Acknowledgments
71
+ This work was in part supported by Shenzhen Science and Technology Program (No:JCYJ20210324135809025).
72
+
73
+ # Citations
74
+ Coming Soon
75
+
76
+ # License
77
+
78
+ MIT