patruff commited on
Commit
91c8f98
·
verified ·
1 Parent(s): 69a3821

Upload tool

Browse files
Files changed (1) hide show
  1. tool.py +112 -73
tool.py CHANGED
@@ -1,19 +1,17 @@
1
  from smolagents.tools import Tool
2
- import string
3
  import json
4
- import difflib
5
  import pronouncing
 
6
 
7
  class ParodyWordSuggestionTool(Tool):
8
  name = "parody_word_suggester"
9
- description = """Suggests rhyming funny words using CMU dictionary pronunciations.
10
  Returns similar-sounding words that rhyme, especially focusing on common vowel sounds."""
11
- inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True}}
12
  output_type = "string"
13
  VOWEL_REF = "AH,UH,AX|AE,EH|IY,IH|AO,AA|UW,UH|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR"
14
 
15
  def _get_vowel_groups(self):
16
- """Convert the simple string format to usable groups."""
17
  groups = []
18
  group_strs = self.VOWEL_REF.split("|")
19
  for group_str in group_strs:
@@ -21,15 +19,23 @@ class ParodyWordSuggestionTool(Tool):
21
  return groups
22
 
23
 
 
 
 
 
 
 
 
 
 
 
24
  def _get_last_syllable(self, phones: list) -> tuple:
25
  """Extract the last syllable (vowel + remaining consonants)."""
26
  last_vowel_idx = -1
27
  last_vowel = None
28
  vowel_groups = self._get_vowel_groups()
29
 
30
- # Find the last vowel
31
  for i, phone in enumerate(phones):
32
- # Strip stress markers for checking
33
  base_phone = phone.rstrip('012')
34
  for group in vowel_groups:
35
  if base_phone in group:
@@ -40,14 +46,11 @@ class ParodyWordSuggestionTool(Tool):
40
  if last_vowel_idx == -1:
41
  return None, []
42
 
43
- # Get all consonants after the vowel
44
  remaining = phones[last_vowel_idx + 1:]
45
-
46
  return last_vowel, remaining
47
 
48
 
49
  def _strip_stress(self, phones: list) -> list:
50
- """Remove stress markers from phones."""
51
  result = []
52
  for phone in phones:
53
  result.append(phone.rstrip('012'))
@@ -55,7 +58,6 @@ class ParodyWordSuggestionTool(Tool):
55
 
56
 
57
  def _vowels_match(self, v1: str, v2: str) -> bool:
58
- """Check if two vowels are in the same group."""
59
  v1 = v1.rstrip('012')
60
  v2 = v2.rstrip('012')
61
 
@@ -69,14 +71,87 @@ class ParodyWordSuggestionTool(Tool):
69
  return False
70
 
71
 
