patruff commited on
Commit
2aab8aa
·
verified ·
1 Parent(s): 6507960

Upload tool

Browse files
Files changed (1) hide show
  1. tool.py +173 -190
tool.py CHANGED
@@ -1,17 +1,23 @@
1
  from smolagents.tools import Tool
2
  import string
3
- import pronouncing
4
  import json
 
5
 
6
  class ParodyWordSuggestionTool(Tool):
7
  name = "parody_word_suggester"
8
- description = "Suggests rhyming funny words using CMU dictionary pronunciations."
9
- inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True, 'default': '0.5'}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'nullable': True, 'default': None}}
 
10
  output_type = "string"
11
- RHYME_WEIGHT = 0.5
12
- PHONE_SEQUENCE_WEIGHT = 0.3
13
- LENGTH_WEIGHT = 0.2
14
- PHONE_GROUPS = "M,N,NG|P,B|T,D|K,G|F,V|TH,DH|S,Z|SH,ZH|L,R|W,Y|IY,IH|UW,UH|EH,AH|AO,AA|AE,AH|AY,EY|OW,UW"
 
 
 
 
 
15
 
16
  def _get_word_phones(self, word, custom_phones=None):
17
  """Get phones for a word, checking custom dictionary first."""
@@ -23,226 +29,186 @@ class ParodyWordSuggestionTool(Tool):
23
  return phones[0] if phones else None
24
 
25
 
26
- def _get_primary_vowel(self, phones: list) -> str:
27
- """Get the primary stressed vowel from phone list."""
28
- phone_str = ""
29
- vowel_char = ""
 
30
 
31
- for phone_str in phones:
32
- if '1' in phone_str and any(vowel_char in phone_str for vowel_char in 'AEIOU'):
33
- return phone_str.rstrip('012')
34
- return None
35
-
36
-
37
- def _phones_are_similar(self, phone1: str, phone2: str) -> bool:
38
- """Check if two phones are similar enough to be considered rhyming."""
39
- # Strip stress markers
40
- p1 = phone1.rstrip('012')
41
- p2 = phone2.rstrip('012')
42
- group_str = ""
43
  group = []
44
 
45
- # Exact match
46
- if p1 == p2:
47
- return True
48
-
49
- # Check similarity groups
50
- for group_str in self.PHONE_GROUPS.split('|'):
51
- group = group_str.split(',')
52
- if p1 in group and p2 in group:
53
- return True
54
-
55
- return False
56
-
57
-
58
- def _get_phone_type(self, phone: str) -> str:
59
- """Get the broad category of a phone."""
60
- # Strip stress markers
61
- phone = phone.rstrip('012')
62
- vowel_char = ""
63
-
64
- # Vowels
65
- if any(vowel_char in phone for vowel_char in 'AEIOU'):
66
- return 'vowel'
67
-
68
- # Initialize fixed sets for categories
69
- nasals = {'M', 'N', 'NG'}
70
- stops = {'P', 'B', 'T', 'D', 'K', 'G'}
71
- fricatives = {'F', 'V', 'TH', 'DH', 'S', 'Z', 'SH', 'ZH'}
72
- liquids = {'L', 'R'}
73
- glides = {'W', 'Y'}
74
 
75
- if phone in nasals:
76
- return 'nasal'
77
- if phone in stops:
78
- return 'stop'
79
- if phone in fricatives:
80
- return 'fricative'
81
- if phone in liquids:
82
- return 'liquid'
83
- if phone in glides:
84
- return 'glide'
85
 
86
- return 'other'
 
87
 
88
 
89
- def _get_rhyme_score(self, phones1: list, phones2: list) -> float:
90
- """Calculate rhyme score based on matching phones after primary stressed vowel."""
91
- # Initialize variables
92
- pos1 = -1
93
- pos2 = -1
94
- i = 0
95
  phone = ""
96
- vowel_char = ""
97
- rhyme_part1 = []
98
- rhyme_part2 = []
99
- similarity_count = 0
100
- p1 = ""
101
- p2 = ""
102
 
