Spaces:
Configuration error
Configuration error
Upload tool
Browse files
tool.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
from smolagents.tools import Tool
|
2 |
-
import json
|
3 |
import string
|
4 |
import pronouncing
|
|
|
5 |
|
6 |
class ParodyWordSuggestionTool(Tool):
|
7 |
name = "parody_word_suggester"
|
8 |
description = """Suggests rhyming funny words using CMU dictionary and custom pronunciations.
|
9 |
Returns similar-sounding words that rhyme, especially focusing on common vowel sounds."""
|
10 |
-
inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'nullable': True}}
|
11 |
output_type = "string"
|
12 |
VOWEL_REF = "AH,AX|UH|AE,EH|IY,IH|AO,AA|UW|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR"
|
13 |
|
@@ -71,35 +71,78 @@ class ParodyWordSuggestionTool(Tool):
|
|
71 |
return False
|
72 |
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
def _calculate_similarity(self, word1, phones1, word2, phones2):
|
75 |
-
"""Calculate similarity score using improved metrics."""
|
76 |
-
# Initialize all variables
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
word_vowel = None
|
78 |
word_end = []
|
79 |
target_vowel = None
|
80 |
target_end = []
|
81 |
-
|
82 |
-
|
83 |
length_score = 0.0
|
84 |
rhyme_score = 0.0
|
85 |
stress_score = 0.0
|
86 |
-
|
87 |
word_end_clean = []
|
88 |
target_end_clean = []
|
89 |
-
matched = 0
|
90 |
common_length = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
|
|
92 |
phone_list1 = phones1.split()
|
93 |
phone_list2 = phones2.split()
|
94 |
|
95 |
-
#
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
# Get last syllable components
|
101 |
-
result1 = self._get_last_syllable(
|
102 |
-
result2 = self._get_last_syllable(
|
103 |
word_vowel, word_end = result1
|
104 |
target_vowel, target_end = result2
|
105 |
|
@@ -111,30 +154,35 @@ class ParodyWordSuggestionTool(Tool):
|
|
111 |
target_end_clean = self._strip_stress(target_end)
|
112 |
|
113 |
if word_end_clean == target_end_clean:
|
114 |
-
|
|
|
|
|
|
|
115 |
else:
|
116 |
-
# Partial rhyme based on ending similarity
|
117 |
common_length = min(len(word_end_clean), len(target_end_clean))
|
118 |
matched = 0
|
119 |
for i in range(common_length):
|
120 |
if word_end_clean[i] == target_end_clean[i]:
|
121 |
matched += 1
|
122 |
-
rhyme_score = 0.
|
123 |
|
124 |
-
# Calculate stress pattern similarity
|
125 |
import pronouncing
|
126 |
-
stress1 = pronouncing.stresses(
|
127 |
-
stress2 = pronouncing.stresses(
|
128 |
-
stress_score = 1.0 if stress1 == stress2 else 0.
|
|
|
|
|
|
|
129 |
|
130 |
-
# Weighted combination
|
131 |
similarity = (
|
132 |
-
(rhyme_score * 0.6) +
|
133 |
-
(length_score * 0.
|
134 |
-
(stress_score * 0.
|
|
|
135 |
)
|
136 |
|
137 |
-
# Cap at 1.0
|
138 |
similarity = min(1.0, similarity)
|
139 |
|
140 |
return {
|
@@ -142,7 +190,9 @@ class ParodyWordSuggestionTool(Tool):
|
|
142 |
"rhyme_score": round(rhyme_score, 3),
|
143 |
"length_score": round(length_score, 3),
|
144 |
"stress_score": round(stress_score, 3),
|
145 |
-
"
|
|
|
|
|
146 |
}
|
147 |
|
148 |
|
@@ -162,6 +212,12 @@ class ParodyWordSuggestionTool(Tool):
|
|
162 |
valid_words = []
|
163 |
invalid_words = []
|
164 |
target_phone_list = []
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
# Parse JSON string to list
|
167 |
try:
|
@@ -216,7 +272,9 @@ class ParodyWordSuggestionTool(Tool):
|
|
216 |
"rhyme_score": similarity_result["rhyme_score"],
|
217 |
"length_score": similarity_result["length_score"],
|
218 |
"stress_score": similarity_result["stress_score"],
|
219 |
-
"
|
|
|
|
|
220 |
"phones": word_phones,
|
221 |
"last_vowel": word_vowel,
|
222 |
"ending": " ".join(word_end) if word_end else "",
|
|
|
1 |
from smolagents.tools import Tool
|
|
|
2 |
import string
|
3 |
import pronouncing
|
4 |
+
import json
|
5 |
|
6 |
class ParodyWordSuggestionTool(Tool):
|
7 |
name = "parody_word_suggester"
|
8 |
description = """Suggests rhyming funny words using CMU dictionary and custom pronunciations.
|
9 |
Returns similar-sounding words that rhyme, especially focusing on common vowel sounds."""
|
10 |
+
inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'default': '0.5', 'nullable': True}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'default': None, 'nullable': True}}
|
11 |
output_type = "string"
|
12 |
VOWEL_REF = "AH,AX|UH|AE,EH|IY,IH|AO,AA|UW|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR"
|
13 |
|
|
|
71 |
return False
|
72 |
|
73 |
|
74 |
+
def _strip_common_suffix(self, phones: list) -> tuple:
|
75 |
+
"""Strip common suffixes and return base and suffix phones."""
|
76 |
+
# Initialize variables
|
77 |
+
suffix_name = ""
|
78 |
+
suffix_phones = []
|
79 |
+
phone1 = ""
|
80 |
+
phone2 = ""
|
81 |
+
|
82 |
+
# Common suffix patterns in CMU phonetic representation
|
83 |
+
SUFFIXES = {
|
84 |
+
'ING': ['IH0', 'NG'], # -ing
|
85 |
+
'ED': ['EH0', 'D'], # -ed
|
86 |
+
'ER': ['ER0'], # -er
|
87 |
+
'EST': ['EH0', 'S', 'T'], # -est
|
88 |
+
'LY': ['L', 'IY0'], # -ly
|
89 |
+
'NESS': ['N', 'EH0', 'S'], # -ness
|
90 |
+
}
|
91 |
+
|
92 |
+
for suffix_name, suffix_phones in SUFFIXES.items():
|
93 |
+
if len(phones) > len(suffix_phones):
|
94 |
+
if all(phone1.rstrip('012') == phone2.rstrip('012')
|
95 |
+
for phone1, phone2 in zip(phones[-len(suffix_phones):], suffix_phones)):
|
96 |
+
return phones[:-len(suffix_phones)], suffix_phones
|
97 |
+
|
98 |
+
return phones, []
|
99 |
+
|
100 |
+
|
101 |
def _calculate_similarity(self, word1, phones1, word2, phones2):
|
102 |
+
"""Calculate similarity score using improved metrics and suffix handling."""
|
103 |
+
# Initialize all variables first
|
104 |
+
phone_list1 = []
|
105 |
+
phone_list2 = []
|
106 |
+
base1 = []
|
107 |
+
base2 = []
|
108 |
+
suffix1 = []
|
109 |
+
suffix2 = []
|
110 |
word_vowel = None
|
111 |
word_end = []
|
112 |
target_vowel = None
|
113 |
target_end = []
|
114 |
+
base_length_diff = 0
|
115 |
+
max_base_length = 0
|
116 |
length_score = 0.0
|
117 |
rhyme_score = 0.0
|
118 |
stress_score = 0.0
|
119 |
+
suffix_score = 0.0
|
120 |
word_end_clean = []
|
121 |
target_end_clean = []
|
|
|
122 |
common_length = 0
|
123 |
+
matched = 0
|
124 |
+
stress1 = ""
|
125 |
+
stress2 = ""
|
126 |
+
similarity = 0.0
|
127 |
+
result1 = (None, [])
|
128 |
+
result2 = (None, [])
|
129 |
|
130 |
+
# Main logic
|
131 |
phone_list1 = phones1.split()
|
132 |
phone_list2 = phones2.split()
|
133 |
|
134 |
+
# Strip common suffixes first
|
135 |
+
base1, suffix1 = self._strip_common_suffix(phone_list1)
|
136 |
+
base2, suffix2 = self._strip_common_suffix(phone_list2)
|
137 |
+
|
138 |
+
# Calculate base word similarity
|
139 |
+
base_length_diff = abs(len(base1) - len(base2))
|
140 |
+
max_base_length = max(len(base1), len(base2))
|
141 |
+
length_score = 1.0 if base_length_diff == 0 else 1.0 - (base_length_diff / max_base_length)
|
142 |
|
143 |
+
# Get last syllable components of base words
|
144 |
+
result1 = self._get_last_syllable(base1)
|
145 |
+
result2 = self._get_last_syllable(base2)
|
146 |
word_vowel, word_end = result1
|
147 |
target_vowel, target_end = result2
|
148 |
|
|
|
154 |
target_end_clean = self._strip_stress(target_end)
|
155 |
|
156 |
if word_end_clean == target_end_clean:
|
157 |
+
if word_vowel.rstrip('012') == target_vowel.rstrip('012'):
|
158 |
+
rhyme_score = 1.0
|
159 |
+
else:
|
160 |
+
rhyme_score = 0.7 # Penalize different vowels in same group
|
161 |
else:
|
|
|
162 |
common_length = min(len(word_end_clean), len(target_end_clean))
|
163 |
matched = 0
|
164 |
for i in range(common_length):
|
165 |
if word_end_clean[i] == target_end_clean[i]:
|
166 |
matched += 1
|
167 |
+
rhyme_score = 0.3 * (matched / max(len(word_end_clean), len(target_end_clean)))
|
168 |
|
169 |
+
# Calculate stress pattern similarity using base words
|
170 |
import pronouncing
|
171 |
+
stress1 = pronouncing.stresses(' '.join(base1))
|
172 |
+
stress2 = pronouncing.stresses(' '.join(base2))
|
173 |
+
stress_score = 1.0 if stress1 == stress2 else 0.3
|
174 |
+
|
175 |
+
# Add suffix matching bonus
|
176 |
+
suffix_score = 1.0 if suffix1 == suffix2 else 0.0
|
177 |
|
178 |
+
# Weighted combination with emphasis on base word similarity
|
179 |
similarity = (
|
180 |
+
(rhyme_score * 0.6) + # Base word rhyme
|
181 |
+
(length_score * 0.1) + # Base word length
|
182 |
+
(stress_score * 0.2) + # Base word stress
|
183 |
+
(suffix_score * 0.1) # Suffix match as small bonus
|
184 |
)
|
185 |
|
|
|
186 |
similarity = min(1.0, similarity)
|
187 |
|
188 |
return {
|
|
|
190 |
"rhyme_score": round(rhyme_score, 3),
|
191 |
"length_score": round(length_score, 3),
|
192 |
"stress_score": round(stress_score, 3),
|
193 |
+
"base_word_diff": base_length_diff,
|
194 |
+
"has_common_suffix": bool(suffix1 and suffix2),
|
195 |
+
"suffix_match": suffix_score == 1.0
|
196 |
}
|
197 |
|
198 |
|
|
|
212 |
valid_words = []
|
213 |
invalid_words = []
|
214 |
target_phone_list = []
|
215 |
+
words = []
|
216 |
+
target_phones = ""
|
217 |
+
word_phones = ""
|
218 |
+
word = ""
|
219 |
+
word_phone_list = []
|
220 |
+
similarity_result = {}
|
221 |
|
222 |
# Parse JSON string to list
|
223 |
try:
|
|
|
272 |
"rhyme_score": similarity_result["rhyme_score"],
|
273 |
"length_score": similarity_result["length_score"],
|
274 |
"stress_score": similarity_result["stress_score"],
|
275 |
+
"base_word_diff": similarity_result["base_word_diff"],
|
276 |
+
"has_common_suffix": similarity_result["has_common_suffix"],
|
277 |
+
"suffix_match": similarity_result["suffix_match"],
|
278 |
"phones": word_phones,
|
279 |
"last_vowel": word_vowel,
|
280 |
"ending": " ".join(word_end) if word_end else "",
|