File size: 5,416 Bytes
ea99abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from typing import Dict, List, Union, Optional
from llms import LLM
import json
import re

def pos_tagging(
    text: str, 
    model: str = "en_core_web_sm",
    use_llm: bool = False,
    custom_instructions: str = ""
) -> Dict[str, List[Union[str, List[str]]]]:
    """
    Perform Part-of-Speech tagging on the input text using either LLM or traditional models.
    
    Args:
        text: The input text to tag
        model: The model to use for tagging (e.g., 'en_core_web_sm', 'gpt-4', 'gemini-pro')
        use_llm: Whether to use LLM for more accurate but slower POS tagging
        custom_instructions: Custom instructions for LLM-based tagging
        
    Returns:
        A dictionary containing 'tokens' and 'tags' lists
    """
    if not text.strip():
        return {"tokens": [], "tags": []}
        
    if use_llm:
        return _pos_tagging_with_llm(text, model, custom_instructions)
    else:
        return _pos_tagging_traditional(text, model)

def _extract_json_array(text: str) -> str:
    """Extract JSON array from text, handling various formats."""
    import re
    
    # Try to find JSON array pattern
    json_match = re.search(r'\[\s*\{.*\}\s*\]', text, re.DOTALL)
    if json_match:
        return json_match.group(0)
    
    # If not found, try to find array between square brackets
    start = text.find('[')
    end = text.rfind(']')
    if start >= 0 and end > start:
        return text[start:end+1]
    
    return text

def _pos_tagging_with_llm(
    text: str,
    model_name: str,
    custom_instructions: str = ""
) -> Dict[str, List[str]]:
    """Use LLM for more accurate and flexible POS tagging."""
    # Create the prompt with clear instructions
    prompt = """Analyze the following text and provide Part-of-Speech (POS) tags for each token.
Return the result as a JSON array of objects with 'token' and 'tag' keys.

Use standard Universal Dependencies POS tags:
- ADJ: adjective
- ADP: adposition
- ADV: adverb
- AUX: auxiliary verb
- CONJ: coordinating conjunction
- DET: determiner
- INTJ: interjection
- NOUN: noun
- NUM: numeral
- PART: particle
- PRON: pronoun
- PROPN: proper noun
- PUNCT: punctuation
- SCONJ: subordinating conjunction
- SYM: symbol
- VERB: verb
- X: other

Example output format:
[
  {"token": "Hello", "tag": "INTJ"},
  {"token": "world", "tag": "NOUN"},
  {"token": ".", "tag": "PUNCT"}
]

Text to analyze:
"""
    
    if custom_instructions:
        prompt = f"{custom_instructions}\n\n{prompt}"
    
    prompt += f'"{text}"'
    
    try:
        # Initialize LLM with lower temperature for more deterministic output
        llm = LLM(model=model_name, temperature=0.1, max_tokens=2000)
        
        # Get response from LLM
        response = llm.generate(prompt)
        print(f"LLM Raw Response: {response[:500]}...")  # Log first 500 chars
        
        if not response.strip():
            raise ValueError("Empty response from LLM")
            
        # Extract JSON array from response
        json_str = _extract_json_array(response)
        if not json_str:
            raise ValueError("No JSON array found in response")
            
        # Parse the JSON
        try:
            pos_tags = json.loads(json_str)
        except json.JSONDecodeError as e:
            # Try to fix common JSON issues
            json_str = json_str.replace("'", '"')
            json_str = re.sub(r'(\w+):', r'"\1":', json_str)  # Add quotes around keys
            pos_tags = json.loads(json_str)
        
        # Validate and extract tokens and tags
        if not isinstance(pos_tags, list):
            raise ValueError(f"Expected list, got {type(pos_tags).__name__}")
            
        tokens = []
        tags = []
        
        for item in pos_tags:
            if not isinstance(item, dict):
                continue
                
            token = item.get('token', '')
            tag = item.get('tag', '')
            
            if token and tag:  # Only add if both token and tag are non-empty
                tokens.append(str(token).strip())
                tags.append(str(tag).strip())
        
        if not tokens or not tags:
            raise ValueError("No valid tokens and tags found in response")
            
        return {
            'tokens': tokens,
            'tags': tags
        }
        
    except Exception as e:
        print(f"Error in LLM POS tagging: {str(e)}")
        print(f"Falling back to traditional POS tagging...")
        return _pos_tagging_traditional(text, "en_core_web_sm")

def _pos_tagging_traditional(text: str, model: str) -> Dict[str, List[str]]:
    """Use traditional POS tagging models."""
    try:
        import spacy
        
        # Load the appropriate model
        try:
            nlp = spacy.load(model)
        except OSError:
            # Fallback to small English model if specified model is not found
            nlp = spacy.load("en_core_web_sm")
            
        # Process the text
        doc = nlp(text)
        
        # Extract tokens and POS tags
        tokens = []
        tags = []
        for token in doc:
            tokens.append(token.text)
            tags.append(token.pos_)
            
        return {
            'tokens': tokens,
            'tags': tags
        }
        
    except Exception as e:
        print(f"Error in traditional POS tagging: {str(e)}")
        return {"tokens": [], "tags": []}