Spaces:
Configuration error
Configuration error
from smolagents.tools import Tool | |
import string | |
import pronouncing | |
import json | |
class ParodyWordSuggestionTool(Tool): | |
name = "parody_word_suggester" | |
description = "Suggests rhyming funny words using CMU dictionary pronunciations." | |
inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True, 'default': '0.5'}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'nullable': True, 'default': None}} | |
output_type = "string" | |
RHYME_WEIGHT = 0.5 | |
PHONE_SEQUENCE_WEIGHT = 0.3 | |
LENGTH_WEIGHT = 0.2 | |
PHONE_GROUPS = "M,N,NG|P,B|T,D|K,G|F,V|TH,DH|S,Z|SH,ZH|L,R|W,Y|IY,IH|UW,UH|EH,AH|AO,AA|AE,AH|AY,EY|OW,UW" | |
def _get_word_phones(self, word, custom_phones=None): | |
"""Get phones for a word, checking custom dictionary first.""" | |
if custom_phones and word in custom_phones: | |
return custom_phones[word]["primary_phones"] | |
import pronouncing | |
phones = pronouncing.phones_for_word(word) | |
return phones[0] if phones else None | |
def _get_primary_vowel(self, phones: list) -> str: | |
"""Get the primary stressed vowel from phone list.""" | |
phone_str = "" | |
vowel_char = "" | |
for phone_str in phones: | |
if '1' in phone_str and any(vowel_char in phone_str for vowel_char in 'AEIOU'): | |
return phone_str.rstrip('012') | |
return None | |
def _phones_are_similar(self, phone1: str, phone2: str) -> bool: | |
"""Check if two phones are similar enough to be considered rhyming.""" | |
# Strip stress markers | |
p1 = phone1.rstrip('012') | |
p2 = phone2.rstrip('012') | |
group_str = "" | |
group = [] | |
# Exact match | |
if p1 == p2: | |
return True | |
# Check similarity groups | |
for group_str in self.PHONE_GROUPS.split('|'): | |
group = group_str.split(',') | |
if p1 in group and p2 in group: | |
return True | |
return False | |
def _get_phone_type(self, phone: str) -> str: | |
"""Get the broad category of a phone.""" | |
# Strip stress markers | |
phone = phone.rstrip('012') | |
vowel_char = "" | |
# Vowels | |
if any(vowel_char in phone for vowel_char in 'AEIOU'): | |
return 'vowel' | |
# Initialize fixed sets for categories | |
nasals = {'M', 'N', 'NG'} | |
stops = {'P', 'B', 'T', 'D', 'K', 'G'} | |
fricatives = {'F', 'V', 'TH', 'DH', 'S', 'Z', 'SH', 'ZH'} | |
liquids = {'L', 'R'} | |
glides = {'W', 'Y'} | |
if phone in nasals: | |
return 'nasal' | |
if phone in stops: | |
return 'stop' | |
if phone in fricatives: | |
return 'fricative' | |
if phone in liquids: | |
return 'liquid' | |
if phone in glides: | |
return 'glide' | |
return 'other' | |
def _get_rhyme_score(self, phones1: list, phones2: list) -> float: | |
"""Calculate rhyme score based on matching phones after primary stressed vowel.""" | |
# Initialize variables | |
pos1 = -1 | |
pos2 = -1 | |
i = 0 | |
phone = "" | |
vowel_char = "" | |
rhyme_part1 = [] | |
rhyme_part2 = [] | |
similarity_count = 0 | |
p1 = "" | |
p2 = "" | |
# Find primary stressed vowel position in both words | |
for i, phone in enumerate(phones1): | |
if '1' in phone and any(vowel_char in phone for vowel_char in 'AEIOU'): | |
pos1 = i | |
break | |
for i, phone in enumerate(phones2): | |
if '1' in phone and any(vowel_char in phone for vowel_char in 'AEIOU'): | |
pos2 = i | |
break | |
if pos1 == -1 or pos2 == -1: | |
return 0.0 | |
# Get all phones after and including the stressed vowel | |
rhyme_part1 = phones1[pos1:] | |
rhyme_part2 = phones2[pos2:] | |
# Check if lengths match | |
if len(rhyme_part1) != len(rhyme_part2): | |
return 0.0 | |
# Calculate similarity score for rhyming part | |
for p1, p2 in zip(rhyme_part1, rhyme_part2): | |
if self._phones_are_similar(p1, p2): | |
similarity_count += 1 | |
# Return score based on how many phones were similar | |
return similarity_count / len(rhyme_part1) if rhyme_part1 else 0.0 | |
def _calculate_phone_sequence_similarity(self, phones1: list, phones2: list) -> float: | |
"""Calculate similarity based on matching phones in sequence.""" | |
if not phones1 or not phones2: | |
return 0.0 | |
# Initialize variables | |
total_similarity = 0.0 | |
i = 0 | |
similarity = 0.0 | |
comparisons = max(len(phones1), len(phones2)) | |
# Compare each position | |
for i in range(min(len(phones1), len(phones2))): | |
similarity = self._get_phone_similarity(phones1[i], phones2[i]) | |
total_similarity += similarity | |
return total_similarity / comparisons if comparisons > 0 else 0.0 | |
def _get_phone_similarity(self, phone1: str, phone2: str) -> float: | |
"""Calculate similarity between two phones.""" | |
# Initialize variables | |
p1 = phone1.rstrip('012') | |
p2 = phone2.rstrip('012') | |
group_str = "" | |
group = [] | |
# Exact match | |
if p1 == p2: | |
return 1.0 | |
# Check similarity groups | |
for group_str in self.PHONE_GROUPS.split('|'): | |
group = group_str.split(',') | |
if p1 in group and p2 in group: | |
return 0.7 | |
# Check broader categories | |
if self._get_phone_type(p1) == self._get_phone_type(p2): | |
return 0.3 | |
return 0.0 | |
def _calculate_length_similarity(self, phones1: list, phones2: list) -> float: | |
"""Calculate similarity based on phone length.""" | |
max_length = max(len(phones1), len(phones2)) | |
length_diff = abs(len(phones1) - len(phones2)) | |
return 1.0 - (length_diff / max_length) if max_length > 0 else 0.0 | |
def _calculate_similarity(self, word1, phones1, word2, phones2): | |
"""Calculate similarity based on multiple factors.""" | |
# Initialize variables | |
phone_list1 = phones1.split() | |
phone_list2 = phones2.split() | |
rhyme_score = 0.0 | |
phone_sequence_score = 0.0 | |
length_score = 0.0 | |
similarity = 0.0 | |
# Get rhyme score using new method | |
rhyme_score = self._get_rhyme_score(phone_list1, phone_list2) | |
# If rhyme score is too low (e.g. below 0.8), consider it a non-rhyme | |
if rhyme_score < 0.8: | |
return { | |
"similarity": 0.0, | |
"rhyme_score": 0.0, | |
"phone_sequence_score": 0.0, | |
"length_score": 0.0, | |
"details": { | |
"primary_vowel1": self._get_primary_vowel(phone_list1), | |
"primary_vowel2": self._get_primary_vowel(phone_list2), | |
"phone_count1": len(phone_list1), | |
"phone_count2": len(phone_list2), | |
"matching_phones": 0 | |
} | |
} | |
# Calculate other scores only if words rhyme closely enough | |
phone_sequence_score = self._calculate_phone_sequence_similarity(phone_list1, phone_list2) | |
length_score = self._calculate_length_similarity(phone_list1, phone_list2) | |
# Combined weighted score | |
similarity = ( | |
(rhyme_score * self.RHYME_WEIGHT) + | |
(phone_sequence_score * self.PHONE_SEQUENCE_WEIGHT) + | |
(length_score * self.LENGTH_WEIGHT) | |
) | |
return { | |
"similarity": round(similarity, 3), | |
"rhyme_score": round(rhyme_score, 3), | |
"phone_sequence_score": round(phone_sequence_score, 3), | |
"length_score": round(length_score, 3), | |
"details": { | |
"primary_vowel1": self._get_primary_vowel(phone_list1), | |
"primary_vowel2": self._get_primary_vowel(phone_list2), | |
"phone_count1": len(phone_list1), | |
"phone_count2": len(phone_list2), | |
"matching_phones": round(phone_sequence_score * len(phone_list1)) | |
} | |
} | |
def forward(self, target: str, word_list_str: str, min_similarity: str = "0.5", custom_phones: dict = None) -> str: | |
import pronouncing | |
import string | |
import json | |
# Initialize variables | |
target = target.lower().strip(string.punctuation) | |
min_similarity = float(min_similarity) | |
suggestions = [] | |
valid_words = [] | |
invalid_words = [] | |
words = [] | |
target_phones = "" | |
word_phones = "" | |
word = "" | |
similarity_result = {} | |
# Parse JSON string to list | |
try: | |
words = json.loads(word_list_str) | |
except json.JSONDecodeError: | |
return json.dumps({ | |
"error": "Invalid JSON string for word_list_str", | |
"suggestions": [] | |
}, indent=2) | |
# Get target pronunciation | |
target_phones = self._get_word_phones(target, custom_phones) | |
if not target_phones: | |
return json.dumps({ | |
"error": f"Target word '{target}' not found in dictionary or custom phones", | |
"suggestions": [] | |
}, indent=2) | |
# Filter word list | |
for word in words: | |
word = word.lower().strip(string.punctuation) | |
if self._get_word_phones(word, custom_phones): | |
valid_words.append(word) | |
else: | |
invalid_words.append(word) | |
if not valid_words: | |
return json.dumps({ | |
"error": "No valid words found in dictionary or custom phones", | |
"invalid_words": invalid_words, | |
"suggestions": [] | |
}, indent=2) | |
# Check each word | |
for word in valid_words: | |
word_phones = self._get_word_phones(word, custom_phones) | |
if word_phones: | |
similarity_result = self._calculate_similarity(word, word_phones, target, target_phones) | |
if similarity_result["similarity"] >= min_similarity: | |
suggestions.append({ | |
"word": word, | |
"similarity": similarity_result["similarity"], | |
"rhyme_score": similarity_result["rhyme_score"], | |
"phone_sequence_score": similarity_result["phone_sequence_score"], | |
"length_score": similarity_result["length_score"], | |
"phones": word_phones, | |
"is_custom": word in custom_phones if custom_phones else False, | |
"details": similarity_result["details"] | |
}) | |
# Sort by similarity score descending | |
suggestions.sort(key=lambda x: x["similarity"], reverse=True) | |
result = { | |
"target": target, | |
"target_phones": target_phones, | |
"invalid_words": invalid_words, | |
"suggestions": suggestions | |
} | |
return json.dumps(result, indent=2) | |
def __init__(self, *args, **kwargs): | |
self.is_initialized = False | |