File size: 4,716 Bytes
ea99abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Dict, List, Union, Optional
from llms import LLM
from dataclasses import dataclass, asdict
import json

@dataclass
class Entity:
    text: str
    type: str
    start: int
    end: int
    confidence: Optional[float] = None
    description: Optional[str] = None

def named_entity_recognition(
    text: str, 
    model: str = "gemini-2.0-flash",
    use_llm: bool = True,
    entity_types: Optional[List[str]] = None
) -> Union[str, List[Dict]]:
    """
    Perform named entity recognition using either LLM or traditional NER models.
    
    Args:
        text: Input text to analyze
        model: Model to use for NER
        use_llm: Whether to use LLM for more accurate but slower NER
        entity_types: List of entity types to extract (only used with LLM)
        
    Returns:
        List of entities with their types and positions
    """
    if not text.strip():
        return []
        
    if use_llm:
        return _ner_with_llm(text, model, entity_types)
    else:
        return _ner_with_traditional(text, model)

def _ner_with_llm(
    text: str, 
    model_name: str,
    entity_types: Optional[List[str]] = None
) -> List[Dict]:
    """Use LLM for more accurate and flexible NER."""
    # Default entity types if none provided
    if entity_types is None:
        entity_types = [
            "PERSON", "ORG", "GPE", "LOC", "PRODUCT", "EVENT",
            "WORK_OF_ART", "LAW", "LANGUAGE", "DATE", "TIME",
            "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"
        ]
    
    # Create the prompt
    entity_types_str = ", ".join(entity_types)
    prompt = f"""
    Extract named entities from the following text and categorize them into these types: {entity_types_str}.
    For each entity, provide:
    - The entity text
    - The entity type (from the list above)
    - The start and end character positions
    - (Optional) A brief description of the entity
    - (Optional) Confidence score (0-1)
    
    Return the entities as a JSON array of objects with these fields:
    - text: The entity text
    - type: The entity type
    - start: Start character position
    - end: End character position
    - description: (Optional) Brief description
    - confidence: (Optional) Confidence score (0-1)
    
    Text: """ + text + """
    
    JSON response (only the array, no other text):
    ["""
    
    try:
        # Initialize LLM
        llm = LLM(model=model_name, temperature=0.1)
        
        # Get response from LLM
        response = llm.generate(prompt)
        
        # Clean and parse the response
        response = response.strip()
        if response.startswith('```json'):
            response = response[response.find('['):response.rfind(']')+1]
        elif response.startswith('['):
            response = response[:response.rfind(']')+1]
            
        entities = json.loads(response)
        
        # Convert to Entity objects and validate
        valid_entities = []
        for ent in entities:
            try:
                entity = Entity(
                    text=ent['text'],
                    type=ent['type'],
                    start=int(ent['start']),
                    end=int(ent['end']),
                    confidence=ent.get('confidence'),
                    description=ent.get('description')
                )
                valid_entities.append(asdict(entity))
            except (KeyError, ValueError) as e:
                print(f"Error parsing entity: {e}")
                continue
                
        return valid_entities
        
    except Exception as e:
        print(f"Error in LLM NER: {str(e)}")
        # Fall back to traditional NER if LLM fails
        return _ner_with_traditional(text, "en_core_web_md")

def _ner_with_traditional(text: str, model: str) -> List[Dict]:
    """Fallback to traditional NER models."""
    try:
        import spacy
        
        # Load the appropriate model
        if model == "en_core_web_sm" or model == "en_core_web_md" or model == "en_core_web_lg":
            nlp = spacy.load(model)
        else:
            nlp = spacy.load("en_core_web_md")
            
        # Process the text
        doc = nlp(text)
        
        # Convert to our entity format
        entities = []
        for ent in doc.ents:
            entities.append({
                'text': ent.text,
                'type': ent.label_,
                'start': ent.start_char,
                'end': ent.end_char,
                'confidence': 1.0  # Traditional NER doesn't provide confidence
            })
            
        return entities
        
    except Exception as e:
        print(f"Error in traditional NER: {str(e)}")
        return []