metadata
license: apache-2.0
language: zh
uie-micro
介绍
- PaddlePaddle/uie-micro 的 Pytorch 实现
代码调用
forward
Parameters
input_ids: Optional[torch.Tensor] = None
token_type_ids: Optional[torch.Tensor] = None
position_ids: Optional[torch.Tensor] = None
attention_mask: Optional[torch.Tensor] = None
head_mask: Optional[torch.Tensor] = None
inputs_embeds: Optional[torch.Tensor] = None
start_positions: Optional[torch.Tensor] = None
end_positions: Optional[torch.Tensor] = None
output_attentions: Optional[bool] = None
output_hidden_states: Optional[bool] = None
return_dict: Optional[bool] = None
Returns UIEModelOutput or tuple(torch.FloatTensor)
predict
Parameters
schema: Union[Dict, List[str], str]
input_texts: Union[List[str], str]
tokenizer: PreTrainedTokenizerFast
max_length: int = 512
batch_size: int = 32
position_prob: int = 0.5
progress_hook=None
Returns * List[Dict]*
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained('Casually/uie-micro', trust_remote_code=True)
model.eval().to('cuda')
tokenizer = AutoTokenizer.from_pretrained('Casually/uie-micro')
hook = tqdm()
schema = {'地震触发词': ['地震强度', '时间', '震中位置', '震源深度']}
model.predict(schema=schema,
input_texts='中国地震台网正式测定:5月16日06时08分在云南临沧市凤庆县(北纬24.34度,东经99.98度)发生3.5级地震,震源深度10千米。',
tokenizer=tokenizer,
progress_hook=hook
)
100%|█████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.92it/s]
[{'地震触发词': [{'end': 58,
'probability': 0.9953331562546808,
'relations': {'地震强度': [{'end': 56,
'probability': 0.9719095188981299,
'start': 52,
'text': '3.5级'}],
'时间': [{'end': 22,
'probability': 0.9653931540843175,
'start': 11,
'text': '5月16日06时08分'}],
'震中位置': [{'end': 50,
'probability': 0.6063880101553423,
'start': 23,
'text': '云南临沧市凤庆县(北纬24.34度,东经99.98度)'}],
'震源深度': [{'end': 67,
'probability': 0.989598549365315,
'start': 63,
'text': '10千米'}]},
'start': 56,
'text': '地震'}]}]
应用示例
实体抽取
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained('Casually/uie-micro', trust_remote_code=True)
model.eval().to('cuda')
tokenizer = AutoTokenizer.from_pretrained('Casually/uie-micro')
schema = ['时间', '选手', '赛事名称']
res = model.predict(schema=schema,
input_texts="2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!",
tokenizer=tokenizer,
)
>>> from pprint import pprint
>>> pprint(res)
[{'时间': [{'end': 6,
'probability': 0.906374348561485,
'start': 0,
'text': '2月8日上午'}],
'选手': [{'end': 31,
'probability': 0.8768158783169611,
'start': 28,
'text': '谷爱凌'}]}]
关系抽取
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained('Casually/uie-micro', trust_remote_code=True)
model.eval().to('cuda')
tokenizer = AutoTokenizer.from_pretrained('Casually/uie-micro')
schema = {'竞赛名称': ['主办方', '承办方', '已举办次数']}
res = model.predict(schema=schema,
input_texts='2022语言与智能技术竞赛由中国中文信息学会和中国计算机学会联合主办,百度公司、中国中文信息学会评测工作委员会和中国计算机学会自然语言处理专委会承办,已连续举办4届,成为全球最热门的中文NLP赛事之一。',
tokenizer=tokenizer,
)
>>> from pprint import pprint
>>> pprint(res)
[{'竞赛名称': [{'end': 13,
'probability': 0.6710958346658344,
'relations': {'已举办次数': [{'end': 82,
'probability': 0.7727802061877256,
'start': 80,
'text': '4届'}]},
'start': 0,
'text': '2022语言与智能技术竞赛'}]}]
事件抽取
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained('Casually/uie-micro', trust_remote_code=True)
model.eval().to('cuda')
tokenizer = AutoTokenizer.from_pretrained('Casually/uie-micro')
schema = {'地震触发词': ['地震强度', '时间', '震中位置', '震源深度']}
res = model.predict(schema=schema,
input_texts='中国地震台网正式测定:5月16日06时08分在云南临沧市凤庆县(北纬24.34度,东经99.98度)发生3.5级地震,震源深度10千米。',
tokenizer=tokenizer,
)
>>> from pprint import pprint
>>> pprint(res)
[{'地震触发词': [{'end': 58,
'probability': 0.9953331562546808,
'relations': {'地震强度': [{'end': 56,
'probability': 0.9719095188981299,
'start': 52,
'text': '3.5级'}],
'时间': [{'end': 22,
'probability': 0.9653931540843175,
'start': 11,
'text': '5月16日06时08分'}],
'震中位置': [{'end': 50,
'probability': 0.6063880101553423,
'start': 23,
'text': '云南临沧市凤庆县(北纬24.34度,东经99.98度)'}],
'震源深度': [{'end': 67,
'probability': 0.989598549365315,
'start': 63,
'text': '10千米'}]},
'start': 56,
'text': '地震'}]}]