103
- # Find primary stressed vowel position in both words
104
- for i, phone in enumerate(phones1):
105
- if '1' in phone and any(vowel_char in phone for vowel_char in 'AEIOU'):
106
- pos1 = i
107
- break
108
-
109
- for i, phone in enumerate(phones2):
110
- if '1' in phone and any(vowel_char in phone for vowel_char in 'AEIOU'):
111
- pos2 = i
112
- break
113
-
114
- if pos1 == -1 or pos2 == -1:
115
- return 0.0
116
-
117
- # Get all phones after and including the stressed vowel
118
- rhyme_part1 = phones1[pos1:]
119
- rhyme_part2 = phones2[pos2:]
120
-
121
- # Check if lengths match
122
- if len(rhyme_part1) != len(rhyme_part2):
123
- return 0.0
124
-
125
- # Calculate similarity score for rhyming part
126
- for p1, p2 in zip(rhyme_part1, rhyme_part2):
127
- if self._phones_are_similar(p1, p2):
128
- similarity_count += 1
129
-
130
- # Return score based on how many phones were similar
131
- return similarity_count / len(rhyme_part1) if rhyme_part1 else 0.0
132
 
133
 
134
- def _calculate_phone_sequence_similarity(self, phones1: list, phones2: list) -> float:
135
- """Calculate similarity based on matching phones in sequence."""
136
- if not phones1 or not phones2:
137
- return 0.0
138
 
139
- # Initialize variables
140
- total_similarity = 0.0
141
- i = 0
142
- similarity = 0.0
143
- comparisons = max(len(phones1), len(phones2))
144
 
145
- # Compare each position
146
- for i in range(min(len(phones1), len(phones2))):
147
- similarity = self._get_phone_similarity(phones1[i], phones2[i])
148
- total_similarity += similarity
149
-
150
- return total_similarity / comparisons if comparisons > 0 else 0.0
151
-
152
-
153
- def _get_phone_similarity(self, phone1: str, phone2: str) -> float:
154
- """Calculate similarity between two phones."""
155
- # Initialize variables
156
- p1 = phone1.rstrip('012')
157
- p2 = phone2.rstrip('012')
158
- group_str = ""
159
  group = []
160
 
161
- # Exact match
162
- if p1 == p2:
163
- return 1.0
164
-
165
- # Check similarity groups
166
- for group_str in self.PHONE_GROUPS.split('|'):
167
- group = group_str.split(',')
168
- if p1 in group and p2 in group:
169
- return 0.7
170
-
171
- # Check broader categories
172
- if self._get_phone_type(p1) == self._get_phone_type(p2):
173
- return 0.3
174
-
175
- return 0.0
176
-
177
-
178
- def _calculate_length_similarity(self, phones1: list, phones2: list) -> float:
179
- """Calculate similarity based on phone length."""
180
- max_length = max(len(phones1), len(phones2))
181
- length_diff = abs(len(phones1) - len(phones2))
182
- return 1.0 - (length_diff / max_length) if max_length > 0 else 0.0
183
 
184
 
185
  def _calculate_similarity(self, word1, phones1, word2, phones2):
186
- """Calculate similarity based on multiple factors."""
187
- # Initialize variables
188
  phone_list1 = phones1.split()
189
  phone_list2 = phones2.split()
 
 
190
  rhyme_score = 0.0
191
- phone_sequence_score = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  length_score = 0.0
 
 
 
193
  similarity = 0.0
 
 
194
 
195
- # Get rhyme score using new method
196
- rhyme_score = self._get_rhyme_score(phone_list1, phone_list2)
 
 
 
197
 
198
- # If rhyme score is too low (e.g. below 0.8), consider it a non-rhyme
199
- if rhyme_score < 0.8:
200
- return {
201
- "similarity": 0.0,
202
- "rhyme_score": 0.0,
203
- "phone_sequence_score": 0.0,
204
- "length_score": 0.0,
205
- "details": {
206
- "primary_vowel1": self._get_primary_vowel(phone_list1),
207
- "primary_vowel2": self._get_primary_vowel(phone_list2),
208
- "phone_count1": len(phone_list1),
209
- "phone_count2": len(phone_list2),
210
- "matching_phones": 0
211
- }
212
- }
 
 
 
 
213
 
