patruff's picture
Upload tool
49ca653 verified
raw
history blame
11.7 kB
from smolagents.tools import Tool
import string
import pronouncing
import json
class ParodyWordSuggestionTool(Tool):
name = "parody_word_suggester"
description = """Suggests rhyming funny words using CMU dictionary and custom pronunciations.
Returns similar-sounding words that rhyme, especially focusing on common vowel sounds."""
inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'default': '0.5', 'nullable': True}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'default': None, 'nullable': True}}
output_type = "string"
VOWEL_REF = "AH,AX|UH|AE,EH|IY,IH|AO,AA|UW|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR"
def _get_vowel_groups(self):
groups = []
group_strs = self.VOWEL_REF.split("|")
for group_str in group_strs:
groups.append(group_str.split(","))
return groups
def _get_word_phones(self, word, custom_phones=None):
"""Get phones for a word, checking custom dictionary first."""
if custom_phones and word in custom_phones:
return custom_phones[word]["primary_phones"]
import pronouncing
phones = pronouncing.phones_for_word(word)
return phones[0] if phones else None
def _get_last_syllable(self, phones: list) -> tuple:
"""Extract the last syllable (vowel + remaining consonants)."""
last_vowel_idx = -1
last_vowel = None
vowel_groups = self._get_vowel_groups()
for i, phone in enumerate(phones):
base_phone = phone.rstrip('012')
for group in vowel_groups:
if base_phone in group:
last_vowel_idx = i
last_vowel = base_phone
break
if last_vowel_idx == -1:
return None, []
remaining = phones[last_vowel_idx + 1:]
return last_vowel, remaining
def _strip_stress(self, phones: list) -> list:
result = []
for phone in phones:
result.append(phone.rstrip('012'))
return result
def _vowels_match(self, v1: str, v2: str) -> bool:
v1 = v1.rstrip('012')
v2 = v2.rstrip('012')
if v1 == v2:
return True
vowel_groups = self._get_vowel_groups()
for group in vowel_groups:
if v1 in group and v2 in group:
return True
return False
def _strip_common_suffix(self, phones: list) -> tuple:
"""Strip common suffixes and return base and suffix phones."""
# Initialize variables
suffix_name = ""
suffix_phones = []
phone1 = ""
phone2 = ""
# Common suffix patterns in CMU phonetic representation
SUFFIXES = {
'ING': ['IH0', 'NG'], # -ing
'ED': ['EH0', 'D'], # -ed
'ER': ['ER0'], # -er
'EST': ['EH0', 'S', 'T'], # -est
'LY': ['L', 'IY0'], # -ly
'NESS': ['N', 'EH0', 'S'], # -ness
}
for suffix_name, suffix_phones in SUFFIXES.items():
if len(phones) > len(suffix_phones):
if all(phone1.rstrip('012') == phone2.rstrip('012')
for phone1, phone2 in zip(phones[-len(suffix_phones):], suffix_phones)):
return phones[:-len(suffix_phones)], suffix_phones
return phones, []
def _calculate_similarity(self, word1, phones1, word2, phones2):
"""Calculate similarity score using improved metrics and suffix handling."""
# Initialize all variables first
phone_list1 = []
phone_list2 = []
base1 = []
base2 = []
suffix1 = []
suffix2 = []
word_vowel = None
word_end = []
target_vowel = None
target_end = []
base_length_diff = 0
max_base_length = 0
length_score = 0.0
rhyme_score = 0.0
stress_score = 0.0
suffix_score = 0.0
word_end_clean = []
target_end_clean = []
common_length = 0
matched = 0
stress1 = ""
stress2 = ""
similarity = 0.0
result1 = (None, [])
result2 = (None, [])
# Main logic
phone_list1 = phones1.split()
phone_list2 = phones2.split()
# Strip common suffixes first
base1, suffix1 = self._strip_common_suffix(phone_list1)
base2, suffix2 = self._strip_common_suffix(phone_list2)
# Calculate base word similarity
base_length_diff = abs(len(base1) - len(base2))
max_base_length = max(len(base1), len(base2))
length_score = 1.0 if base_length_diff == 0 else 1.0 - (base_length_diff / max_base_length)
# Get last syllable components of base words
result1 = self._get_last_syllable(base1)
result2 = self._get_last_syllable(base2)
word_vowel, word_end = result1
target_vowel, target_end = result2
# Calculate rhyme score
rhyme_score = 0.0
if word_vowel and target_vowel:
if self._vowels_match(word_vowel, target_vowel):
word_end_clean = self._strip_stress(word_end)
target_end_clean = self._strip_stress(target_end)
if word_end_clean == target_end_clean:
if word_vowel.rstrip('012') == target_vowel.rstrip('012'):
rhyme_score = 1.0
else:
rhyme_score = 0.7 # Penalize different vowels in same group
else:
common_length = min(len(word_end_clean), len(target_end_clean))
matched = 0
for i in range(common_length):
if word_end_clean[i] == target_end_clean[i]:
matched += 1
rhyme_score = 0.3 * (matched / max(len(word_end_clean), len(target_end_clean)))
# Calculate stress pattern similarity using base words
import pronouncing
stress1 = pronouncing.stresses(' '.join(base1))
stress2 = pronouncing.stresses(' '.join(base2))
stress_score = 1.0 if stress1 == stress2 else 0.3
# Add suffix matching bonus
suffix_score = 1.0 if suffix1 == suffix2 else 0.0
# Weighted combination with emphasis on base word similarity
similarity = (
(rhyme_score * 0.6) + # Base word rhyme
(length_score * 0.1) + # Base word length
(stress_score * 0.2) + # Base word stress
(suffix_score * 0.1) # Suffix match as small bonus
)
similarity = min(1.0, similarity)
return {
"similarity": round(similarity, 3),
"rhyme_score": round(rhyme_score, 3),
"length_score": round(length_score, 3),
"stress_score": round(stress_score, 3),
"base_word_diff": base_length_diff,
"has_common_suffix": bool(suffix1 and suffix2),
"suffix_match": suffix_score == 1.0
}
def forward(self, target: str, word_list_str: str, min_similarity: str = "0.5", custom_phones: dict = None) -> str:
import pronouncing
import string
import json
# Initialize all variables
target = target.lower().strip(string.punctuation)
min_similarity = float(min_similarity)
suggestions = []
word_vowel = None
word_end = []
target_vowel = None
target_end = []
valid_words = []
invalid_words = []
target_phone_list = []
words = []
target_phones = ""
word_phones = ""
word = ""
word_phone_list = []
similarity_result = {}
# Parse JSON string to list
try:
words = json.loads(word_list_str)
except json.JSONDecodeError:
return json.dumps({
"error": "Invalid JSON string for word_list_str",
"suggestions": []
}, indent=2)
# Get target pronunciation
target_phones = self._get_word_phones(target, custom_phones)
if not target_phones:
return json.dumps({
"error": f"Target word '{target}' not found in dictionary or custom phones",
"suggestions": []
}, indent=2)
# Filter word list
valid_words = []
invalid_words = []
for word in words:
word = word.lower().strip(string.punctuation)
if self._get_word_phones(word, custom_phones):
valid_words.append(word)
else:
invalid_words.append(word)
if not valid_words:
return json.dumps({
"error": "No valid words found in dictionary or custom phones",
"invalid_words": invalid_words,
"suggestions": []
}, indent=2)
target_phone_list = target_phones.split()
target_vowel, target_end = self._get_last_syllable(target_phone_list)
# Check each word
for word in valid_words:
word_phones = self._get_word_phones(word, custom_phones)
if word_phones:
similarity_result = self._calculate_similarity(word, word_phones, target, target_phones)
if similarity_result["similarity"] >= min_similarity:
word_phone_list = word_phones.split()
word_vowel, word_end = self._get_last_syllable(word_phone_list)
suggestions.append({
"word": word,
"similarity": similarity_result["similarity"],
"rhyme_score": similarity_result["rhyme_score"],
"length_score": similarity_result["length_score"],
"stress_score": similarity_result["stress_score"],
"base_word_diff": similarity_result["base_word_diff"],
"has_common_suffix": similarity_result["has_common_suffix"],
"suffix_match": similarity_result["suffix_match"],
"phones": word_phones,
"last_vowel": word_vowel,
"ending": " ".join(word_end) if word_end else "",
"is_custom": word in custom_phones if custom_phones else False
})
# Sort by similarity score descending
suggestions.sort(key=lambda x: x["similarity"], reverse=True)
result = {
"target": target,
"target_phones": target_phones,
"target_last_vowel": target_vowel,
"target_ending": " ".join(target_end) if target_end else "",
"invalid_words": invalid_words,
"suggestions": suggestions
}
return json.dumps(result, indent=2)
def __init__(self, *args, **kwargs):
self.is_initialized = False