Last commit not found
raw
history blame
1.7 kB
# Updated NamedEntityRecognitionTool in ner_tool.py
from transformers import pipeline
from transformers import Tool
class NamedEntityRecognitionTool(Tool):
name = "ner_tool"
description = "Identifies and labels various entities in a given text."
inputs = ["text"]
outputs = ["text"]
def __call__(self, text: str):
# Initialize the named entity recognition pipeline
ner_analyzer = pipeline("ner")
# Perform named entity recognition on the input text
entities = ner_analyzer(text)
# Prepare a list to store token-level entities
token_entities = []
for entity in entities:
label = entity.get("entity", "UNKNOWN")
word = entity.get("word", "")
start = entity.get("start", -1)
end = entity.get("end", -1)
# Extract the complete entity text
entity_text = text[start:end].strip()
# Check for multi-token entities
if "##" in word:
# For multi-token entities, add each sub-token with its label
sub_tokens = word.split("##")
for i, sub_token in enumerate(sub_tokens):
token_entities.append({"token": sub_token, "label": label, "entity_text": entity_text})
else:
# For single-token entities, add the token with its label
token_entities.append({"token": word, "label": label, "entity_text": entity_text})
# Print the identified token-level entities
print(f"Token-level Entities: {token_entities}")
return {"entities": token_entities} # Return a dictionary with the specified output component