214
- # Calculate other scores only if words rhyme closely enough
215
- phone_sequence_score = self._calculate_phone_sequence_similarity(phone_list1, phone_list2)
216
- length_score = self._calculate_length_similarity(phone_list1, phone_list2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- # Combined weighted score
 
 
 
 
 
 
219
  similarity = (
220
- (rhyme_score * self.RHYME_WEIGHT) +
221
- (phone_sequence_score * self.PHONE_SEQUENCE_WEIGHT) +
222
- (length_score * self.LENGTH_WEIGHT)
 
223
  )
224
 
 
 
 
 
 
 
 
225
  return {
226
  "similarity": round(similarity, 3),
227
  "rhyme_score": round(rhyme_score, 3),
228
- "phone_sequence_score": round(phone_sequence_score, 3),
229
  "length_score": round(length_score, 3),
230
- "details": {
231
- "primary_vowel1": self._get_primary_vowel(phone_list1),
232
- "primary_vowel2": self._get_primary_vowel(phone_list2),
233
- "phone_count1": len(phone_list1),
234
- "phone_count2": len(phone_list2),
235
- "matching_phones": round(phone_sequence_score * len(phone_list1))
236
- }
237
  }
238
 
239
 
240
- def forward(self, target: str, word_list_str: str, min_similarity: str = "0.5", custom_phones: dict = None) -> str:
241
  import pronouncing
242
  import string
243
  import json
244
 
245
- # Initialize variables
246
  target = target.lower().strip(string.punctuation)
247
  min_similarity = float(min_similarity)
248
  suggestions = []
@@ -250,8 +216,14 @@ class ParodyWordSuggestionTool(Tool):
250
  invalid_words = []
251
  words = []
252
  target_phones = ""
253
- word_phones = ""
 
 
254
  word = ""
 
 
 
 
255
  similarity_result = {}
256
 
257
  # Parse JSON string to list
@@ -271,6 +243,10 @@ class ParodyWordSuggestionTool(Tool):
271
  "suggestions": []
272
  }, indent=2)
273
 
 
 
 
 
274
  # Filter word list
275
  for word in words:
276
  word = word.lower().strip(string.punctuation)
@@ -293,15 +269,20 @@ class ParodyWordSuggestionTool(Tool):
293
  similarity_result = self._calculate_similarity(word, word_phones, target, target_phones)
294
 
295
  if similarity_result["similarity"] >= min_similarity:
 
 
 
296
  suggestions.append({
297
  "word": word,
298
  "similarity": similarity_result["similarity"],
299
  "rhyme_score": similarity_result["rhyme_score"],
300
- "phone_sequence_score": similarity_result["phone_sequence_score"],
301
  "length_score": similarity_result["length_score"],
 
302
  "phones": word_phones,
303
- "is_custom": word in custom_phones if custom_phones else False,
304
- "details": similarity_result["details"]
 
305
  })
306
 
307
  # Sort by similarity score descending
@@ -310,6 +291,8 @@ class ParodyWordSuggestionTool(Tool):
310
  result = {
311
  "target": target,
312
  "target_phones": target_phones,
 
 
313
  "invalid_words": invalid_words,
314
  "suggestions": suggestions
315
  }
 
1
  from smolagents.tools import Tool
2
  import string
 
3
  import json
4
+ import pronouncing
5
 
6
  class ParodyWordSuggestionTool(Tool):
7
  name = "parody_word_suggester"
8
+ description = """Suggests rhyming funny words using CMU dictionary and custom pronunciations.
9
+ Returns similar-sounding words that rhyme, especially focusing on common vowel sounds."""
10
+ inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True, 'default': '0.6'}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'nullable': True, 'default': None}}
11
  output_type = "string"
12
+ VOWEL_REF = "AH,UH,AX|AE,EH|IY,IH|AO,AA|UW,UH|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR"
13
+
14
+ def _get_vowel_groups(self):
15
+ groups = []
16
+ group_strs = self.VOWEL_REF.split("|")
17
+ for group_str in group_strs:
18
+ groups.append(group_str.split(","))
19
+ return groups
20
+
21
 
22
  def _get_word_phones(self, word, custom_phones=None):
23
  """Get phones for a word, checking custom dictionary first."""
 
29
  return phones[0] if phones else None
30
 
31
 
32
+ def _get_last_syllable(self, phones: list) -> tuple:
33
+ """Extract the last syllable (vowel + remaining consonants)."""
34
+ last_vowel_idx = -1
35
+ last_vowel = None
36
+ vowel_groups = self._get_vowel_groups()
37
 
38
+ # Initialize loop variables
39
+ i = 0
40
+ phone = ""
41
+ base_phone = ""
 
 
 
 
 
 
 
 
42
  group = []
43
 
