import torch from transformers import RobertaForTokenClassification, AutoTokenizer model=RobertaForTokenClassification.from_pretrained('jiangchengchengNLP/Chinese_resume_extract') tokenizer = AutoTokenizer.from_pretrained('jiangchengchengNLP/Chinese_resume_extract',do_lower_case=True) device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.eval() model.to(device) import json label_list={ 0:'其他', 1:'电话', 2:'毕业时间', #毕业时间 3:'出生日期', #出生日期 4:'项目名称', #项目名称 5:'毕业院校', #毕业院校 6:'职务', #职务 7:'籍贯', #籍贯 8:'学位', #学位 9:'性别', #性别 10:'姓名', #姓名 11:'工作时间', #工作时间 12:'落户市县', #落户市县 13:'项目时间', #项目时间 14:'最高学历', #最高学历 15:'工作单位', #工作单位 16:'政治面貌', #政治面貌 17:'工作内容', #工作内容 18:'项目责任', #项目责任 } def get_info(text): #文本处理 text=text.strip() text=text.replace('\n',',') # 将换行符替换为逗号 text=text.replace('\r',',') # 将回车符替换为逗号 text=text.replace('\t',',') # 将制表符替换为逗号 text=text.replace(' ',',') # 将空格替换为逗号 #将连续的逗号合并成一个逗号 while ',,' in text: text=text.replace(',,',',') block_list=[] if len(text)>300: #切块策略 #先切分成句 sentence_list=text.split(',') #然后拼接句子长度不超过300,一旦超过300,当前句子放到下一个块中 boundary=300 block_list=[] block=sentence_list[0] for i in range(1,len(sentence_list)): if len(block)+len(sentence_list[i])<=boundary: block+=sentence_list[i] else: block_list.append(block) block=sentence_list[i] block_list.append(block) else: block_list.append(text) _input = tokenizer(block_list, return_tensors='pt',padding=True,truncation=True) #如果有GPU,将输入数据移到GPU input_ids = _input['input_ids'].to(device) attention_mask = _input['attention_mask'].to(device) # 模型推理 with torch.no_grad(): logits = model(input_ids=input_ids, attention_mask=attention_mask)[0] # 获取预测的标签ID #print(logits.shape) ids = torch.argmax(logits, dim=-1) input_ids=input_ids.reshape(-1) #将张量在最后一个维度拼接,并以0为分界,拼接成句 ids =ids.reshape(-1) # 按标签组合成提取内容 extracted_info = {} word_list=[] flag=None for idx, label_id in enumerate(ids): label_id = label_id.item() if label_id!= 0 and (flag==None or flag==label_id): #不等于零时 if flag==None: flag=label_id label = label_list[label_id] # 获取对应的标签 word_list.append(input_ids[idx].item()) if label not in extracted_info: extracted_info[label] = [] else: if word_list: sentence=''.join(tokenizer.decode(word_list)) extracted_info[label].append(sentence) flag=None word_list=[] if label_id!= 0: label = label_list[label_id] # 获取对应的标签 word_list.append(input_ids[idx].item()) if label not in extracted_info: extracted_info[label] = [] # 返回JSON格式的提取内容 return extracted_info