File size: 4,716 Bytes
ea99abb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from typing import Dict, List, Union, Optional
from llms import LLM
from dataclasses import dataclass, asdict
import json
@dataclass
class Entity:
text: str
type: str
start: int
end: int
confidence: Optional[float] = None
description: Optional[str] = None
def named_entity_recognition(
text: str,
model: str = "gemini-2.0-flash",
use_llm: bool = True,
entity_types: Optional[List[str]] = None
) -> Union[str, List[Dict]]:
"""
Perform named entity recognition using either LLM or traditional NER models.
Args:
text: Input text to analyze
model: Model to use for NER
use_llm: Whether to use LLM for more accurate but slower NER
entity_types: List of entity types to extract (only used with LLM)
Returns:
List of entities with their types and positions
"""
if not text.strip():
return []
if use_llm:
return _ner_with_llm(text, model, entity_types)
else:
return _ner_with_traditional(text, model)
def _ner_with_llm(
text: str,
model_name: str,
entity_types: Optional[List[str]] = None
) -> List[Dict]:
"""Use LLM for more accurate and flexible NER."""
# Default entity types if none provided
if entity_types is None:
entity_types = [
"PERSON", "ORG", "GPE", "LOC", "PRODUCT", "EVENT",
"WORK_OF_ART", "LAW", "LANGUAGE", "DATE", "TIME",
"PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"
]
# Create the prompt
entity_types_str = ", ".join(entity_types)
prompt = f"""
Extract named entities from the following text and categorize them into these types: {entity_types_str}.
For each entity, provide:
- The entity text
- The entity type (from the list above)
- The start and end character positions
- (Optional) A brief description of the entity
- (Optional) Confidence score (0-1)
Return the entities as a JSON array of objects with these fields:
- text: The entity text
- type: The entity type
- start: Start character position
- end: End character position
- description: (Optional) Brief description
- confidence: (Optional) Confidence score (0-1)
Text: """ + text + """
JSON response (only the array, no other text):
["""
try:
# Initialize LLM
llm = LLM(model=model_name, temperature=0.1)
# Get response from LLM
response = llm.generate(prompt)
# Clean and parse the response
response = response.strip()
if response.startswith('```json'):
response = response[response.find('['):response.rfind(']')+1]
elif response.startswith('['):
response = response[:response.rfind(']')+1]
entities = json.loads(response)
# Convert to Entity objects and validate
valid_entities = []
for ent in entities:
try:
entity = Entity(
text=ent['text'],
type=ent['type'],
start=int(ent['start']),
end=int(ent['end']),
confidence=ent.get('confidence'),
description=ent.get('description')
)
valid_entities.append(asdict(entity))
except (KeyError, ValueError) as e:
print(f"Error parsing entity: {e}")
continue
return valid_entities
except Exception as e:
print(f"Error in LLM NER: {str(e)}")
# Fall back to traditional NER if LLM fails
return _ner_with_traditional(text, "en_core_web_md")
def _ner_with_traditional(text: str, model: str) -> List[Dict]:
"""Fallback to traditional NER models."""
try:
import spacy
# Load the appropriate model
if model == "en_core_web_sm" or model == "en_core_web_md" or model == "en_core_web_lg":
nlp = spacy.load(model)
else:
nlp = spacy.load("en_core_web_md")
# Process the text
doc = nlp(text)
# Convert to our entity format
entities = []
for ent in doc.ents:
entities.append({
'text': ent.text,
'type': ent.label_,
'start': ent.start_char,
'end': ent.end_char,
'confidence': 1.0 # Traditional NER doesn't provide confidence
})
return entities
except Exception as e:
print(f"Error in traditional NER: {str(e)}")
return [] |