metadata
tags:
- Chinese Medical
- Punctuation Restoration
language:
- zh
license: mit
pipeline_tag: token-classification
base_model: rickltt/pmp-h256
Example Usage
import torch
import jieba
import numpy as np
from classifier import BertForMaskClassification
from transformers import AutoTokenizer, AutoConfig, BertForTokenClassification
label_list = ["O","COMMA","PERIOD","COLON"]
label2punct = {
"COMMA": ",",
"PERIOD": "。",
"COLON":":",
}
model_name_or_path = "pmp-h312"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = BertForMaskClassification.from_pretrained(model_name_or_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def punct(text):
tokenize_words = jieba.lcut(''.join(text))
mask_tokens = []
for word in tokenize_words:
mask_tokens.extend(word)
mask_tokens.append("[MASK]")
tokenized_inputs = tokenizer(mask_tokens,is_split_into_words=True, return_tensors="pt")
with torch.no_grad():
logits = model(**tokenized_inputs).logits
predictions = logits.argmax(-1).tolist()
predictions = predictions[0]
tokens = tokenizer.convert_ids_to_tokens(tokenized_inputs["input_ids"][0])
result =[]
print(tokens)
print(predictions)
for token, prediction in zip(tokens, predictions):
if token =="[CLS]" or token =="[SEP]":
continue
if token == "[MASK]":
label = label_list[prediction]
if label != "O":
punct = label2punct[label]
result.append(punct)
else:
result.append(token)
return "".join(result)
text = '肝浊音界正常肝上界位于锁骨中线第五肋间移动浊音阴性肾区无叩痛'
print(punct(text))
# 肝浊音界正常,肝上界位于锁骨中线第五肋间,移动浊音阴性,肾区无叩痛。
Acknowledgments
This work was in part supported by Shenzhen Science and Technology Program (No:JCYJ20210324135809025).
Citations
Coming Soon
License
MIT