|
from typing import Dict, List, Union, Optional |
|
from llms import LLM |
|
import json |
|
import re |
|
|
|
def pos_tagging( |
|
text: str, |
|
model: str = "en_core_web_sm", |
|
use_llm: bool = False, |
|
custom_instructions: str = "" |
|
) -> Dict[str, List[Union[str, List[str]]]]: |
|
""" |
|
Perform Part-of-Speech tagging on the input text using either LLM or traditional models. |
|
|
|
Args: |
|
text: The input text to tag |
|
model: The model to use for tagging (e.g., 'en_core_web_sm', 'gpt-4', 'gemini-pro') |
|
use_llm: Whether to use LLM for more accurate but slower POS tagging |
|
custom_instructions: Custom instructions for LLM-based tagging |
|
|
|
Returns: |
|
A dictionary containing 'tokens' and 'tags' lists |
|
""" |
|
if not text.strip(): |
|
return {"tokens": [], "tags": []} |
|
|
|
if use_llm: |
|
return _pos_tagging_with_llm(text, model, custom_instructions) |
|
else: |
|
return _pos_tagging_traditional(text, model) |
|
|
|
def _extract_json_array(text: str) -> str: |
|
"""Extract JSON array from text, handling various formats.""" |
|
import re |
|
|
|
|
|
json_match = re.search(r'\[\s*\{.*\}\s*\]', text, re.DOTALL) |
|
if json_match: |
|
return json_match.group(0) |
|
|
|
|
|
start = text.find('[') |
|
end = text.rfind(']') |
|
if start >= 0 and end > start: |
|
return text[start:end+1] |
|
|
|
return text |
|
|
|
def _pos_tagging_with_llm( |
|
text: str, |
|
model_name: str, |
|
custom_instructions: str = "" |
|
) -> Dict[str, List[str]]: |
|
"""Use LLM for more accurate and flexible POS tagging.""" |
|
|
|
prompt = """Analyze the following text and provide Part-of-Speech (POS) tags for each token. |
|
Return the result as a JSON array of objects with 'token' and 'tag' keys. |
|
|
|
Use standard Universal Dependencies POS tags: |
|
- ADJ: adjective |
|
- ADP: adposition |
|
- ADV: adverb |
|
- AUX: auxiliary verb |
|
- CONJ: coordinating conjunction |
|
- DET: determiner |
|
- INTJ: interjection |
|
- NOUN: noun |
|
- NUM: numeral |
|
- PART: particle |
|
- PRON: pronoun |
|
- PROPN: proper noun |
|
- PUNCT: punctuation |
|
- SCONJ: subordinating conjunction |
|
- SYM: symbol |
|
- VERB: verb |
|
- X: other |
|
|
|
Example output format: |
|
[ |
|
{"token": "Hello", "tag": "INTJ"}, |
|
{"token": "world", "tag": "NOUN"}, |
|
{"token": ".", "tag": "PUNCT"} |
|
] |
|
|
|
Text to analyze: |
|
""" |
|
|
|
if custom_instructions: |
|
prompt = f"{custom_instructions}\n\n{prompt}" |
|
|
|
prompt += f'"{text}"' |
|
|
|
try: |
|
|
|
llm = LLM(model=model_name, temperature=0.1, max_tokens=2000) |
|
|
|
|
|
response = llm.generate(prompt) |
|
print(f"LLM Raw Response: {response[:500]}...") |
|
|
|
if not response.strip(): |
|
raise ValueError("Empty response from LLM") |
|
|
|
|
|
json_str = _extract_json_array(response) |
|
if not json_str: |
|
raise ValueError("No JSON array found in response") |
|
|
|
|
|
try: |
|
pos_tags = json.loads(json_str) |
|
except json.JSONDecodeError as e: |
|
|
|
json_str = json_str.replace("'", '"') |
|
json_str = re.sub(r'(\w+):', r'"\1":', json_str) |
|
pos_tags = json.loads(json_str) |
|
|
|
|
|
if not isinstance(pos_tags, list): |
|
raise ValueError(f"Expected list, got {type(pos_tags).__name__}") |
|
|
|
tokens = [] |
|
tags = [] |
|
|
|
for item in pos_tags: |
|
if not isinstance(item, dict): |
|
continue |
|
|
|
token = item.get('token', '') |
|
tag = item.get('tag', '') |
|
|
|
if token and tag: |
|
tokens.append(str(token).strip()) |
|
tags.append(str(tag).strip()) |
|
|
|
if not tokens or not tags: |
|
raise ValueError("No valid tokens and tags found in response") |
|
|
|
return { |
|
'tokens': tokens, |
|
'tags': tags |
|
} |
|
|
|
except Exception as e: |
|
print(f"Error in LLM POS tagging: {str(e)}") |
|
print(f"Falling back to traditional POS tagging...") |
|
return _pos_tagging_traditional(text, "en_core_web_sm") |
|
|
|
def _pos_tagging_traditional(text: str, model: str) -> Dict[str, List[str]]: |
|
"""Use traditional POS tagging models.""" |
|
try: |
|
import spacy |
|
|
|
|
|
try: |
|
nlp = spacy.load(model) |
|
except OSError: |
|
|
|
nlp = spacy.load("en_core_web_sm") |
|
|
|
|
|
doc = nlp(text) |
|
|
|
|
|
tokens = [] |
|
tags = [] |
|
for token in doc: |
|
tokens.append(token.text) |
|
tags.append(token.pos_) |
|
|
|
return { |
|
'tokens': tokens, |
|
'tags': tags |
|
} |
|
|
|
except Exception as e: |
|
print(f"Error in traditional POS tagging: {str(e)}") |
|
return {"tokens": [], "tags": []} |
|
|