cjber's picture
Update README.md
eb59026
|
raw
history blame
3.06 kB
metadata
language: en
datasets:
  - wnut_17
license: mit
metrics:
  - f1
widget:
  - text: My name is Sylvain and I live in Paris
    example_title: Parisian
  - text: My name is Sarah and I live in London
    example_title: Londoner

Reddit NER for place names

Fine-tuned twitter-roberta-base for named entity recognition, trained using wnut_17 with 498 additional comments from Reddit. This model is intended solely for place name extraction from social media text, other entities have therefore been removed.

Use in transformers

from transformers import pipeline

generator = pipeline(
    task="ner",
    model="cjber/reddit-ner-place_names",
    tokenizer="cjber/reddit-ner-place_names",
)

out = generator("I live north of liverpool in Waterloo")

entities = [item["word"] for item in out]
labels = [item["entity"] for item in out]

Label idx values are required for the following stages:

class Label:
    labels: dict[str, int] = {
        "O": 0,
        "B-location": 1,
        "I-location": 2,
        "L-location": 3,
        "U-location": 4,
    }

    idx: dict[int, str] = {v: k for k, v in labels.items()}
    count: int = len(labels)

Combine subwords:

def combine_subwords(tokens: list[str], tags: list[int]) -> tuple[list[str], list[str]]:
    idx = [
        idx for idx, token in enumerate(tokens) if token not in ["<s>", "<pad>", "</s>"]
    ]

    tokens = [tokens[i] for i in idx]
    tags = [tags[i] for i in idx]

    for idx, _ in enumerate(tokens):
        idx += 1
        if not tokens[-idx + 1].startswith("Ġ"):
            tokens[-idx] = tokens[-idx] + tokens[-idx + 1]
    subwords = [i for i, _ in enumerate(tokens) if tokens[i].startswith("Ġ")]

    tags = [tags[i] for i in subwords]
    tokens = [tokens[i][1:] for i in subwords]
    tags_str: list[str] = [Label.idx[i] for i in tags]
    return tokens, tags_str


names, labels = combine_subwords(entities, [Label.labels[lb] for lb in labels])

Combine BILUO tags:

def combine_biluo(tokens: list[str], tags: list[str]) -> tuple[list[str], list[str]]:
    tokens_biluo = tokens.copy()
    tags_biluo = tags.copy()

    for idx, tag in enumerate(tags_biluo):
        if idx + 1 < len(tags_biluo) and tag[0] == "B":
            i = 1
            while tags_biluo[idx + i][0] not in ["B", "O"]:
                tokens_biluo[idx] = f"{tokens_biluo[idx]} {tokens_biluo[idx + i]}"
                i += 1
                if idx + i == len(tokens_biluo):
                    break

    zipped = [
        (token, tag)
        for (token, tag) in zip(tokens_biluo, tags_biluo)
        if tag[0] not in ["I", "L"]
    ]
    if list(zipped):
        tokens_biluo, tags_biluo = zip(*zipped)
        tags_biluo = [tag[2:] if tag != "O" else tag for tag in tags_biluo]
        return list(tokens_biluo), tags_biluo
    else:
        return [], []
        
names, labels = combine_biluo(names, labels)

This gives:

>>> names
['liverpool', 'Waterloo']

>>> labels
['location', 'location']