Spaces:
Configuration error
Configuration error
Upload tool
Browse files
tool.py
CHANGED
@@ -1,19 +1,17 @@
|
|
1 |
from smolagents.tools import Tool
|
2 |
-
import string
|
3 |
import json
|
4 |
-
import difflib
|
5 |
import pronouncing
|
|
|
6 |
|
7 |
class ParodyWordSuggestionTool(Tool):
|
8 |
name = "parody_word_suggester"
|
9 |
-
description = """Suggests rhyming funny words using CMU dictionary pronunciations.
|
10 |
Returns similar-sounding words that rhyme, especially focusing on common vowel sounds."""
|
11 |
-
inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True}}
|
12 |
output_type = "string"
|
13 |
VOWEL_REF = "AH,UH,AX|AE,EH|IY,IH|AO,AA|UW,UH|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR"
|
14 |
|
15 |
def _get_vowel_groups(self):
|
16 |
-
"""Convert the simple string format to usable groups."""
|
17 |
groups = []
|
18 |
group_strs = self.VOWEL_REF.split("|")
|
19 |
for group_str in group_strs:
|
@@ -21,15 +19,23 @@ class ParodyWordSuggestionTool(Tool):
|
|
21 |
return groups
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def _get_last_syllable(self, phones: list) -> tuple:
|
25 |
"""Extract the last syllable (vowel + remaining consonants)."""
|
26 |
last_vowel_idx = -1
|
27 |
last_vowel = None
|
28 |
vowel_groups = self._get_vowel_groups()
|
29 |
|
30 |
-
# Find the last vowel
|
31 |
for i, phone in enumerate(phones):
|
32 |
-
# Strip stress markers for checking
|
33 |
base_phone = phone.rstrip('012')
|
34 |
for group in vowel_groups:
|
35 |
if base_phone in group:
|
@@ -40,14 +46,11 @@ class ParodyWordSuggestionTool(Tool):
|
|
40 |
if last_vowel_idx == -1:
|
41 |
return None, []
|
42 |
|
43 |
-
# Get all consonants after the vowel
|
44 |
remaining = phones[last_vowel_idx + 1:]
|
45 |
-
|
46 |
return last_vowel, remaining
|
47 |
|
48 |
|
49 |
def _strip_stress(self, phones: list) -> list:
|
50 |
-
"""Remove stress markers from phones."""
|
51 |
result = []
|
52 |
for phone in phones:
|
53 |
result.append(phone.rstrip('012'))
|
@@ -55,7 +58,6 @@ class ParodyWordSuggestionTool(Tool):
|
|
55 |
|
56 |
|
57 |
def _vowels_match(self, v1: str, v2: str) -> bool:
|
58 |
-
"""Check if two vowels are in the same group."""
|
59 |
v1 = v1.rstrip('012')
|
60 |
v2 = v2.rstrip('012')
|
61 |
|
@@ -69,14 +71,87 @@ class ParodyWordSuggestionTool(Tool):
|
|
69 |
return False
|
70 |
|
71 |
|
72 |
-
def
|
73 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
import pronouncing
|
75 |
import string
|
76 |
import json
|
77 |
-
from difflib import SequenceMatcher
|
78 |
|
79 |
-
# Initialize variables
|
80 |
target = target.lower().strip(string.punctuation)
|
81 |
min_similarity = float(min_similarity)
|
82 |
suggestions = []
|
@@ -84,6 +159,9 @@ class ParodyWordSuggestionTool(Tool):
|
|
84 |
word_end = []
|
85 |
target_vowel = None
|
86 |
target_end = []
|
|
|
|
|
|
|
87 |
|
88 |
# Parse JSON string to list
|
89 |
try:
|
@@ -95,92 +173,54 @@ class ParodyWordSuggestionTool(Tool):
|
|
95 |
}, indent=2)
|
96 |
|
97 |
# Get target pronunciation
|
98 |
-
target_phones =
|
99 |
if not target_phones:
|
100 |
return json.dumps({
|
101 |
-
"error": f"Target word '{target}' not found in
|
102 |
"suggestions": []
|
103 |
}, indent=2)
|
104 |
|
105 |
-
# Filter word list
|
106 |
valid_words = []
|
107 |
invalid_words = []
|
108 |
for word in words:
|
109 |
word = word.lower().strip(string.punctuation)
|
110 |
-
if
|
111 |
valid_words.append(word)
|
112 |
else:
|
113 |
invalid_words.append(word)
|
114 |
|
115 |
if not valid_words:
|
116 |
return json.dumps({
|
117 |
-
"error": "No valid words found in
|
118 |
"invalid_words": invalid_words,
|
119 |
"suggestions": []
|
120 |
}, indent=2)
|
121 |
|
122 |
-
target_phones = target_phones[0]
|
123 |
target_phone_list = target_phones.split()
|
124 |
target_vowel, target_end = self._get_last_syllable(target_phone_list)
|
125 |
|
126 |
# Check each word
|
127 |
for word in valid_words:
|
128 |
-
|
129 |
-
if
|
130 |
-
|
131 |
-
word_phone_list = word_phones.split()
|
132 |
-
word_vowel, word_end = self._get_last_syllable(word_phone_list)
|
133 |
-
|
134 |
-
# 1. Rhyme score (most important - 60%)
|
135 |
-
rhyme_score = 0.0
|
136 |
-
if word_vowel and target_vowel:
|
137 |
-
# Check if the vowels are similar
|
138 |
-
if self._vowels_match(word_vowel, target_vowel):
|
139 |
-
# Check if endings match (ignoring stress numbers)
|
140 |
-
word_end_clean = self._strip_stress(word_end)
|
141 |
-
target_end_clean = self._strip_stress(target_end)
|
142 |
-
|
143 |
-
if word_end_clean == target_end_clean:
|
144 |
-
rhyme_score = 1.0
|
145 |
-
# Extra boost for exact match minus first letter
|
146 |
-
if len(word) == len(target) and word[1:] == target[1:]:
|
147 |
-
rhyme_score = 1.2
|
148 |
-
else:
|
149 |
-
rhyme_score = 0.6
|
150 |
-
|
151 |
-
# 2. Syllable match (25%)
|
152 |
-
target_syl = pronouncing.syllable_count(target_phones)
|
153 |
-
word_syl = pronouncing.syllable_count(word_phones)
|
154 |
-
syllable_score = 1.0 if target_syl == word_syl else 0.0
|
155 |
-
|
156 |
-
# 3. String similarity (15%)
|
157 |
-
# Higher weight for end of word similarity
|
158 |
-
if len(word) > 1 and len(target) > 1:
|
159 |
-
end_similarity = SequenceMatcher(None, word[1:], target[1:]).ratio()
|
160 |
-
string_similarity = end_similarity
|
161 |
-
else:
|
162 |
-
string_similarity = SequenceMatcher(None, target, word).ratio()
|
163 |
-
|
164 |
-
# Combined score
|
165 |
-
similarity = (rhyme_score * 0.6) + (syllable_score * 0.25) + (string_similarity * 0.15)
|
166 |
|
167 |
-
if similarity >= min_similarity:
|
|
|
|
|
|
|
168 |
suggestions.append({
|
169 |
"word": word,
|
170 |
-
"similarity":
|
171 |
-
"
|
172 |
-
"
|
173 |
-
"
|
174 |
-
"
|
175 |
-
"syllables": word_syl,
|
176 |
"phones": word_phones,
|
177 |
"last_vowel": word_vowel,
|
178 |
"ending": " ".join(word_end) if word_end else "",
|
179 |
-
"
|
180 |
-
"word_end_clean": word_end_clean,
|
181 |
-
"target_end_clean": target_end_clean,
|
182 |
-
"exact_match": word_end_clean == target_end_clean
|
183 |
-
}
|
184 |
})
|
185 |
|
186 |
# Sort by similarity score descending
|
@@ -188,11 +228,10 @@ class ParodyWordSuggestionTool(Tool):
|
|
188 |
|
189 |
result = {
|
190 |
"target": target,
|
191 |
-
"target_syllables": pronouncing.syllable_count(target_phones),
|
192 |
"target_phones": target_phones,
|
193 |
"target_last_vowel": target_vowel,
|
194 |
"target_ending": " ".join(target_end) if target_end else "",
|
195 |
-
"invalid_words": invalid_words,
|
196 |
"suggestions": suggestions
|
197 |
}
|
198 |
|
|
|
1 |
from smolagents.tools import Tool
|
|
|
2 |
import json
|
|
|
3 |
import pronouncing
|
4 |
+
import string
|
5 |
|
6 |
class ParodyWordSuggestionTool(Tool):
|
7 |
name = "parody_word_suggester"
|
8 |
+
description = """Suggests rhyming funny words using CMU dictionary and custom pronunciations.
|
9 |
Returns similar-sounding words that rhyme, especially focusing on common vowel sounds."""
|
10 |
+
inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'nullable': True}}
|
11 |
output_type = "string"
|
12 |
VOWEL_REF = "AH,UH,AX|AE,EH|IY,IH|AO,AA|UW,UH|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR"
|
13 |
|
14 |
def _get_vowel_groups(self):
|
|
|
15 |
groups = []
|
16 |
group_strs = self.VOWEL_REF.split("|")
|
17 |
for group_str in group_strs:
|
|
|
19 |
return groups
|
20 |
|
21 |
|
22 |
+
def _get_word_phones(self, word, custom_phones=None):
|
23 |
+
"""Get phones for a word, checking custom dictionary first."""
|
24 |
+
if custom_phones and word in custom_phones:
|
25 |
+
return custom_phones[word]["primary_phones"]
|
26 |
+
|
27 |
+
import pronouncing
|
28 |
+
phones = pronouncing.phones_for_word(word)
|
29 |
+
return phones[0] if phones else None
|
30 |
+
|
31 |
+
|
32 |
def _get_last_syllable(self, phones: list) -> tuple:
|
33 |
"""Extract the last syllable (vowel + remaining consonants)."""
|
34 |
last_vowel_idx = -1
|
35 |
last_vowel = None
|
36 |
vowel_groups = self._get_vowel_groups()
|
37 |
|
|
|
38 |
for i, phone in enumerate(phones):
|
|
|
39 |
base_phone = phone.rstrip('012')
|
40 |
for group in vowel_groups:
|
41 |
if base_phone in group:
|
|
|
46 |
if last_vowel_idx == -1:
|
47 |
return None, []
|
48 |
|
|
|
49 |
remaining = phones[last_vowel_idx + 1:]
|
|
|
50 |
return last_vowel, remaining
|
51 |
|
52 |
|
53 |
def _strip_stress(self, phones: list) -> list:
|
|
|
54 |
result = []
|
55 |
for phone in phones:
|
56 |
result.append(phone.rstrip('012'))
|
|
|
58 |
|
59 |
|
60 |
def _vowels_match(self, v1: str, v2: str) -> bool:
|
|
|
61 |
v1 = v1.rstrip('012')
|
62 |
v2 = v2.rstrip('012')
|
63 |
|
|
|
71 |
return False
|
72 |
|
73 |
|
74 |
+
def _calculate_similarity(self, word1, phones1, word2, phones2):
|
75 |
+
"""Calculate similarity score using improved metrics."""
|
76 |
+
# Initialize all variables
|
77 |
+
word_vowel = None
|
78 |
+
word_end = []
|
79 |
+
target_vowel = None
|
80 |
+
target_end = []
|
81 |
+
phone_diff = 0
|
82 |
+
max_phones = 0
|
83 |
+
length_score = 0.0
|
84 |
+
rhyme_score = 0.0
|
85 |
+
stress_score = 0.0
|
86 |
+
i = 0 # For loop counter
|
87 |
+
word_end_clean = []
|
88 |
+
target_end_clean = []
|
89 |
+
matched = 0
|
90 |
+
common_length = 0
|
91 |
+
|
92 |
+
phone_list1 = phones1.split()
|
93 |
+
phone_list2 = phones2.split()
|
94 |
+
|
95 |
+
# Calculate length similarity score
|
96 |
+
phone_diff = abs(len(phone_list1) - len(phone_list2))
|
97 |
+
max_phones = max(len(phone_list1), len(phone_list2))
|
98 |
+
length_score = 1.0 if phone_diff == 0 else 1.0 - (phone_diff / max_phones)
|
99 |
+
|
100 |
+
# Get last syllable components
|
101 |
+
result1 = self._get_last_syllable(phone_list1)
|
102 |
+
result2 = self._get_last_syllable(phone_list2)
|
103 |
+
word_vowel, word_end = result1
|
104 |
+
target_vowel, target_end = result2
|
105 |
+
|
106 |
+
# Calculate rhyme score
|
107 |
+
rhyme_score = 0.0
|
108 |
+
if word_vowel and target_vowel:
|
109 |
+
if self._vowels_match(word_vowel, target_vowel):
|
110 |
+
word_end_clean = self._strip_stress(word_end)
|
111 |
+
target_end_clean = self._strip_stress(target_end)
|
112 |
+
|
113 |
+
if word_end_clean == target_end_clean:
|
114 |
+
rhyme_score = 1.0
|
115 |
+
else:
|
116 |
+
# Partial rhyme based on ending similarity
|
117 |
+
common_length = min(len(word_end_clean), len(target_end_clean))
|
118 |
+
matched = 0
|
119 |
+
for i in range(common_length):
|
120 |
+
if word_end_clean[i] == target_end_clean[i]:
|
121 |
+
matched += 1
|
122 |
+
rhyme_score = 0.6 * (matched / max(len(word_end_clean), len(target_end_clean)))
|
123 |
+
|
124 |
+
# Calculate stress pattern similarity
|
125 |
+
import pronouncing
|
126 |
+
stress1 = pronouncing.stresses(phones1)
|
127 |
+
stress2 = pronouncing.stresses(phones2)
|
128 |
+
stress_score = 1.0 if stress1 == stress2 else 0.5
|
129 |
+
|
130 |
+
# Weighted combination (60% rhyme, 30% length, 10% stress)
|
131 |
+
similarity = (
|
132 |
+
(rhyme_score * 0.6) +
|
133 |
+
(length_score * 0.3) +
|
134 |
+
(stress_score * 0.1)
|
135 |
+
)
|
136 |
+
|
137 |
+
# Cap at 1.0
|
138 |
+
similarity = min(1.0, similarity)
|
139 |
+
|
140 |
+
return {
|
141 |
+
"similarity": round(similarity, 3),
|
142 |
+
"rhyme_score": round(rhyme_score, 3),
|
143 |
+
"length_score": round(length_score, 3),
|
144 |
+
"stress_score": round(stress_score, 3),
|
145 |
+
"phone_length_difference": phone_diff
|
146 |
+
}
|
147 |
+
|
148 |
+
|
149 |
+
def forward(self, target: str, word_list_str: str, min_similarity: str = "0.5", custom_phones: dict = None) -> str:
|
150 |
import pronouncing
|
151 |
import string
|
152 |
import json
|
|
|
153 |
|
154 |
+
# Initialize all variables
|
155 |
target = target.lower().strip(string.punctuation)
|
156 |
min_similarity = float(min_similarity)
|
157 |
suggestions = []
|
|
|
159 |
word_end = []
|
160 |
target_vowel = None
|
161 |
target_end = []
|
162 |
+
valid_words = []
|
163 |
+
invalid_words = []
|
164 |
+
target_phone_list = []
|
165 |
|
166 |
# Parse JSON string to list
|
167 |
try:
|
|
|
173 |
}, indent=2)
|
174 |
|
175 |
# Get target pronunciation
|
176 |
+
target_phones = self._get_word_phones(target, custom_phones)
|
177 |
if not target_phones:
|
178 |
return json.dumps({
|
179 |
+
"error": f"Target word '{target}' not found in dictionary or custom phones",
|
180 |
"suggestions": []
|
181 |
}, indent=2)
|
182 |
|
183 |
+
# Filter word list
|
184 |
valid_words = []
|
185 |
invalid_words = []
|
186 |
for word in words:
|
187 |
word = word.lower().strip(string.punctuation)
|
188 |
+
if self._get_word_phones(word, custom_phones):
|
189 |
valid_words.append(word)
|
190 |
else:
|
191 |
invalid_words.append(word)
|
192 |
|
193 |
if not valid_words:
|
194 |
return json.dumps({
|
195 |
+
"error": "No valid words found in dictionary or custom phones",
|
196 |
"invalid_words": invalid_words,
|
197 |
"suggestions": []
|
198 |
}, indent=2)
|
199 |
|
|
|
200 |
target_phone_list = target_phones.split()
|
201 |
target_vowel, target_end = self._get_last_syllable(target_phone_list)
|
202 |
|
203 |
# Check each word
|
204 |
for word in valid_words:
|
205 |
+
word_phones = self._get_word_phones(word, custom_phones)
|
206 |
+
if word_phones:
|
207 |
+
similarity_result = self._calculate_similarity(word, word_phones, target, target_phones)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
+
if similarity_result["similarity"] >= min_similarity:
|
210 |
+
word_phone_list = word_phones.split()
|
211 |
+
word_vowel, word_end = self._get_last_syllable(word_phone_list)
|
212 |
+
|
213 |
suggestions.append({
|
214 |
"word": word,
|
215 |
+
"similarity": similarity_result["similarity"],
|
216 |
+
"rhyme_score": similarity_result["rhyme_score"],
|
217 |
+
"length_score": similarity_result["length_score"],
|
218 |
+
"stress_score": similarity_result["stress_score"],
|
219 |
+
"phone_length_difference": similarity_result["phone_length_difference"],
|
|
|
220 |
"phones": word_phones,
|
221 |
"last_vowel": word_vowel,
|
222 |
"ending": " ".join(word_end) if word_end else "",
|
223 |
+
"is_custom": word in custom_phones if custom_phones else False
|
|
|
|
|
|
|
|
|
224 |
})
|
225 |
|
226 |
# Sort by similarity score descending
|
|
|
228 |
|
229 |
result = {
|
230 |
"target": target,
|
|
|
231 |
"target_phones": target_phones,
|
232 |
"target_last_vowel": target_vowel,
|
233 |
"target_ending": " ".join(target_end) if target_end else "",
|
234 |
+
"invalid_words": invalid_words,
|
235 |
"suggestions": suggestions
|
236 |
}
|
237 |
|