File size: 5,416 Bytes
ea99abb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
from typing import Dict, List, Union, Optional
from llms import LLM
import json
import re
def pos_tagging(
text: str,
model: str = "en_core_web_sm",
use_llm: bool = False,
custom_instructions: str = ""
) -> Dict[str, List[Union[str, List[str]]]]:
"""
Perform Part-of-Speech tagging on the input text using either LLM or traditional models.
Args:
text: The input text to tag
model: The model to use for tagging (e.g., 'en_core_web_sm', 'gpt-4', 'gemini-pro')
use_llm: Whether to use LLM for more accurate but slower POS tagging
custom_instructions: Custom instructions for LLM-based tagging
Returns:
A dictionary containing 'tokens' and 'tags' lists
"""
if not text.strip():
return {"tokens": [], "tags": []}
if use_llm:
return _pos_tagging_with_llm(text, model, custom_instructions)
else:
return _pos_tagging_traditional(text, model)
def _extract_json_array(text: str) -> str:
"""Extract JSON array from text, handling various formats."""
import re
# Try to find JSON array pattern
json_match = re.search(r'\[\s*\{.*\}\s*\]', text, re.DOTALL)
if json_match:
return json_match.group(0)
# If not found, try to find array between square brackets
start = text.find('[')
end = text.rfind(']')
if start >= 0 and end > start:
return text[start:end+1]
return text
def _pos_tagging_with_llm(
text: str,
model_name: str,
custom_instructions: str = ""
) -> Dict[str, List[str]]:
"""Use LLM for more accurate and flexible POS tagging."""
# Create the prompt with clear instructions
prompt = """Analyze the following text and provide Part-of-Speech (POS) tags for each token.
Return the result as a JSON array of objects with 'token' and 'tag' keys.
Use standard Universal Dependencies POS tags:
- ADJ: adjective
- ADP: adposition
- ADV: adverb
- AUX: auxiliary verb
- CONJ: coordinating conjunction
- DET: determiner
- INTJ: interjection
- NOUN: noun
- NUM: numeral
- PART: particle
- PRON: pronoun
- PROPN: proper noun
- PUNCT: punctuation
- SCONJ: subordinating conjunction
- SYM: symbol
- VERB: verb
- X: other
Example output format:
[
{"token": "Hello", "tag": "INTJ"},
{"token": "world", "tag": "NOUN"},
{"token": ".", "tag": "PUNCT"}
]
Text to analyze:
"""
if custom_instructions:
prompt = f"{custom_instructions}\n\n{prompt}"
prompt += f'"{text}"'
try:
# Initialize LLM with lower temperature for more deterministic output
llm = LLM(model=model_name, temperature=0.1, max_tokens=2000)
# Get response from LLM
response = llm.generate(prompt)
print(f"LLM Raw Response: {response[:500]}...") # Log first 500 chars
if not response.strip():
raise ValueError("Empty response from LLM")
# Extract JSON array from response
json_str = _extract_json_array(response)
if not json_str:
raise ValueError("No JSON array found in response")
# Parse the JSON
try:
pos_tags = json.loads(json_str)
except json.JSONDecodeError as e:
# Try to fix common JSON issues
json_str = json_str.replace("'", '"')
json_str = re.sub(r'(\w+):', r'"\1":', json_str) # Add quotes around keys
pos_tags = json.loads(json_str)
# Validate and extract tokens and tags
if not isinstance(pos_tags, list):
raise ValueError(f"Expected list, got {type(pos_tags).__name__}")
tokens = []
tags = []
for item in pos_tags:
if not isinstance(item, dict):
continue
token = item.get('token', '')
tag = item.get('tag', '')
if token and tag: # Only add if both token and tag are non-empty
tokens.append(str(token).strip())
tags.append(str(tag).strip())
if not tokens or not tags:
raise ValueError("No valid tokens and tags found in response")
return {
'tokens': tokens,
'tags': tags
}
except Exception as e:
print(f"Error in LLM POS tagging: {str(e)}")
print(f"Falling back to traditional POS tagging...")
return _pos_tagging_traditional(text, "en_core_web_sm")
def _pos_tagging_traditional(text: str, model: str) -> Dict[str, List[str]]:
"""Use traditional POS tagging models."""
try:
import spacy
# Load the appropriate model
try:
nlp = spacy.load(model)
except OSError:
# Fallback to small English model if specified model is not found
nlp = spacy.load("en_core_web_sm")
# Process the text
doc = nlp(text)
# Extract tokens and POS tags
tokens = []
tags = []
for token in doc:
tokens.append(token.text)
tags.append(token.pos_)
return {
'tokens': tokens,
'tags': tags
}
except Exception as e:
print(f"Error in traditional POS tagging: {str(e)}")
return {"tokens": [], "tags": []}
|