minskiter's picture
fix(predictor): fix some error
bcb2102
from typing import Any, Dict,List,Union
from transformers import Pipeline
import requests
import re
from io import BytesIO
import pandas as pd
import math
import queue
from datetime import date
import time
import logging
import torch
import torch.nn.functional as F
class Predictor():
def __init__(
self,
pipelines: Dict[str, Pipeline] = {},
paths: List[str] = [],
today: date = date.today()
) -> None:
if "name" not in pipelines:
raise ValueError("'name' pipeline is None")
if "common" not in pipelines:
raise ValueError("'common' pipeline is None")
self.pipelines = pipelines
self.today = today
self.logger = logging.getLogger(__name__)
self.__init_split_data()
self.__init_schools_data(paths)
self.__init_patterns()
def __init_patterns(
self
):
last_name = r"[赵,钱,孙,李,周,吴,郑,王,冯,陈,楮,卫,蒋,沈,韩,杨,朱,秦,尤,许,何,吕,施,张,孔,曹,严,华,金,魏,陶,姜,戚,谢,邹,喻,"\
+r"柏,水,窦,章,云,苏,潘,葛,奚,范,彭,郎,鲁,韦,昌,马,苗,凤,花,方,俞,任,袁,柳,酆,鲍,史,唐,费,廉,岑,薛,雷,贺,倪,汤,滕,殷,罗," \
+ r"毕,郝,邬,安,常,乐,于,时,傅,皮,卞,齐,康,伍,余,元,卜,顾,孟,平,黄,和,穆,萧,尹,姚,邵,湛,汪,祁,毛,禹,狄,米,贝,明,臧,计,伏,成,戴,谈,宋,茅," \
+ r"庞,熊,纪,舒,屈,项,祝,董,梁,杜,阮,蓝,闽,席,季,麻,强,贾,路,娄,危,江,童,颜,郭,梅,盛,林,刁,锺,徐,丘,骆,高,夏,蔡,田,樊,胡,凌,霍,虞,万,支," \
+ r"柯,昝,管,卢,莫,经,房,裘,缪,干,解,应,宗,丁,宣,贲,邓,郁,单,杭,洪,包,诸,左,石,崔,吉,钮,龚,程,嵇,邢,滑,裴,陆,荣,翁,荀,羊,於,惠,甄,麹,家," \
+ r"封,芮,羿,储,靳,汲,邴,糜,松,井,段,富,巫,乌,焦,巴,弓,牧,隗,山,谷,车,侯,宓,蓬,全,郗,班,仰,秋,仲,伊,宫,宁,仇,栾,暴,甘,斜,厉,戎,祖,武,符," \
+ r"刘,景,詹,束,龙,叶,幸,司,韶,郜,黎,蓟,薄,印,宿,白,怀,蒲,邰,从,鄂,索,咸,籍,赖,卓,蔺,屠,蒙,池,乔,阴,郁,胥,能,苍,双,闻,莘,党,翟,谭,贡,劳," \
+ r"逄,姬,申,扶,堵,冉,宰,郦,雍,郤,璩,桑,桂,濮,牛,寿,通,边,扈,燕,冀,郏,浦,尚,农,温,别,庄,晏,柴,瞿,阎,充,慕,连,茹,习,宦,艾,鱼,容,向,古,易," \
+ r"慎,戈,廖,庾,终,暨,居,衡,步,都,耿,满,弘,匡,国,文,寇,广,禄,阙,东,欧,殳,沃,利,蔚,越,夔,隆,师,巩,厍,聂,晁,勾,敖,融,冷,訾,辛,阚,那,简,饶," \
+ r"空,曾,毋,沙,乜,养,鞠,须,丰,巢,关,蒯,相,查,后,荆,红,游,竺,权,逑,盖,益,桓,公,万俟,司马,上官,欧阳,夏侯,诸葛,闻人,东方,赫连,皇甫,尉迟," \
+ r"公羊,澹台,公冶,宗政,濮阳,淳于,单于,太叔,申屠,公孙,仲孙,轩辕,令狐,锺离,宇文,长孙,慕容,鲜于,闾丘,司徒,司空,丌官,司寇,仉,督,子车," \
+ r"颛孙,端木,巫马,公西,漆雕,乐正,壤驷,公良,拓拔,夹谷,宰父,谷梁,晋,楚,阎,法,汝,鄢,涂,钦,段干,百里,东郭,南门,呼延,归,海,羊舌,微生,岳," \
+ r"帅,缑,亢,况,后,有,琴,梁丘,左丘,东门,西门,商,牟,佘,佴,伯,赏,南宫,墨,哈,谯,笪,年,爱,阳,佟,第五,言,福,邱,钟]"
first_name = r' {0,3}[\u4e00-\u9fa5]( {0,3}[\u4e00-\u9fa5]){0,3}'
self.name_pattern = re.compile(last_name + first_name)
self.phone_pattern = re.compile(r'1 {0,4}(3 {0,4}\d|4 {0,4}[5-9]|5 {0,4}[0-35-9]|6 {0,4}[2567]|7 {0,4}[0-8]|8 {0,4}\d|9 {0,4}[0-35-9]) {0,4}(\d {0,4}){8}')
self.email_pattern = re.compile(r'([a-zA-Z0-9_-] {0,4})+@([a-zA-Z0-9_-] {0,4})+(\. {0,4}([a-zA-Z0-9_-] {0,4})+)+')
self.gender_pattern = re.compile(r'(性 {0,8}别.*?)?\s*?(男|女)')
self.age_patterns = [
re.compile(r"(\d{1,2})岁|年龄.{0,10}?(\d{1,2})"),
re.compile(r"生.{0,12}(([12]\d{3})[年|.]?(([01]?\d)[月|.]?)?(([0-3]?\d)[日|.]?)?)"),
]
self.works_key_pattern = re.compile("工作|experience|work",re.M|re.I)
self.job_time_patterns = re.compile('([1-2]\d{3}(\D?[01]?\d){0,2})\D?([1-2]\d{3}(\D?[01]?\d){0,2}|至今)')
self.edu_index = ["博士","硕士","研究生","学士","本科","大专","专科","中专","高中","初中","小学"]
self.edu_patterns = list(re.compile(i) for i in self.edu_index)
self.school_pattern = re.compile(r"([a-zA-Z0-9 \u4e00-\u9fa5]{1,18}(学院|大学|中学|小学|学校|Unverisity|College))")
def _is_url(self, path: str):
return path.startswith('http://') or path.startswith('https://')
def __init_schools_data(
self,
paths: List[str],
):
schools = {}
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/73.0.3683.103 Safari/537.36",
}
for path in paths:
stream = None
if self._is_url(path):
res = requests.get(path,headers=headers)
if res.status_code==200:
stream = BytesIO(res.content)
else:
with open(path, 'rb') as f:
stream = BytesIO(f.read())
df = pd.read_excel(stream)
for row in df.iterrows():
if isinstance(row[1][1],float) and math.isnan(row[1][1]):
continue
if row[1][1]=='学校名称':
continue
# [学校] = 学历(本科、专科)
if len(row[1])>5:
schools[row[1][1]] = row[1][5]
else:
schools[row[1][1]] = "成人学校"
self.schools = schools
if len(schools)==0:
raise ValueError("学校数据为空")
def __init_split_data(
self
):
self.splits = {'\\', '_', '"', '%', '{', '《', ')', '$', '(', '\n', '~', '*', ':', '!', ';', '”', '’', '\t', '?', '-', ';', '》', '】', '`', '、', '+', '“', '[', '—', '·', ')', '=', '‘', '}', '?', ',', '&', '@', '#', ']', '——', ' ', '.', '【', "'", '>', ',', '/', ':', '。', '...', '^', '(', '<', '|', '……', '!'}
def to_date(self, datestr:str):
if re.match("^\d{4}$",datestr):
return date(int(datestr),1,1)
match = re.match("^(\d{4})\D(\d{1,2})",datestr)
if match is not None:
try:
y = int(match.group(1))
m = min(max(int(match.group(2)),1),12)
return date(y,m,1)
except ValueError:
print(datestr)
if datestr=="至今":
return self.today
return None
def split_to_blocks(
self,
text: str,
max_block_len: int = 510,
overlap: bool = True,
max_overlap_len: int = 20,
):
block = {
"start": -1,
"end": -1,
"text": "",
}
blocks = []
overlap_end = queue.Queue()
for i in range(len(text)):
if text[i] in self.splits:
if block["start"]==-1:
continue
if block["end"]!=-1 and i-block['start']>=max_block_len:
block["text"] = text[block["start"]:block["end"]]
blocks.append(block)
block = {
"start": overlap_end.queue[0]+1 if overlap else block['end']+1,
"end": -1,
"text": "",
}
block["end"] = i
while overlap_end.qsize()>0 and overlap_end.queue[0]+max_overlap_len<=i:
overlap_end.get()
overlap_end.put(i)
else:
if block["start"]==-1:
block["start"] = i
# last block
if block["start"]!=-1:
block["end"] = len(text)
block["text"] = text[block["start"]:block["end"]]
blocks.append(block)
return blocks
def get_expand_span(
self,
text: str,
start: int,
end: int,
max_expand_length=10,
):
expand_l,expand_r = start,end
for l in range(max(start-max_expand_length,0), start):
if text[l] in self.splits:
expand_l = l+1
break
for r in range(min(end+max_expand_length,len(text)-1), end, -1):
if text[r] in self.splits:
expand_r = r
break
return text[expand_l:expand_r], expand_l, expand_r
def remove_blanks(
self,
text: str,
blank_pattern: re.Pattern,
):
index_mapper = {}
new_text = []
for i in range(len(text)):
if blank_pattern.match(text[i]) is not None:
continue
index_mapper[len(new_text)] = i
new_text.append(text[i])
return ''.join(new_text), index_mapper
def process(self, text)->Dict[str, Any]:
return_obj = {
"name": [],
"age": [],
"gender": [],
"phone": [],
"email": [],
"schools": [],
"work_time": 0,
"edus": [],
"jobs": [],
"titles": []
}
# 获取名字,先过滤所有空白字符,防止名字中间有空格
remove_blanks_text, index_mapper = self.remove_blanks(text, re.compile(r' '))
start_time = time.perf_counter()
backup_name = []
for block in self.split_to_blocks(remove_blanks_text):
block_text,block_l = block['text'],block['start']
entities = self.pipelines['name'](block_text)
for entity in entities:
if entity['entity']=='NAME':
if self.name_pattern.match(entity['word']) is not None:
obj = {
'start': index_mapper[block_l+entity['start']],
'end': index_mapper[block_l+entity['end']-1]+1,
'entity': 'NAME',
'text': entity['word']
}
repeat = False
for o in return_obj['name']:
if obj['start']==o['start'] and obj['end']==o['end']:
repeat = True
break
if not repeat:
obj['origin'] = text[obj['start']:obj['end']]
return_obj['name'].append(obj)
else:
obj = {
'start': index_mapper[block_l+entity['start']],
'end': index_mapper[block_l+entity['end']-1]+1,
'entity': 'NAME',
'text': entity['word']
}
repeat = False
for o in return_obj['name']:
if obj['start']==o['start'] and obj['end']==o['end']:
repeat = True
break
if not repeat:
obj['origin'] = text[obj['start']:obj['end']]
backup_name.append(obj)
if len(return_obj['name'])==0:
return_obj['name'] = backup_name
end_time = time.perf_counter()
self.logger.info(f"process name time: {end_time-start_time}")
# 获取年龄
start_time = time.perf_counter()
for age_match in self.age_patterns[0].finditer(remove_blanks_text):
age = None
s,e = -1,-1
if age_match.group(1) is not None:
age = age_match.group(1)
s,e = age_match.span(1)
elif age_match.group(2) is not None:
age = age_match.group(2)
s,e = age_match.span(2)
if age is not None:
return_obj['age'].append({
'start': index_mapper[s],
'end': index_mapper[e-1]+1,
'text': str(age),
'entity': 'AGE',
'origin': text[index_mapper[s]:index_mapper[e-1]+1]
})
for age_match in self.age_patterns[1].finditer(remove_blanks_text):
age = None
s,e = -1,-1
year = age_match.group(2)
if year is not None:
year = int(year)
month = age_match.group(4)
if month is not None:
month = int(month)
else:
month = 1
day = age_match.group(6)
if day is not None:
day = int(day)
else:
day = 1
age = date.today().year - year
if date.today().month<month or (date.today().month==month and date.today().day<day):
age -= 1
if age is not None:
s,e = age_match.span(1)
return_obj['age'].append({
'start': index_mapper[s],
'end': index_mapper[e-1]+1,
'text': str(age),
'entity': 'AGE',
'origin': text[index_mapper[s]:index_mapper[e-1]+1]
})
end_time = time.perf_counter()
self.logger.info(f"process age time: {end_time-start_time}")
start_time = time.perf_counter()
# 获取学校
for school_match in self.school_pattern.finditer(remove_blanks_text):
start,end = school_match.span()
expand_text, start, end = self.get_expand_span(remove_blanks_text, start, end)
entities = self.pipelines['common'](expand_text)
for entity in entities:
if entity['entity']=="ORG" and self.school_pattern.search(entity['word']) is not None:
obj = {
'start': index_mapper[start+entity['start']],
'end': index_mapper[start+entity['end']-1]+1,
'entity': 'SCHOOL'
}
for school in self.schools:
if school in entity['word']:
obj['text'] = school
obj["level"] = self.schools[school]
break
repeat = False
for o in return_obj['schools']:
if obj['start']==o['start'] and obj['end']==o['end']:
repeat = True
break
if not repeat:
obj['origin'] = text[obj['start']:obj['end']]
if "text" not in obj:
obj['text'] = obj['origin'].split("\n")[-1]
return_obj['schools'].append(obj)
# 正则找学校
for school_match in re.finditer(r"|".join(self.schools.keys()), remove_blanks_text):
start,end = school_match.span()
obj = {
'start': index_mapper[start],
'end': index_mapper[end-1]+1,
'entity': 'SCHOOL',
'text': school_match.group().split('\n')[-1],
}
repeat = False
for o in return_obj['schools']:
if obj['start']==o['start'] and obj['end']==o['end']:
repeat = True
break
if not repeat:
obj['origin'] = text[obj['start']:obj['end']]
obj['level'] = self.schools[obj['text']]
return_obj['schools'].append(obj)
return_obj['schools'] = sorted(return_obj['schools'], key=lambda x: x['start'])
end_time = time.perf_counter()
self.logger.info(f"process school time: {end_time-start_time}")
start_time = time.perf_counter()
# 获取学历
for i,pattern in enumerate(self.edu_patterns):
for edu_match in pattern.finditer(remove_blanks_text):
start,end = edu_match.span()
expand_text, start, end = self.get_expand_span(remove_blanks_text, start, end)
entities = self.pipelines['common'](expand_text)
for entity in entities:
if entity['entity']=='EDU' and pattern.search(entity['word']) is not None:
obj = {
'start': index_mapper[start+entity['start']],
'end': index_mapper[start+entity['end']-1]+1,
'text': self.edu_index[i],
'entity': 'EDU',
}
repeat = False
for o in return_obj['edus']:
if obj['start']==o['start'] and obj['end']==o['end']:
repeat = True
break
if not repeat:
obj['origin'] = text[obj['start']:obj['end']]
return_obj['edus'].append(obj)
end_time = time.perf_counter()
self.logger.info(f"process edu time: {end_time-start_time}")
start_time = time.perf_counter()
# 如果有工作经历
if self.works_key_pattern.search(remove_blanks_text) is not None:
for job_time_match in self.job_time_patterns.finditer(remove_blanks_text):
origin_start,origin_end = job_time_match.span()
# convert_to_date
fr = self.to_date(job_time_match.group(1))
if fr is None:
continue
fs,fe = job_time_match.span(1)
to = self.to_date(job_time_match.group(3))
if to is None:
continue
ts,te = job_time_match.span(3)
expand_text, start, end = self.get_expand_span(remove_blanks_text, origin_start, origin_end, max_expand_length=50)
entities = self.pipelines['common'](expand_text)
objs = []
for entity in entities:
if entity['entity']=="ORG":
obj = {
'start': index_mapper[start+entity['start']],
'end': index_mapper[start+entity['end']-1]+1,
'entity': 'COMPANY',
'text': entity['word'],
'dis': min(
abs(origin_start-start-entity['end']+1),
abs(origin_end-start-entity['start'])
),
}
obj['origin'] = text[obj['start']:obj['end']]
objs.append(obj)
objs.sort(key=lambda x:x['dis'])
if len(objs)>0 and self.school_pattern.search(objs[0]['text']) is None:
del objs[0]['dis']
from_date = {
'start': index_mapper[fs],
'end': index_mapper[fe-1]+1,
'text': fr.isoformat(),
'entity': 'DATE',
'origin': text[index_mapper[fs]:index_mapper[fe-1]+1]
}
to_date = {
'start': index_mapper[ts],
'end': index_mapper[te-1]+1,
'text': to.isoformat(),
'entity': 'DATE',
'origin': text[index_mapper[ts]:index_mapper[te-1]+1]
}
jobs = [objs[0],from_date,to_date]
return_obj['jobs'].append(jobs)
return_obj["jobs"].sort(key=lambda x:date.fromisoformat(x[1]['text']))
# 计算工作时间
last_end = None
work_month = 0
for i in range(0,len(return_obj["jobs"])):
start = date.fromisoformat(return_obj["jobs"][i][1]['text'])
end = date.fromisoformat(return_obj["jobs"][i][2]['text'])
if last_end is not None and start<last_end:
start = last_end
diff_y = end.year-start.year
diff_m = end.month-start.month
work_month += diff_y * 12 + diff_m
last_end = end
return_obj['work_time'] = max(math.ceil(work_month/12),0)
end_time = time.perf_counter()
self.logger.info(f"process work time: {end_time-start_time}")
start_time = time.perf_counter()
# 获取手机号码
for phone_match in self.phone_pattern.finditer(text):
start,end = phone_match.span()
return_obj['phone'].append({
'start': start,
'end': end,
'entity': 'PHONE',
'origin': text[start:end],
'text': re.sub('\s','',text[start:end])
})
end_time = time.perf_counter()
self.logger.info(f"process phone time: {end_time-start_time}")
start_time = time.perf_counter()
for email_match in self.email_pattern.finditer(text):
start,end = email_match.span()
return_obj['email'].append({
'start': start,
'end': end,
'entity': 'EMAIL',
'origin': text[start:end],
'text': re.sub('\s','',text[start:end])
})
end_time = time.perf_counter()
self.logger.info(f"process email time: {end_time-start_time}")
start_time = time.perf_counter()
for gender_match in self.gender_pattern.finditer(text):
start,end = gender_match.span(2)
return_obj['gender'].append({
'start': start,
'end': end,
'entity': 'GENDER',
'origin': text[start:end],
'text': text[start:end]
})
end_time = time.perf_counter()
self.logger.info(f"process gender time: {end_time-start_time}")
start_time = time.perf_counter()
for block in self.split_to_blocks(remove_blanks_text):
entities = self.pipelines["common"](block["text"])
for entity in entities:
if entity['entity']=='TITLE':
obj = {
'start': index_mapper[block['start']+entity['start']],
'end': index_mapper[block['start']+entity['end']-1]+1,
'text': entity['word'],
'entity': 'TITLE',
}
obj['origin'] = text[obj['start']:obj['end']]
repeat = False
for o in return_obj['titles']:
if obj['start']==o['start'] and obj['end']==o['end']:
repeat = True
break
if not repeat:
return_obj['titles'].append(obj)
end_time = time.perf_counter()
self.logger.info(f"process title time: {end_time-start_time}")
return return_obj
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.process(*args, **kwds)
class PositionPredictor():
def __init__(self, pipeline: Pipeline) -> None:
self.pipeline = pipeline
self.__init_split_data()
self.logger = logging.getLogger(__name__)
def __split_blocks(self, text: str) -> List[str]:
start,end = 0,0
blocks = []
while end<len(text):
if text[end] in self.splits:
if end>start:
blocks.append(text[start:end])
start = end+1
end += 1
if end>start:
blocks.append(text[start:end])
return blocks
def __init_split_data(
self
):
self.splits = {'\\', '_', '"', '%', '{', '《', ')', '$', '(', '\n', '~', '*', ':', '!', ';', '”', '’', '\t', '?', '-', ';', '》', '】', '`', '、', '+', '“', '[', '—', '·', ')', '=', '‘', '}', '?', ',', '&', '@', '#', ']', '——', ' ', '.', '【', "'", '>', ',', '/', ':', '。', '...', '^', '(', '<', '|', '……', '!'}
def predict(self,
positions: List[Dict[str,Union[str,List[str]]]],
resume: str
) -> List[Dict[str, Union[str, float]]]:
ans = []
resume_blocks = self.__split_blocks(resume)
resume_encoding = []
for block_resume in resume_blocks:
resume_encoding.append(torch.tensor(self.pipeline(block_resume)[0]))
resume_encoding = torch.stack(resume_encoding,dim=0)
for position in positions:
requireds = position['required']
score = 0.0
block_encodings = []
for required in requireds:
blocks = self.__split_blocks(required)
for block in blocks:
block_encodings.append(torch.tensor(self.pipeline(block)[0]))
block_encodings = torch.stack(block_encodings,dim=0)
cos_sims = F.cosine_similarity(resume_encoding.unsqueeze(1), block_encodings.unsqueeze(0),dim=-1)
score = cos_sims.max().item()
self.logger.info(f"position: {position['name']}, score: {score}")
ans.append({
'position': position['name'],
'score': score
})
ans.sort(key=lambda x:x['score'], reverse=True)
return ans
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.predict(*args, **kwds)