72
- def forward(self, target: str, word_list_str: str, min_similarity: str = "0.5") -> str:
73
- """Get rhyming word suggestions."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  import pronouncing
75
  import string
76
  import json
77
- from difflib import SequenceMatcher
78
 
79
- # Initialize variables
80
  target = target.lower().strip(string.punctuation)
81
  min_similarity = float(min_similarity)
82
  suggestions = []
@@ -84,6 +159,9 @@ class ParodyWordSuggestionTool(Tool):
84
  word_end = []
85
  target_vowel = None
86
  target_end = []
 
 
 
87
 
88
  # Parse JSON string to list
89
  try:
@@ -95,92 +173,54 @@ class ParodyWordSuggestionTool(Tool):
95
  }, indent=2)
96
 
97
  # Get target pronunciation
98
- target_phones = pronouncing.phones_for_word(target)
99
  if not target_phones:
100
  return json.dumps({
101
- "error": f"Target word '{target}' not found in CMU dictionary",
102
  "suggestions": []
103
  }, indent=2)
104
 
105
- # Filter word list to only words in CMU dictionary
106
  valid_words = []
107
  invalid_words = []
108
  for word in words:
109
  word = word.lower().strip(string.punctuation)
110
- if pronouncing.phones_for_word(word):
111
  valid_words.append(word)
112
  else:
113
  invalid_words.append(word)
114
 
115
  if not valid_words:
116
  return json.dumps({
117
- "error": "No valid words found in CMU dictionary",
118
  "invalid_words": invalid_words,
119
  "suggestions": []
120
  }, indent=2)
121
 
122
- target_phones = target_phones[0]
123
  target_phone_list = target_phones.split()
124
  target_vowel, target_end = self._get_last_syllable(target_phone_list)
125
 
126
  # Check each word
127
  for word in valid_words:
128
- phones = pronouncing.phones_for_word(word)
129
- if phones:
130
- word_phones = phones[0]
131
- word_phone_list = word_phones.split()
132
- word_vowel, word_end = self._get_last_syllable(word_phone_list)
133
-
134
- # 1. Rhyme score (most important - 60%)
135
- rhyme_score = 0.0
136
- if word_vowel and target_vowel:
137
- # Check if the vowels are similar
138
- if self._vowels_match(word_vowel, target_vowel):
139
- # Check if endings match (ignoring stress numbers)
140
- word_end_clean = self._strip_stress(word_end)
141
- target_end_clean = self._strip_stress(target_end)
142
-
143
- if word_end_clean == target_end_clean:
144
- rhyme_score = 1.0
145
- # Extra boost for exact match minus first letter
146
- if len(word) == len(target) and word[1:] == target[1:]:
147
- rhyme_score = 1.2
148
- else:
149
- rhyme_score = 0.6
150
-
151
- # 2. Syllable match (25%)
152
- target_syl = pronouncing.syllable_count(target_phones)
153
- word_syl = pronouncing.syllable_count(word_phones)
154
- syllable_score = 1.0 if target_syl == word_syl else 0.0
155
-
156
- # 3. String similarity (15%)
157
- # Higher weight for end of word similarity
158
- if len(word) > 1 and len(target) > 1:
159
- end_similarity = SequenceMatcher(None, word[1:], target[1:]).ratio()
160
- string_similarity = end_similarity
161
- else:
162
- string_similarity = SequenceMatcher(None, target, word).ratio()
163
-
164
- # Combined score
165
- similarity = (rhyme_score * 0.6) + (syllable_score * 0.25) + (string_similarity * 0.15)
166
 
167
- if similarity >= min_similarity:
 
 
 
168
  suggestions.append({
169
  "word": word,
170
- "similarity": round(similarity, 3),
171
- "rhyme_match": rhyme_score > 0,
172
- "rhyme_score": round(rhyme_score, 3),
173
- "syllable_match": syllable_score == 1.0,
174
- "string_similarity": round(string_similarity, 3),
175
- "syllables": word_syl,
176
  "phones": word_phones,
177
  "last_vowel": word_vowel,
178
  "ending": " ".join(word_end) if word_end else "",
179
- "debug_info": {
180
- "word_end_clean": word_end_clean,
181
- "target_end_clean": target_end_clean,
182
- "exact_match": word_end_clean == target_end_clean
183
- }
184
  })
185
 
186
  # Sort by similarity score descending
@@ -188,11 +228,10 @@ class ParodyWordSuggestionTool(Tool):
188
 
189
  result = {
190
  "target": target,
191
- "target_syllables": pronouncing.syllable_count(target_phones),
192
  "target_phones": target_phones,
193
  "target_last_vowel": target_vowel,
194
  "target_ending": " ".join(target_end) if target_end else "",
195
- "invalid_words": invalid_words, # List of words not in CMU
196
  "suggestions": suggestions
197
  }
198
 
 
1
  from smolagents.tools import Tool
 
2
  import json
 
3
  import pronouncing
4
+ import string
5
 
6
  class ParodyWordSuggestionTool(Tool):
7
  name = "parody_word_suggester"
8
+ description = """Suggests rhyming funny words using CMU dictionary and custom pronunciations.
9
  Returns similar-sounding words that rhyme, especially focusing on common vowel sounds."""
10
+ inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'nullable': True}}
11
  output_type = "string"
12
  VOWEL_REF = "AH,UH,AX|AE,EH|IY,IH|AO,AA|UW,UH|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR"
13
 
14
  def _get_vowel_groups(self):
 
15
  groups = []
16
  group_strs = self.VOWEL_REF.split("|")
17
  for group_str in group_strs:
 
19
  return groups
20
 
21
 
22
+ def _get_word_phones(self, word, custom_phones=None):
23
+ """Get phones for a word, checking custom dictionary first."""
24
+ if custom_phones and word in custom_phones:
25
+ return custom_phones[word]["primary_phones"]
26
+
27
+ import pronouncing
28
+ phones = pronouncing.phones_for_word(word)
29
+ return phones[0] if phones else None
30
+
31
+
32
  def _get_last_syllable(self, phones: list) -> tuple:
33
  """Extract the last syllable (vowel + remaining consonants)."""
34
  last_vowel_idx = -1
35
  last_vowel = None
36
  vowel_groups = self._get_vowel_groups()
37
 
 
38
  for i, phone in enumerate(phones):
 
39
  base_phone = phone.rstrip('012')
40
  for group in vowel_groups:
41
  if base_phone in group:
 
46
  if last_vowel_idx == -1:
47
  return None, []
48
 
 
49
  remaining = phones[last_vowel_idx + 1:]
 
50
  return last_vowel, remaining
51
 
52
 
53
  def _strip_stress(self, phones: list) -> list:
 
54
  result = []
55
  for phone in phones:
56
  result.append(phone.rstrip('012'))
 
58
 
59
 
60
  def _vowels_match(self, v1: str, v2: str) -> bool:
 
61
  v1 = v1.rstrip('012')
62
  v2 = v2.rstrip('012')
63
 
 
71
  return False
72
 
73
 
74
+ def _calculate_similarity(self, word1, phones1, word2, phones2):
75
+ """Calculate similarity score using improved metrics."""
76
+ # Initialize all variables
77
+ word_vowel = None
78
+ word_end = []
79
+ target_vowel = None
80
+ target_end = []
81
+ phone_diff = 0
82
+ max_phones = 0
83
+ length_score = 0.0
84
+ rhyme_score = 0.0
85
+ stress_score = 0.0
86
+ i = 0 # For loop counter
87
+ word_end_clean = []
88
+ target_end_clean = []
89
+ matched = 0
90
+ common_length = 0
91
+
92
+ phone_list1 = phones1.split()
93
+ phone_list2 = phones2.split()
94
+
95
+ # Calculate length similarity score
96
+ phone_diff = abs(len(phone_list1) - len(phone_list2))
97
+ max_phones = max(len(phone_list1), len(phone_list2))
98
+ length_score = 1.0 if phone_diff == 0 else 1.0 - (phone_diff / max_phones)
99
+
100
+ # Get last syllable components
101
+ result1 = self._get_last_syllable(phone_list1)
102
+ result2 = self._get_last_syllable(phone_list2)
103
+ word_vowel, word_end = result1
104
+ target_vowel, target_end = result2
105
+
106
+ # Calculate rhyme score
107
+ rhyme_score = 0.0
108
+ if word_vowel and target_vowel:
109
+ if self._vowels_match(word_vowel, target_vowel):
110
+ word_end_clean = self._strip_stress(word_end)
111
+ target_end_clean = self._strip_stress(target_end)
112
+
113
+ if word_end_clean == target_end_clean:
114
+ rhyme_score = 1.0
115
+ else:
116
+ # Partial rhyme based on ending similarity
117
+ common_length = min(len(word_end_clean), len(target_end_clean))
118
+ matched = 0
119
+ for i in range(common_length):
120
+ if word_end_clean[i] == target_end_clean[i]:
121
+ matched += 1
122
+ rhyme_score = 0.6 * (matched / max(len(word_end_clean), len(target_end_clean)))
123
+
124
+ # Calculate stress pattern similarity
125
+ import pronouncing
126
+ stress1 = pronouncing.stresses(phones1)
127
+ stress2 = pronouncing.stresses(phones2)
128
+ stress_score = 1.0 if stress1 == stress2 else 0.5
129
+
130
+ # Weighted combination (60% rhyme, 30% length, 10% stress)
131
+ similarity = (
132
+ (rhyme_score * 0.6) +
133
+ (length_score * 0.3) +
134
+ (stress_score * 0.1)
135
+ )
136
+
137
+ # Cap at 1.0
138
+ similarity = min(1.0, similarity)
139
+
140
+ return {
141
+ "similarity": round(similarity, 3),
142
+ "rhyme_score": round(rhyme_score, 3),
143
+ "length_score": round(length_score, 3),
144
+ "stress_score": round(stress_score, 3),
145
+ "phone_length_difference": phone_diff
146
+ }
147
+
148
+
149
+ def forward(self, target: str, word_list_str: str, min_similarity: str = "0.5", custom_phones: dict = None) -> str:
150
  import pronouncing
151
  import string
152
  import json
 
153
 
154
+ # Initialize all variables
155
  target = target.lower().strip(string.punctuation)
156
  min_similarity = float(min_similarity)
157
  suggestions = []
 
159
  word_end = []
160
  target_vowel = None
161
  target_end = []
162
+ valid_words = []
163
+ invalid_words = []
164
+ target_phone_list = []
165
 
166
  # Parse JSON string to list
167
  try:
 
173
  }, indent=2)
174
 
175
  # Get target pronunciation
176
+ target_phones = self._get_word_phones(target, custom_phones)
177
  if not target_phones:
178
  return json.dumps({
179
+ "error": f"Target word '{target}' not found in dictionary or custom phones",
180
  "suggestions": []
181
  }, indent=2)
182
 
183
+ # Filter word list
184
  valid_words = []
185
  invalid_words = []
186
  for word in words:
187
  word = word.lower().strip(string.punctuation)
188
+ if self._get_word_phones(word, custom_phones):
189
  valid_words.append(word)
190
  else:
191
  invalid_words.append(word)
192
 
193
  if not valid_words:
194
  return json.dumps({
195
+ "error": "No valid words found in dictionary or custom phones",
196
  "invalid_words": invalid_words,
197
  "suggestions": []
198
  }, indent=2)
199
 
 
200
  target_phone_list = target_phones.split()
201
  target_vowel, target_end = self._get_last_syllable(target_phone_list)
202
 
203
  # Check each word
204
  for word in valid_words:
205
+ word_phones = self._get_word_phones(word, custom_phones)
206
+ if word_phones:
207
+ similarity_result = self._calculate_similarity(word, word_phones, target, target_phones)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
+ if similarity_result["similarity"] >= min_similarity:
210
+ word_phone_list = word_phones.split()
211
+ word_vowel, word_end = self._get_last_syllable(word_phone_list)
212
+
213
  suggestions.append({
214
  "word": word,
215
+ "similarity": similarity_result["similarity"],
216
+ "rhyme_score": similarity_result["rhyme_score"],
217
+ "length_score": similarity_result["length_score"],
218
+ "stress_score": similarity_result["stress_score"],
219
+ "phone_length_difference": similarity_result["phone_length_difference"],
 
220
  "phones": word_phones,
221
  "last_vowel": word_vowel,
222
  "ending": " ".join(word_end) if word_end else "",
223
+ "is_custom": word in custom_phones if custom_phones else False
 
 
 
 
224
  })
225
 
226
  # Sort by similarity score descending
 
228
 
229
  result = {
230
  "target": target,
 
231
  "target_phones": target_phones,
232
  "target_last_vowel": target_vowel,
233
  "target_ending": " ".join(target_end) if target_end else "",
234
+ "invalid_words": invalid_words,
235
  "suggestions": suggestions
236
  }
237