44
+ for i, phone in enumerate(phones):
45
+ base_phone = phone.rstrip('012')
46
+ for group in vowel_groups:
47
+ if base_phone in group:
48
+ last_vowel_idx = i
49
+ last_vowel = base_phone
50
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ if last_vowel_idx == -1:
53
+ return None, []
 
 
 
 
 
 
 
 
54
 
55
+ remaining = phones[last_vowel_idx + 1:]
56
+ return last_vowel, remaining
57
 
58
 
59
+ def _strip_stress(self, phones: list) -> list:
60
+ """Remove stress markers from phones."""
61
+ result = []
62
+ # Initialize loop variable
 
 
63
  phone = ""
 
 
 
 
 
 
64
 
65
+ for phone in phones:
66
+ result.append(phone.rstrip('012'))
67
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
+ def _vowels_match(self, v1: str, v2: str) -> bool:
71
+ """Check if vowels belong to the same sound group."""
72
+ v1 = v1.rstrip('012')
73
+ v2 = v2.rstrip('012')
74
 
75
+ if v1 == v2:
76
+ return True
 
 
 
77
 
78
+ # Initialize loop variables
79
+ vowel_groups = self._get_vowel_groups()
 
 
 
 
 
 
 
 
 
 
 
 
80
  group = []
81
 
82
+ for group in vowel_groups:
83
+ if v1 in group and v2 in group:
84
+ return True
85
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  def _calculate_similarity(self, word1, phones1, word2, phones2):
89
+ """Calculate similarity score using both perfect and near-rhyme detection."""
90
+ # Initialize all variables
91
  phone_list1 = phones1.split()
92
  phone_list2 = phones2.split()
93
+
94
+ # Variables for rhyme scoring
95
  rhyme_score = 0.0
96
+ word_vowel = None
97
+ word_end = []
98
+ target_vowel = None
99
+ target_end = []
100
+ word_end_clean = []
101
+ target_end_clean = []
102
+ common_length = 0
103
+ matched = 0
104
+ i = 0
105
+
106
+ # Variables for near-rhyme scoring
107
+ near_rhyme_score = 0.0
108
+ consonants1 = []
109
+ consonants2 = []
110
+ matches = 0
111
+
112
+ # Variables for length and stress scoring
113
+ phone_diff = 0
114
+ max_phones = 0
115
  length_score = 0.0
116
+ stress_score = 0.0
117
+ stress1 = ""
118
+ stress2 = ""
119
  similarity = 0.0
120
+ p = ""
121
+ v = ""
122
 
123
+ # Get last syllable components
124
+ result1 = self._get_last_syllable(phone_list1)
125
+ result2 = self._get_last_syllable(phone_list2)
126
+ word_vowel, word_end = result1
127
+ target_vowel, target_end = result2
128
 
129
+ # Perfect rhyme check (60% of score)
130
+ if word_vowel and target_vowel:
131
+ if self._vowels_match(word_vowel, target_vowel):
132
+ word_end_clean = self._strip_stress(word_end)
133
+ target_end_clean = self._strip_stress(target_end)
134
+
135
+ if word_end_clean == target_end_clean:
136
+ rhyme_score = 1.0
137
+ else:
138
+ # Partial rhyme based on ending similarity
139
+ common_length = min(len(word_end_clean), len(target_end_clean))
140
+ matched = 0
141
+ for i in range(common_length):
142
+ if word_end_clean[i] == target_end_clean[i]:
143
+ matched += 1
144
+ if max(len(word_end_clean), len(target_end_clean)) > 0:
145
+ rhyme_score = 0.6 * (matched / max(1, max(len(word_end_clean), len(target_end_clean))))
146
+ else:
147
+ rhyme_score = 0.0
148
 
149
+ # Near rhyme check (for words like "running"/"cunning") - 20% of score
150
+ # Check if words have similar length and pattern
151
+ if abs(len(phone_list1) - len(phone_list2)) <= 1:
152
+ # Check consonant patterns are similar
153
+ consonants1 = [p for p in self._strip_stress(phone_list1) if not any(v in p for v in 'AEIOU')]
154
+ consonants2 = [p for p in self._strip_stress(phone_list2) if not any(v in p for v in 'AEIOU')]
155
+
156
+ if len(consonants1) == len(consonants2):
157
+ matches = 0
158
+ for a, b in zip(consonants1, consonants2):
159
+ if a == b:
160
+ matches += 1
161
+ if len(consonants1) > 0:
162
+ near_rhyme_score = matches / max(1, len(consonants1))
163
+
164
+ # Additional check for -ing endings (special case for English)
165
+ if len(phone_list1) >= 3 and len(phone_list2) >= 3:
166
+ if (self._strip_stress(phone_list1[-2:]) == ['IH', 'NG'] and
167
+ self._strip_stress(phone_list2[-2:]) == ['IH', 'NG']):
168
+ near_rhyme_score = max(near_rhyme_score, 0.8) # Boost for -ing endings
169
+
170
+ # Calculate length similarity score (10% of total)
171
+ phone_diff = abs(len(phone_list1) - len(phone_list2))
172
+ max_phones = max(len(phone_list1), len(phone_list2))
173
+ length_score = 1.0 if phone_diff == 0 else 1.0 - (phone_diff / max_phones)
174
 
