patruff's picture
Upload tool
294224b verified
from smolagents.tools import Tool
import string
import pronouncing
import json
class ParodyWordSuggestionTool(Tool):
name = "parody_word_suggester"
description = "Suggests rhyming funny words using CMU dictionary pronunciations."
inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True, 'default': '0.5'}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'nullable': True, 'default': None}}
output_type = "string"
RHYME_WEIGHT = 0.5
PHONE_SEQUENCE_WEIGHT = 0.3
LENGTH_WEIGHT = 0.2
PHONE_GROUPS = "M,N,NG|P,B|T,D|K,G|F,V|TH,DH|S,Z|SH,ZH|L,R|W,Y|IY,IH|UW,UH|EH,AH|AO,AA|AE,AH|AY,EY|OW,UW"
def _get_word_phones(self, word, custom_phones=None):
"""Get phones for a word, checking custom dictionary first."""
if custom_phones and word in custom_phones:
return custom_phones[word]["primary_phones"]
import pronouncing
phones = pronouncing.phones_for_word(word)
return phones[0] if phones else None
def _get_primary_vowel(self, phones: list) -> str:
"""Get the primary stressed vowel from phone list."""
phone_str = ""
vowel_char = ""
for phone_str in phones:
if '1' in phone_str and any(vowel_char in phone_str for vowel_char in 'AEIOU'):
return phone_str.rstrip('012')
return None
def _phones_are_similar(self, phone1: str, phone2: str) -> bool:
"""Check if two phones are similar enough to be considered rhyming."""
# Strip stress markers
p1 = phone1.rstrip('012')
p2 = phone2.rstrip('012')
group_str = ""
group = []
# Exact match
if p1 == p2:
return True
# Check similarity groups
for group_str in self.PHONE_GROUPS.split('|'):
group = group_str.split(',')
if p1 in group and p2 in group:
return True
return False
def _get_phone_type(self, phone: str) -> str:
"""Get the broad category of a phone."""
# Strip stress markers
phone = phone.rstrip('012')
vowel_char = ""
# Vowels
if any(vowel_char in phone for vowel_char in 'AEIOU'):
return 'vowel'
# Initialize fixed sets for categories
nasals = {'M', 'N', 'NG'}
stops = {'P', 'B', 'T', 'D', 'K', 'G'}
fricatives = {'F', 'V', 'TH', 'DH', 'S', 'Z', 'SH', 'ZH'}
liquids = {'L', 'R'}
glides = {'W', 'Y'}
if phone in nasals:
return 'nasal'
if phone in stops:
return 'stop'
if phone in fricatives:
return 'fricative'
if phone in liquids:
return 'liquid'
if phone in glides:
return 'glide'
return 'other'
def _get_rhyme_score(self, phones1: list, phones2: list) -> float:
"""Calculate rhyme score based on matching phones after primary stressed vowel."""
# Initialize variables
pos1 = -1
pos2 = -1
i = 0
phone = ""
vowel_char = ""
rhyme_part1 = []
rhyme_part2 = []
similarity_count = 0
p1 = ""
p2 = ""
# Find primary stressed vowel position in both words
for i, phone in enumerate(phones1):
if '1' in phone and any(vowel_char in phone for vowel_char in 'AEIOU'):
pos1 = i
break
for i, phone in enumerate(phones2):
if '1' in phone and any(vowel_char in phone for vowel_char in 'AEIOU'):
pos2 = i
break
if pos1 == -1 or pos2 == -1:
return 0.0
# Get all phones after and including the stressed vowel
rhyme_part1 = phones1[pos1:]
rhyme_part2 = phones2[pos2:]
# Check if lengths match
if len(rhyme_part1) != len(rhyme_part2):
return 0.0
# Calculate similarity score for rhyming part
for p1, p2 in zip(rhyme_part1, rhyme_part2):
if self._phones_are_similar(p1, p2):
similarity_count += 1
# Return score based on how many phones were similar
return similarity_count / len(rhyme_part1) if rhyme_part1 else 0.0
def _calculate_phone_sequence_similarity(self, phones1: list, phones2: list) -> float:
"""Calculate similarity based on matching phones in sequence."""
if not phones1 or not phones2:
return 0.0
# Initialize variables
total_similarity = 0.0
i = 0
similarity = 0.0
comparisons = max(len(phones1), len(phones2))
# Compare each position
for i in range(min(len(phones1), len(phones2))):
similarity = self._get_phone_similarity(phones1[i], phones2[i])
total_similarity += similarity
return total_similarity / comparisons if comparisons > 0 else 0.0
def _get_phone_similarity(self, phone1: str, phone2: str) -> float:
"""Calculate similarity between two phones."""
# Initialize variables
p1 = phone1.rstrip('012')
p2 = phone2.rstrip('012')
group_str = ""
group = []
# Exact match
if p1 == p2:
return 1.0
# Check similarity groups
for group_str in self.PHONE_GROUPS.split('|'):
group = group_str.split(',')
if p1 in group and p2 in group:
return 0.7
# Check broader categories
if self._get_phone_type(p1) == self._get_phone_type(p2):
return 0.3
return 0.0
def _calculate_length_similarity(self, phones1: list, phones2: list) -> float:
"""Calculate similarity based on phone length."""
max_length = max(len(phones1), len(phones2))
length_diff = abs(len(phones1) - len(phones2))
return 1.0 - (length_diff / max_length) if max_length > 0 else 0.0
def _calculate_similarity(self, word1, phones1, word2, phones2):
"""Calculate similarity based on multiple factors."""
# Initialize variables
phone_list1 = phones1.split()
phone_list2 = phones2.split()
rhyme_score = 0.0
phone_sequence_score = 0.0
length_score = 0.0
similarity = 0.0
# Get rhyme score using new method
rhyme_score = self._get_rhyme_score(phone_list1, phone_list2)
# If rhyme score is too low (e.g. below 0.8), consider it a non-rhyme
if rhyme_score < 0.8:
return {
"similarity": 0.0,
"rhyme_score": 0.0,
"phone_sequence_score": 0.0,
"length_score": 0.0,
"details": {
"primary_vowel1": self._get_primary_vowel(phone_list1),
"primary_vowel2": self._get_primary_vowel(phone_list2),
"phone_count1": len(phone_list1),
"phone_count2": len(phone_list2),
"matching_phones": 0
}
}
# Calculate other scores only if words rhyme closely enough
phone_sequence_score = self._calculate_phone_sequence_similarity(phone_list1, phone_list2)
length_score = self._calculate_length_similarity(phone_list1, phone_list2)
# Combined weighted score
similarity = (
(rhyme_score * self.RHYME_WEIGHT) +
(phone_sequence_score * self.PHONE_SEQUENCE_WEIGHT) +
(length_score * self.LENGTH_WEIGHT)
)
return {
"similarity": round(similarity, 3),
"rhyme_score": round(rhyme_score, 3),
"phone_sequence_score": round(phone_sequence_score, 3),
"length_score": round(length_score, 3),
"details": {
"primary_vowel1": self._get_primary_vowel(phone_list1),
"primary_vowel2": self._get_primary_vowel(phone_list2),
"phone_count1": len(phone_list1),
"phone_count2": len(phone_list2),
"matching_phones": round(phone_sequence_score * len(phone_list1))
}
}
def forward(self, target: str, word_list_str: str, min_similarity: str = "0.5", custom_phones: dict = None) -> str:
import pronouncing
import string
import json
# Initialize variables
target = target.lower().strip(string.punctuation)
min_similarity = float(min_similarity)
suggestions = []
valid_words = []
invalid_words = []
words = []
target_phones = ""
word_phones = ""
word = ""
similarity_result = {}
# Parse JSON string to list
try:
words = json.loads(word_list_str)
except json.JSONDecodeError:
return json.dumps({
"error": "Invalid JSON string for word_list_str",
"suggestions": []
}, indent=2)
# Get target pronunciation
target_phones = self._get_word_phones(target, custom_phones)
if not target_phones:
return json.dumps({
"error": f"Target word '{target}' not found in dictionary or custom phones",
"suggestions": []
}, indent=2)
# Filter word list
for word in words:
word = word.lower().strip(string.punctuation)
if self._get_word_phones(word, custom_phones):
valid_words.append(word)
else:
invalid_words.append(word)
if not valid_words:
return json.dumps({
"error": "No valid words found in dictionary or custom phones",
"invalid_words": invalid_words,
"suggestions": []
}, indent=2)
# Check each word
for word in valid_words:
word_phones = self._get_word_phones(word, custom_phones)
if word_phones:
similarity_result = self._calculate_similarity(word, word_phones, target, target_phones)
if similarity_result["similarity"] >= min_similarity:
suggestions.append({
"word": word,
"similarity": similarity_result["similarity"],
"rhyme_score": similarity_result["rhyme_score"],
"phone_sequence_score": similarity_result["phone_sequence_score"],
"length_score": similarity_result["length_score"],
"phones": word_phones,
"is_custom": word in custom_phones if custom_phones else False,
"details": similarity_result["details"]
})
# Sort by similarity score descending
suggestions.sort(key=lambda x: x["similarity"], reverse=True)
result = {
"target": target,
"target_phones": target_phones,
"invalid_words": invalid_words,
"suggestions": suggestions
}
return json.dumps(result, indent=2)
def __init__(self, *args, **kwargs):
self.is_initialized = False