Spaces:
Configuration error
Configuration error
from smolagents.tools import Tool | |
import string | |
import pronouncing | |
import json | |
class ParodyWordSuggestionTool(Tool): | |
name = "parody_word_suggester" | |
description = """Suggests rhyming funny words using CMU dictionary and custom pronunciations. | |
Returns similar-sounding words that rhyme, especially focusing on common vowel sounds.""" | |
inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'default': '0.5', 'nullable': True}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'default': None, 'nullable': True}} | |
output_type = "string" | |
VOWEL_REF = "AH,AX|UH|AE,EH|IY,IH|AO,AA|UW|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR" | |
def _get_vowel_groups(self): | |
groups = [] | |
group_strs = self.VOWEL_REF.split("|") | |
for group_str in group_strs: | |
groups.append(group_str.split(",")) | |
return groups | |
def _get_word_phones(self, word, custom_phones=None): | |
"""Get phones for a word, checking custom dictionary first.""" | |
if custom_phones and word in custom_phones: | |
return custom_phones[word]["primary_phones"] | |
import pronouncing | |
phones = pronouncing.phones_for_word(word) | |
return phones[0] if phones else None | |
def _get_last_syllable(self, phones: list) -> tuple: | |
"""Extract the last syllable (vowel + remaining consonants).""" | |
last_vowel_idx = -1 | |
last_vowel = None | |
vowel_groups = self._get_vowel_groups() | |
for i, phone in enumerate(phones): | |
base_phone = phone.rstrip('012') | |
for group in vowel_groups: | |
if base_phone in group: | |
last_vowel_idx = i | |
last_vowel = base_phone | |
break | |
if last_vowel_idx == -1: | |
return None, [] | |
remaining = phones[last_vowel_idx + 1:] | |
return last_vowel, remaining | |
def _strip_stress(self, phones: list) -> list: | |
result = [] | |
for phone in phones: | |
result.append(phone.rstrip('012')) | |
return result | |
def _vowels_match(self, v1: str, v2: str) -> bool: | |
v1 = v1.rstrip('012') | |
v2 = v2.rstrip('012') | |
if v1 == v2: | |
return True | |
vowel_groups = self._get_vowel_groups() | |
for group in vowel_groups: | |
if v1 in group and v2 in group: | |
return True | |
return False | |
def _strip_common_suffix(self, phones: list) -> tuple: | |
"""Strip common suffixes and return base and suffix phones.""" | |
# Initialize variables | |
suffix_name = "" | |
suffix_phones = [] | |
phone1 = "" | |
phone2 = "" | |
# Common suffix patterns in CMU phonetic representation | |
SUFFIXES = { | |
'ING': ['IH0', 'NG'], # -ing | |
'ED': ['EH0', 'D'], # -ed | |
'ER': ['ER0'], # -er | |
'EST': ['EH0', 'S', 'T'], # -est | |
'LY': ['L', 'IY0'], # -ly | |
'NESS': ['N', 'EH0', 'S'], # -ness | |
} | |
for suffix_name, suffix_phones in SUFFIXES.items(): | |
if len(phones) > len(suffix_phones): | |
if all(phone1.rstrip('012') == phone2.rstrip('012') | |
for phone1, phone2 in zip(phones[-len(suffix_phones):], suffix_phones)): | |
return phones[:-len(suffix_phones)], suffix_phones | |
return phones, [] | |
def _calculate_similarity(self, word1, phones1, word2, phones2): | |
"""Calculate similarity score using improved metrics and suffix handling.""" | |
# Initialize all variables first | |
phone_list1 = [] | |
phone_list2 = [] | |
base1 = [] | |
base2 = [] | |
suffix1 = [] | |
suffix2 = [] | |
word_vowel = None | |
word_end = [] | |
target_vowel = None | |
target_end = [] | |
base_length_diff = 0 | |
max_base_length = 0 | |
length_score = 0.0 | |
rhyme_score = 0.0 | |
stress_score = 0.0 | |
suffix_score = 0.0 | |
word_end_clean = [] | |
target_end_clean = [] | |
common_length = 0 | |
matched = 0 | |
stress1 = "" | |
stress2 = "" | |
similarity = 0.0 | |
result1 = (None, []) | |
result2 = (None, []) | |
# Main logic | |
phone_list1 = phones1.split() | |
phone_list2 = phones2.split() | |
# Strip common suffixes first | |
base1, suffix1 = self._strip_common_suffix(phone_list1) | |
base2, suffix2 = self._strip_common_suffix(phone_list2) | |
# Calculate base word similarity | |
base_length_diff = abs(len(base1) - len(base2)) | |
max_base_length = max(len(base1), len(base2)) | |
length_score = 1.0 if base_length_diff == 0 else 1.0 - (base_length_diff / max_base_length) | |
# Get last syllable components of base words | |
result1 = self._get_last_syllable(base1) | |
result2 = self._get_last_syllable(base2) | |
word_vowel, word_end = result1 | |
target_vowel, target_end = result2 | |
# Calculate rhyme score | |
rhyme_score = 0.0 | |
if word_vowel and target_vowel: | |
if self._vowels_match(word_vowel, target_vowel): | |
word_end_clean = self._strip_stress(word_end) | |
target_end_clean = self._strip_stress(target_end) | |
if word_end_clean == target_end_clean: | |
if word_vowel.rstrip('012') == target_vowel.rstrip('012'): | |
rhyme_score = 1.0 | |
else: | |
rhyme_score = 0.7 # Penalize different vowels in same group | |
else: | |
common_length = min(len(word_end_clean), len(target_end_clean)) | |
matched = 0 | |
for i in range(common_length): | |
if word_end_clean[i] == target_end_clean[i]: | |
matched += 1 | |
rhyme_score = 0.3 * (matched / max(len(word_end_clean), len(target_end_clean))) | |
# Calculate stress pattern similarity using base words | |
import pronouncing | |
stress1 = pronouncing.stresses(' '.join(base1)) | |
stress2 = pronouncing.stresses(' '.join(base2)) | |
stress_score = 1.0 if stress1 == stress2 else 0.3 | |
# Add suffix matching bonus | |
suffix_score = 1.0 if suffix1 == suffix2 else 0.0 | |
# Weighted combination with emphasis on base word similarity | |
similarity = ( | |
(rhyme_score * 0.6) + # Base word rhyme | |
(length_score * 0.1) + # Base word length | |
(stress_score * 0.2) + # Base word stress | |
(suffix_score * 0.1) # Suffix match as small bonus | |
) | |
similarity = min(1.0, similarity) | |
return { | |
"similarity": round(similarity, 3), | |
"rhyme_score": round(rhyme_score, 3), | |
"length_score": round(length_score, 3), | |
"stress_score": round(stress_score, 3), | |
"base_word_diff": base_length_diff, | |
"has_common_suffix": bool(suffix1 and suffix2), | |
"suffix_match": suffix_score == 1.0 | |
} | |
def forward(self, target: str, word_list_str: str, min_similarity: str = "0.5", custom_phones: dict = None) -> str: | |
import pronouncing | |
import string | |
import json | |
# Initialize all variables | |
target = target.lower().strip(string.punctuation) | |
min_similarity = float(min_similarity) | |
suggestions = [] | |
word_vowel = None | |
word_end = [] | |
target_vowel = None | |
target_end = [] | |
valid_words = [] | |
invalid_words = [] | |
target_phone_list = [] | |
words = [] | |
target_phones = "" | |
word_phones = "" | |
word = "" | |
word_phone_list = [] | |
similarity_result = {} | |
# Parse JSON string to list | |
try: | |
words = json.loads(word_list_str) | |
except json.JSONDecodeError: | |
return json.dumps({ | |
"error": "Invalid JSON string for word_list_str", | |
"suggestions": [] | |
}, indent=2) | |
# Get target pronunciation | |
target_phones = self._get_word_phones(target, custom_phones) | |
if not target_phones: | |
return json.dumps({ | |
"error": f"Target word '{target}' not found in dictionary or custom phones", | |
"suggestions": [] | |
}, indent=2) | |
# Filter word list | |
valid_words = [] | |
invalid_words = [] | |
for word in words: | |
word = word.lower().strip(string.punctuation) | |
if self._get_word_phones(word, custom_phones): | |
valid_words.append(word) | |
else: | |
invalid_words.append(word) | |
if not valid_words: | |
return json.dumps({ | |
"error": "No valid words found in dictionary or custom phones", | |
"invalid_words": invalid_words, | |
"suggestions": [] | |
}, indent=2) | |
target_phone_list = target_phones.split() | |
target_vowel, target_end = self._get_last_syllable(target_phone_list) | |
# Check each word | |
for word in valid_words: | |
word_phones = self._get_word_phones(word, custom_phones) | |
if word_phones: | |
similarity_result = self._calculate_similarity(word, word_phones, target, target_phones) | |
if similarity_result["similarity"] >= min_similarity: | |
word_phone_list = word_phones.split() | |
word_vowel, word_end = self._get_last_syllable(word_phone_list) | |
suggestions.append({ | |
"word": word, | |
"similarity": similarity_result["similarity"], | |
"rhyme_score": similarity_result["rhyme_score"], | |
"length_score": similarity_result["length_score"], | |
"stress_score": similarity_result["stress_score"], | |
"base_word_diff": similarity_result["base_word_diff"], | |
"has_common_suffix": similarity_result["has_common_suffix"], | |
"suffix_match": similarity_result["suffix_match"], | |
"phones": word_phones, | |
"last_vowel": word_vowel, | |
"ending": " ".join(word_end) if word_end else "", | |
"is_custom": word in custom_phones if custom_phones else False | |
}) | |
# Sort by similarity score descending | |
suggestions.sort(key=lambda x: x["similarity"], reverse=True) | |
result = { | |
"target": target, | |
"target_phones": target_phones, | |
"target_last_vowel": target_vowel, | |
"target_ending": " ".join(target_end) if target_end else "", | |
"invalid_words": invalid_words, | |
"suggestions": suggestions | |
} | |
return json.dumps(result, indent=2) | |
def __init__(self, *args, **kwargs): | |
self.is_initialized = False | |