cjber's picture
Update README.md
ed15c9a
|
raw
history blame
2.88 kB
metadata
language: en
datasets:
  - wnut_17
license: mit
metrics:
  - f1
widget:
  - text: My name is Sylvain and I live in Paris
    example_title: Parisian
  - text: My name is Sarah and I live in London
    example_title: Londoner

Reddit NER for place names

Use in transformers

from transformers import pipeline

generator = pipeline(
    task="ner",
    model="cjber/reddit-ner-place_names",
    tokenizer="cjber/reddit-ner-place_names",
)

out = generator(
    "I live in Gothenburg, and long queues aside I definitely prefer the housing situation here compared to Edinburgh."
)

entities = [item["word"] for item in out]
labels = [item["entity"] for item in out]

Label idx values are required for the following stages:

class Label:
    labels: dict[str, int] = {
        "O": 0,
        "B-location": 1,
        "I-location": 2,
        "L-location": 3,
        "U-location": 4,
    }

    idx: dict[int, str] = {v: k for k, v in labels.items()}
    count: int = len(labels)

Combine subwords:

def combine_subwords(tokens: list[str], tags: list[int]) -> tuple[list[str], list[str]]:
    idx = [
        idx for idx, token in enumerate(tokens) if token not in ["<s>", "<pad>", "</s>"]
    ]

    tokens = [tokens[i] for i in idx]
    tags = [tags[i] for i in idx]

    for idx, _ in enumerate(tokens):
        idx += 1
        if not tokens[-idx + 1].startswith("Ġ"):
            tokens[-idx] = tokens[-idx] + tokens[-idx + 1]
    subwords = [i for i, _ in enumerate(tokens) if tokens[i].startswith("Ġ")]

    tags = [tags[i] for i in subwords]
    tokens = [tokens[i][1:] for i in subwords]
    tags_str: list[str] = [Label.idx[i] for i in tags]
    return tokens, tags_str


names, labels = combine_subwords(entities, [Label.labels[lb] for lb in labels])

Combine BILUO tags:

def combine_biluo(tokens: list[str], tags: list[str]) -> tuple[list[str], list[str]]:
    tokens_biluo = tokens.copy()
    tags_biluo = tags.copy()

    for idx, tag in enumerate(tags_biluo):
        if idx + 1 < len(tags_biluo) and tag[0] == "B":
            i = 1
            while tags_biluo[idx + i][0] not in ["B", "O"]:
                tokens_biluo[idx] = f"{tokens_biluo[idx]} {tokens_biluo[idx + i]}"
                i += 1
                if idx + i == len(tokens_biluo):
                    break

    zipped = [
        (token, tag)
        for (token, tag) in zip(tokens_biluo, tags_biluo)
        if tag[0] not in ["I", "L"]
    ]
    if list(zipped):
        tokens_biluo, tags_biluo = zip(*zipped)
        tags_biluo = [tag[2:] if tag != "O" else tag for tag in tags_biluo]
        return list(tokens_biluo), tags_biluo
    else:
        return [], []
        
names, labels = combine_biluo(names, labels)

This gives:

>>> names
['Gothenburg', 'Edinburgh']

>>> labels
['location', 'location']