175
+ # Calculate stress pattern similarity (10% of total)
176
+ import pronouncing
177
+ stress1 = pronouncing.stresses(phones1)
178
+ stress2 = pronouncing.stresses(phones2)
179
+ stress_score = 1.0 if stress1 == stress2 else 0.5
180
+
181
+ # Weighted combination
182
  similarity = (
183
+ (rhyme_score * 0.6) + # Perfect rhyme (60%)
184
+ (near_rhyme_score * 0.2) + # Near rhyme (20%)
185
+ (length_score * 0.1) + # Length similarity (10%)
186
+ (stress_score * 0.1) # Stress pattern (10%)
187
  )
188
 
189
+ # Special case: Boost very similar-sounding words
190
+ if near_rhyme_score > 0.7 and length_score > 0.8 and stress_score > 0.8:
191
+ similarity = max(similarity, 0.75) # Ensure these get a high enough score
192
+
193
+ # Cap at 1.0
194
+ similarity = min(1.0, similarity)
195
+
196
  return {
197
  "similarity": round(similarity, 3),
198
  "rhyme_score": round(rhyme_score, 3),
199
+ "near_rhyme_score": round(near_rhyme_score, 3),
200
  "length_score": round(length_score, 3),
201
+ "stress_score": round(stress_score, 3),
202
+ "phone_length_difference": phone_diff
 
 
 
 
 
203
  }
204
 
205
 
206
+ def forward(self, target: str, word_list_str: str, min_similarity: str = "0.6", custom_phones: dict = None) -> str:
207
  import pronouncing
208
  import string
209
  import json
210
 
211
+ # Initialize all variables
212
  target = target.lower().strip(string.punctuation)
213
  min_similarity = float(min_similarity)
214
  suggestions = []
 
216
  invalid_words = []
217
  words = []
218
  target_phones = ""
219
+ target_phone_list = []
220
+ target_vowel = None
221
+ target_end = []
222
  word = ""
223
+ word_phones = ""
224
+ word_phone_list = []
225
+ word_vowel = None
226
+ word_end = []
227
  similarity_result = {}
228
 
229
  # Parse JSON string to list
 
243
  "suggestions": []
244
  }, indent=2)
245
 
246
+ # Parse target phones
247
+ target_phone_list = target_phones.split()
248
+ target_vowel, target_end = self._get_last_syllable(target_phone_list)
249
+
250
  # Filter word list
251
  for word in words:
252
  word = word.lower().strip(string.punctuation)
 
269
  similarity_result = self._calculate_similarity(word, word_phones, target, target_phones)
270
 
271
  if similarity_result["similarity"] >= min_similarity:
272
+ word_phone_list = word_phones.split()
273
+ word_vowel, word_end = self._get_last_syllable(word_phone_list)
274
+
275
  suggestions.append({
276
  "word": word,
277
  "similarity": similarity_result["similarity"],
278
  "rhyme_score": similarity_result["rhyme_score"],
279
+ "near_rhyme_score": similarity_result["near_rhyme_score"],
280
  "length_score": similarity_result["length_score"],
281
+ "stress_score": similarity_result["stress_score"],
282
  "phones": word_phones,
283
+ "last_vowel": word_vowel,
284
+ "ending": " ".join(word_end) if word_end else "",
285
+ "is_custom": word in custom_phones if custom_phones else False
286
  })
287
 
288
  # Sort by similarity score descending
 
291
  result = {
292
  "target": target,
293
  "target_phones": target_phones,
294
+ "target_last_vowel": target_vowel,
295
+ "target_ending": " ".join(target_end) if target_end else "",
296
  "invalid_words": invalid_words,
297
  "suggestions": suggestions
298
  }