|
--- |
|
tags: |
|
- Chinese Medical |
|
- Punctuation Restoration |
|
language: |
|
- zh |
|
license: mit |
|
pipeline_tag: token-classification |
|
base_model: rickltt/pmp-h256 |
|
--- |
|
|
|
## Example Usage |
|
```python |
|
import torch |
|
import jieba |
|
import numpy as np |
|
from classifier import BertForMaskClassification |
|
from transformers import AutoTokenizer, AutoConfig, BertForTokenClassification |
|
|
|
label_list = ["O","COMMA","PERIOD","COLON"] |
|
|
|
label2punct = { |
|
"COMMA": ",", |
|
"PERIOD": "。", |
|
"COLON":":", |
|
} |
|
|
|
model_name_or_path = "pmp-h768" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
model = BertForMaskClassification.from_pretrained(model_name_or_path) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def punct(text): |
|
|
|
tokenize_words = jieba.lcut(''.join(text)) |
|
mask_tokens = [] |
|
for word in tokenize_words: |
|
mask_tokens.extend(word) |
|
mask_tokens.append("[MASK]") |
|
tokenized_inputs = tokenizer(mask_tokens,is_split_into_words=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
logits = model(**tokenized_inputs).logits |
|
predictions = logits.argmax(-1).tolist() |
|
predictions = predictions[0] |
|
tokens = tokenizer.convert_ids_to_tokens(tokenized_inputs["input_ids"][0]) |
|
|
|
result =[] |
|
print(tokens) |
|
print(predictions) |
|
for token, prediction in zip(tokens, predictions): |
|
if token =="[CLS]" or token =="[SEP]": |
|
continue |
|
if token == "[MASK]": |
|
label = label_list[prediction] |
|
if label != "O": |
|
punct = label2punct[label] |
|
result.append(punct) |
|
else: |
|
result.append(token) |
|
|
|
return "".join(result) |
|
|
|
text = '肝浊音界正常肝上界位于锁骨中线第五肋间移动浊音阴性肾区无叩痛' |
|
print(punct(text)) |
|
|
|
# 肝浊音界正常,肝上界位于锁骨中线第五肋间,移动浊音阴性,肾区无叩痛。 |
|
``` |
|
|
|
# Acknowledgments |
|
This work was in part supported by Shenzhen Science and Technology Program (No:JCYJ20210324135809025). |
|
|
|
# Citations |
|
Coming Soon |
|
|
|
# License |
|
|
|
MIT |
|
|