|
from typing import Dict, List, Union, Optional |
|
from llms import LLM |
|
from dataclasses import dataclass, asdict |
|
import json |
|
|
|
@dataclass |
|
class Entity: |
|
text: str |
|
type: str |
|
start: int |
|
end: int |
|
confidence: Optional[float] = None |
|
description: Optional[str] = None |
|
|
|
def named_entity_recognition( |
|
text: str, |
|
model: str = "gemini-2.0-flash", |
|
use_llm: bool = True, |
|
entity_types: Optional[List[str]] = None |
|
) -> Union[str, List[Dict]]: |
|
""" |
|
Perform named entity recognition using either LLM or traditional NER models. |
|
|
|
Args: |
|
text: Input text to analyze |
|
model: Model to use for NER |
|
use_llm: Whether to use LLM for more accurate but slower NER |
|
entity_types: List of entity types to extract (only used with LLM) |
|
|
|
Returns: |
|
List of entities with their types and positions |
|
""" |
|
if not text.strip(): |
|
return [] |
|
|
|
if use_llm: |
|
return _ner_with_llm(text, model, entity_types) |
|
else: |
|
return _ner_with_traditional(text, model) |
|
|
|
def _ner_with_llm( |
|
text: str, |
|
model_name: str, |
|
entity_types: Optional[List[str]] = None |
|
) -> List[Dict]: |
|
"""Use LLM for more accurate and flexible NER.""" |
|
|
|
if entity_types is None: |
|
entity_types = [ |
|
"PERSON", "ORG", "GPE", "LOC", "PRODUCT", "EVENT", |
|
"WORK_OF_ART", "LAW", "LANGUAGE", "DATE", "TIME", |
|
"PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL" |
|
] |
|
|
|
|
|
entity_types_str = ", ".join(entity_types) |
|
prompt = f""" |
|
Extract named entities from the following text and categorize them into these types: {entity_types_str}. |
|
For each entity, provide: |
|
- The entity text |
|
- The entity type (from the list above) |
|
- The start and end character positions |
|
- (Optional) A brief description of the entity |
|
- (Optional) Confidence score (0-1) |
|
|
|
Return the entities as a JSON array of objects with these fields: |
|
- text: The entity text |
|
- type: The entity type |
|
- start: Start character position |
|
- end: End character position |
|
- description: (Optional) Brief description |
|
- confidence: (Optional) Confidence score (0-1) |
|
|
|
Text: """ + text + """ |
|
|
|
JSON response (only the array, no other text): |
|
[""" |
|
|
|
try: |
|
|
|
llm = LLM(model=model_name, temperature=0.1) |
|
|
|
|
|
response = llm.generate(prompt) |
|
|
|
|
|
response = response.strip() |
|
if response.startswith('```json'): |
|
response = response[response.find('['):response.rfind(']')+1] |
|
elif response.startswith('['): |
|
response = response[:response.rfind(']')+1] |
|
|
|
entities = json.loads(response) |
|
|
|
|
|
valid_entities = [] |
|
for ent in entities: |
|
try: |
|
entity = Entity( |
|
text=ent['text'], |
|
type=ent['type'], |
|
start=int(ent['start']), |
|
end=int(ent['end']), |
|
confidence=ent.get('confidence'), |
|
description=ent.get('description') |
|
) |
|
valid_entities.append(asdict(entity)) |
|
except (KeyError, ValueError) as e: |
|
print(f"Error parsing entity: {e}") |
|
continue |
|
|
|
return valid_entities |
|
|
|
except Exception as e: |
|
print(f"Error in LLM NER: {str(e)}") |
|
|
|
return _ner_with_traditional(text, "en_core_web_md") |
|
|
|
def _ner_with_traditional(text: str, model: str) -> List[Dict]: |
|
"""Fallback to traditional NER models.""" |
|
try: |
|
import spacy |
|
|
|
|
|
if model == "en_core_web_sm" or model == "en_core_web_md" or model == "en_core_web_lg": |
|
nlp = spacy.load(model) |
|
else: |
|
nlp = spacy.load("en_core_web_md") |
|
|
|
|
|
doc = nlp(text) |
|
|
|
|
|
entities = [] |
|
for ent in doc.ents: |
|
entities.append({ |
|
'text': ent.text, |
|
'type': ent.label_, |
|
'start': ent.start_char, |
|
'end': ent.end_char, |
|
'confidence': 1.0 |
|
}) |
|
|
|
return entities |
|
|
|
except Exception as e: |
|
print(f"Error in traditional NER: {str(e)}") |
|
return [] |