IL-TUR-Leaderboard / ner_helpers.py
abhinav-joshi's picture
add prediction submission
e1043c6
raw
history blame
3.66 kB
from transformers import AutoTokenizer
import re
import string
class TF_Tokenizer:
def __init__(self, model_str):
tok = AutoTokenizer.from_pretrained(model_str)
def __call__(self, txt):
return self.tok.tokenize(txt)
class WS_Tokenizer:
def __init__(self):
pass
def __call__(self, txt):
return re.findall(r"[{}]|\w+".format(string.punctuation), txt)
def convert_spans_to_bio(txt, roles, tokenizer_func):
roles = sorted(roles, key=lambda x: x["start"])
roles_left = [r["start"] for r in roles]
ttxt = tokenizer_func(txt)
c = 0
cr = -1
prev = "O"
troles = []
for tok in ttxt:
if c >= len(txt):
break
while txt[c] == " ":
c += 1
else:
if c in roles_left: # Start of a new role
ind = roles_left.index(c)
cr = roles[ind]["end"]
prev = "I-" + roles[ind]["label"]
troles.append("B-" + roles[ind]["label"])
else:
if c < cr: # Assign previous role
troles.append(prev)
else: # Assign 'O'
troles.append("O")
c += len(tok)
if len(ttxt) != len(troles):
troles += ["O"] * (len(ttxt) - len(troles))
assert len(ttxt) == len(troles)
return troles
def convert_bio_to_spans(txt, troles, tokenizer_func):
c = 0
c2 = 0
cr = -1
cs = -1
prev = "O"
roles = []
ttxt = tokenizer_func(txt)
if len(ttxt) != len(troles):
ttxt = ttxt[: len(troles)]
for j, tok in enumerate(ttxt):
if c >= len(txt):
break
while c < len(txt) and txt[c].isspace():
c += 1
if tok[:2] == "##" or tok == "[UNK]":
c += len(tok) - 2 if tok[:2] == "##" else 1
else:
if troles[j].startswith("B-"):
if cs >= cr:
cr = c
if cs >= 0:
roles.append({"start": cs, "end": c2, "label": prev})
cs = c
prev = troles[j][2:]
else:
if troles[j] == "O":
if cs >= cr:
cr = c
if cs >= 0:
roles.append({"start": cs, "end": c2, "label": prev})
c += len(tok)
c2 = c
if cs >= cr:
if cs >= 0:
roles.append({"start": cs, "end": c2, "label": prev})
return roles
def span2bio(txt, labels):
roles = sorted(labels, key=lambda x: x["label"])
roles_left = [r["start"] for r in roles]
ttxt = re.findall(r"[{}]|\w+".format(string.punctuation), txt)
c = 0
cr = -1
prev = "O"
troles = []
for tok in ttxt:
if c >= len(txt):
break
while txt[c] == " ":
c += 1
else:
if c in roles_left: # Start of a new role
ind = roles_left.index(c)
cr = roles[ind]["end"]
prev = "I-" + roles[ind]["label"]
troles.append("B-" + roles[ind]["label"])
else:
if c < cr: # Assign previous role
troles.append(prev)
else: # Assign 'O'
troles.append("O")
c += len(tok)
if len(ttxt) != len(troles):
troles += ["O"] * (len(ttxt) - len(troles))
assert len(ttxt) == len(troles)
return ttxt, troles