Ling / tasks /ner.py
Nam Fam
update files
ea99abb
from typing import Dict, List, Union, Optional
from llms import LLM
from dataclasses import dataclass, asdict
import json
@dataclass
class Entity:
text: str
type: str
start: int
end: int
confidence: Optional[float] = None
description: Optional[str] = None
def named_entity_recognition(
text: str,
model: str = "gemini-2.0-flash",
use_llm: bool = True,
entity_types: Optional[List[str]] = None
) -> Union[str, List[Dict]]:
"""
Perform named entity recognition using either LLM or traditional NER models.
Args:
text: Input text to analyze
model: Model to use for NER
use_llm: Whether to use LLM for more accurate but slower NER
entity_types: List of entity types to extract (only used with LLM)
Returns:
List of entities with their types and positions
"""
if not text.strip():
return []
if use_llm:
return _ner_with_llm(text, model, entity_types)
else:
return _ner_with_traditional(text, model)
def _ner_with_llm(
text: str,
model_name: str,
entity_types: Optional[List[str]] = None
) -> List[Dict]:
"""Use LLM for more accurate and flexible NER."""
# Default entity types if none provided
if entity_types is None:
entity_types = [
"PERSON", "ORG", "GPE", "LOC", "PRODUCT", "EVENT",
"WORK_OF_ART", "LAW", "LANGUAGE", "DATE", "TIME",
"PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"
]
# Create the prompt
entity_types_str = ", ".join(entity_types)
prompt = f"""
Extract named entities from the following text and categorize them into these types: {entity_types_str}.
For each entity, provide:
- The entity text
- The entity type (from the list above)
- The start and end character positions
- (Optional) A brief description of the entity
- (Optional) Confidence score (0-1)
Return the entities as a JSON array of objects with these fields:
- text: The entity text
- type: The entity type
- start: Start character position
- end: End character position
- description: (Optional) Brief description
- confidence: (Optional) Confidence score (0-1)
Text: """ + text + """
JSON response (only the array, no other text):
["""
try:
# Initialize LLM
llm = LLM(model=model_name, temperature=0.1)
# Get response from LLM
response = llm.generate(prompt)
# Clean and parse the response
response = response.strip()
if response.startswith('```json'):
response = response[response.find('['):response.rfind(']')+1]
elif response.startswith('['):
response = response[:response.rfind(']')+1]
entities = json.loads(response)
# Convert to Entity objects and validate
valid_entities = []
for ent in entities:
try:
entity = Entity(
text=ent['text'],
type=ent['type'],
start=int(ent['start']),
end=int(ent['end']),
confidence=ent.get('confidence'),
description=ent.get('description')
)
valid_entities.append(asdict(entity))
except (KeyError, ValueError) as e:
print(f"Error parsing entity: {e}")
continue
return valid_entities
except Exception as e:
print(f"Error in LLM NER: {str(e)}")
# Fall back to traditional NER if LLM fails
return _ner_with_traditional(text, "en_core_web_md")
def _ner_with_traditional(text: str, model: str) -> List[Dict]:
"""Fallback to traditional NER models."""
try:
import spacy
# Load the appropriate model
if model == "en_core_web_sm" or model == "en_core_web_md" or model == "en_core_web_lg":
nlp = spacy.load(model)
else:
nlp = spacy.load("en_core_web_md")
# Process the text
doc = nlp(text)
# Convert to our entity format
entities = []
for ent in doc.ents:
entities.append({
'text': ent.text,
'type': ent.label_,
'start': ent.start_char,
'end': ent.end_char,
'confidence': 1.0 # Traditional NER doesn't provide confidence
})
return entities
except Exception as e:
print(f"Error in traditional NER: {str(e)}")
return []