patruff commited on
Commit
9bc766a
·
verified ·
1 Parent(s): b77eb42

Upload tool

Browse files
Files changed (1) hide show
  1. tool.py +121 -149
tool.py CHANGED
@@ -1,27 +1,21 @@
1
  from smolagents.tools import Tool
 
2
  import pronouncing
3
- import difflib
4
  import json
5
- import string
6
 
7
  class ParodyWordSuggestionTool(Tool):
8
  name = "parody_word_suggester"
9
- description = """Suggests rhyming funny words using CMU dictionary pronunciations.
10
- Returns similar-sounding words that rhyme, especially focusing on common vowel sounds."""
11
  inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True, 'default': '0.5'}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'nullable': True, 'default': None}}
12
  output_type = "string"
13
- VOWEL_REF = "AH,AX|UH|AE,EH|IY,IH|AO,AA|UW|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR"
 
 
 
14
  CONSONANT_REF = "M,N,NG|P,B|T,D|K,G|F,V|TH,DH|S,Z|SH,ZH|L,R|W,Y"
15
 
16
- def _get_vowel_groups(self):
17
- groups = []
18
- group_strs = self.VOWEL_REF.split("|")
19
- for group_str in group_strs:
20
- groups.append(group_str.split(","))
21
- return groups
22
-
23
-
24
  def _get_consonant_groups(self):
 
25
  groups = []
26
  group_strs = self.CONSONANT_REF.split("|")
27
  for group_str in group_strs:
@@ -39,146 +33,141 @@ class ParodyWordSuggestionTool(Tool):
39
  return phones[0] if phones else None
40
 
41
 
42
- def _get_last_syllable(self, phones: list) -> tuple:
43
- """Extract the last syllable (vowel + remaining consonants)."""
44
- last_vowel_idx = -1
45
- last_vowel = None
46
- vowel_groups = self._get_vowel_groups()
47
-
48
- for i, phone in enumerate(phones):
49
- base_phone = phone.rstrip('012')
50
- for group in vowel_groups:
51
- if base_phone in group:
52
- last_vowel_idx = i
53
- last_vowel = base_phone
54
- break
55
 
56
- if last_vowel_idx == -1:
57
- return None, []
58
-
59
- remaining = phones[last_vowel_idx + 1:]
60
- return last_vowel, remaining
61
 
62
 
63
- def _strip_stress(self, phones: list) -> list:
64
- result = []
65
- for phone in phones:
66
- result.append(phone.rstrip('012'))
67
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
- def _vowels_match(self, v1: str, v2: str) -> bool:
71
- v1 = v1.rstrip('012')
72
- v2 = v2.rstrip('012')
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- if v1 == v2:
75
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- vowel_groups = self._get_vowel_groups()
78
- for group in vowel_groups:
79
- if v1 in group and v2 in group:
80
- return True
81
- return False
 
 
82
 
83
 
84
  def _calculate_similarity(self, word1, phones1, word2, phones2):
85
- """Calculate similarity with heavy emphasis on rhyming."""
86
- from difflib import SequenceMatcher
87
- import pronouncing
88
-
89
- # Initialize all variables
90
  rhyme_score = 0.0
91
- string_score = 0.0
92
- pattern_score = 0.0
93
- phone_list1 = []
94
- phone_list2 = []
95
- vowel1 = None
96
- vowel2 = None
97
- end1 = []
98
- end2 = []
99
- end1_clean = []
100
- end2_clean = []
101
- matching_consonants = 0
102
 
 
103
  phone_list1 = phones1.split()
104
  phone_list2 = phones2.split()
105
 
106
- # Get last syllables
107
- vowel1, end1 = self._get_last_syllable(phone_list1)
108
- vowel2, end2 = self._get_last_syllable(phone_list2)
109
-
110
- # Calculate rhyme score (60%)
111
- if vowel1 and vowel2:
112
- # Perfect vowel match
113
- if vowel1.rstrip('012') == vowel2.rstrip('012'):
114
- rhyme_score = 1.0
115
- # Similar vowel match
116
- elif self._vowels_match(vowel1, vowel2):
117
- rhyme_score = 0.8
118
-
119
- # Check endings
120
- if end1 and end2:
121
- end1_clean = self._strip_stress(end1)
122
- end2_clean = self._strip_stress(end2)
123
-
124
- # Perfect ending match
125
- if end1_clean == end2_clean:
126
- rhyme_score = min(1.0, rhyme_score + 0.2)
127
- # Partial ending match
128
- else:
129
- consonant_groups = self._get_consonant_groups()
130
- matching_consonants = 0
131
- for c1, c2 in zip(end1_clean, end2_clean):
132
- if c1 == c2:
133
- matching_consonants += 1
134
- else:
135
- # Check if consonants are in same group
136
- for group in consonant_groups:
137
- if c1 in group and c2 in group:
138
- matching_consonants += 0.5
139
- break
140
-
141
- if matching_consonants > 0:
142
- rhyme_score = min(1.0, rhyme_score + (0.1 * matching_consonants))
143
 
144
- # String similarity (25%)
145
- if len(word1) > 1 and len(word2) > 1:
146
- end_similarity = SequenceMatcher(None, word1[1:], word2[1:]).ratio()
147
- string_score = end_similarity
148
- else:
149
- string_score = SequenceMatcher(None, word1, word2).ratio()
150
 
151
- # Pattern/Length score (15%)
152
  if len(phone_list1) == len(phone_list2):
153
- pattern_score = 1.0
154
  else:
155
- pattern_score = 1.0 - (abs(len(phone_list1) - len(phone_list2)) / max(len(phone_list1), len(phone_list2)))
156
 
157
- # Final weighted score
158
- similarity = (
159
- (rhyme_score * 0.60) +
160
- (string_score * 0.25) +
161
- (pattern_score * 0.15)
162
- )
163
 
164
- # Extra boost for exact matches minus first letter
165
- if len(word1) == len(word2) and word1[1:] == word2[1:]:
166
- similarity = min(1.0, similarity * 1.2)
167
 
168
- # Extra penalty for very different lengths
169
- if abs(len(word1) - len(word2)) > 2:
170
- similarity *= 0.7
 
 
 
 
171
 
172
  return {
173
  "similarity": round(similarity, 3),
174
  "rhyme_score": round(rhyme_score, 3),
175
- "string_score": round(string_score, 3),
176
- "pattern_score": round(pattern_score, 3),
 
177
  "details": {
178
- "last_vowel_match": vowel1.rstrip('012') == vowel2.rstrip('012') if vowel1 and vowel2 else False,
179
- "similar_vowels": self._vowels_match(vowel1, vowel2) if vowel1 and vowel2 else False,
180
- "ending_match": " ".join(end1_clean) == " ".join(end2_clean) if end1 and end2 else False,
181
- "string_length_diff": abs(len(word1) - len(word2))
 
182
  }
183
  }
184
 
@@ -188,22 +177,16 @@ class ParodyWordSuggestionTool(Tool):
188
  import string
189
  import json
190
 
191
- # Initialize all variables
192
  target = target.lower().strip(string.punctuation)
193
  min_similarity = float(min_similarity)
194
  suggestions = []
195
- word_vowel = None
196
- word_end = []
197
- target_vowel = None
198
- target_end = []
199
  valid_words = []
200
  invalid_words = []
201
- target_phone_list = []
202
  words = []
203
  target_phones = ""
204
  word_phones = ""
205
  word = ""
206
- word_phone_list = []
207
  similarity_result = {}
208
 
209
  # Parse JSON string to list
@@ -215,7 +198,7 @@ class ParodyWordSuggestionTool(Tool):
215
  "suggestions": []
216
  }, indent=2)
217
 
218
- # Get target pronunciation using custom phones
219
  target_phones = self._get_word_phones(target, custom_phones)
220
  if not target_phones:
221
  return json.dumps({
@@ -223,9 +206,7 @@ class ParodyWordSuggestionTool(Tool):
223
  "suggestions": []
224
  }, indent=2)
225
 
226
- # Filter word list checking both CMU and custom phones
227
- valid_words = []
228
- invalid_words = []
229
  for word in words:
230
  word = word.lower().strip(string.punctuation)
231
  if self._get_word_phones(word, custom_phones):
@@ -240,9 +221,6 @@ class ParodyWordSuggestionTool(Tool):
240
  "suggestions": []
241
  }, indent=2)
242
 
243
- target_phone_list = target_phones.split()
244
- target_vowel, target_end = self._get_last_syllable(target_phone_list)
245
-
246
  # Check each word
247
  for word in valid_words:
248
  word_phones = self._get_word_phones(word, custom_phones)
@@ -250,18 +228,14 @@ class ParodyWordSuggestionTool(Tool):
250
  similarity_result = self._calculate_similarity(word, word_phones, target, target_phones)
251
 
252
  if similarity_result["similarity"] >= min_similarity:
253
- word_phone_list = word_phones.split()
254
- word_vowel, word_end = self._get_last_syllable(word_phone_list)
255
-
256
  suggestions.append({
257
  "word": word,
258
  "similarity": similarity_result["similarity"],
259
  "rhyme_score": similarity_result["rhyme_score"],
260
- "string_score": similarity_result["string_score"],
261
- "pattern_score": similarity_result["pattern_score"],
 
262
  "phones": word_phones,
263
- "last_vowel": word_vowel,
264
- "ending": " ".join(word_end) if word_end else "",
265
  "is_custom": word in custom_phones if custom_phones else False,
266
  "details": similarity_result["details"]
267
  })
@@ -272,8 +246,6 @@ class ParodyWordSuggestionTool(Tool):
272
  result = {
273
  "target": target,
274
  "target_phones": target_phones,
275
- "target_last_vowel": target_vowel,
276
- "target_ending": " ".join(target_end) if target_end else "",
277
  "invalid_words": invalid_words,
278
  "suggestions": suggestions
279
  }
 
1
  from smolagents.tools import Tool
2
+ import string
3
  import pronouncing
 
4
  import json
 
5
 
6
  class ParodyWordSuggestionTool(Tool):
7
  name = "parody_word_suggester"
8
+ description = "Suggests rhyming funny words using CMU dictionary pronunciations."
 
9
  inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True, 'default': '0.5'}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'nullable': True, 'default': None}}
10
  output_type = "string"
11
+ RHYME_WEIGHT = 0.6
12
+ PHONE_PATTERN_WEIGHT = 0.2
13
+ CHAR_DIFF_WEIGHT = 0.1
14
+ CONSONANT_WEIGHT = 0.1
15
  CONSONANT_REF = "M,N,NG|P,B|T,D|K,G|F,V|TH,DH|S,Z|SH,ZH|L,R|W,Y"
16
 
 
 
 
 
 
 
 
 
17
  def _get_consonant_groups(self):
18
+ """Get consonant similarity groups."""
19
  groups = []
20
  group_strs = self.CONSONANT_REF.split("|")
21
  for group_str in group_strs:
 
33
  return phones[0] if phones else None
34
 
35
 
36
+ def _get_primary_vowel(self, phones: list) -> str:
37
+ """Get the primary stressed vowel from phone list."""
38
+ vowel_chars = 'AEIOU' # Initialize the vowel characters set
39
+ phone_str = "" # Initialize phone string
40
+ vowel_char = ""
 
 
 
 
 
 
 
 
41
 
42
+ for phone_str in phones:
43
+ if '1' in phone_str and any(vowel_char in phone_str for vowel_char in vowel_chars):
44
+ return phone_str.rstrip('012')
45
+ return None
 
46
 
47
 
48
+ def _calculate_char_difference(self, word1: str, word2: str) -> float:
49
+ """Calculate character difference score."""
50
+ if not word1 or not word2:
51
+ return 0.0
52
+
53
+ # Initialize variables
54
+ changes = 0
55
+ char1 = ""
56
+ char2 = ""
57
+
58
+ # Count character differences
59
+ for char1, char2 in zip(word1, word2):
60
+ if char1 != char2:
61
+ changes += 1
62
+
63
+ # Add difference for length mismatch
64
+ changes += abs(len(word1) - len(word2))
65
+
66
+ # Score based on changes (0 changes = 1.0, more changes = lower score)
67
+ max_changes = max(len(word1), len(word2))
68
+ return 1.0 - (changes / max_changes) if max_changes > 0 else 0.0
69
 
70
 
71
+ def _calculate_consonant_similarity(self, phone_list1: list, phone_list2: list) -> float:
72
+ """Calculate consonant similarity score."""
73
+ # Initialize variables
74
+ consonant_score = 0.0
75
+ consonant_groups = self._get_consonant_groups()
76
+ vowel_chars = 'AEIOU'
77
+ phone_str = ""
78
+ vowel_char = ""
79
+ consonants1 = []
80
+ consonants2 = []
81
+ matches = 0
82
+ comparisons = 0
83
+ cons1 = ""
84
+ cons2 = ""
85
+ group = []
86
 
87
+ # Get consonants (non-vowel phones)
88
+ consonants1 = [phone_str for phone_str in phone_list1
89
+ if not any(vowel_char in phone_str for vowel_char in vowel_chars)]
90
+ consonants2 = [phone_str for phone_str in phone_list2
91
+ if not any(vowel_char in phone_str for vowel_char in vowel_chars)]
92
+
93
+ if not consonants1 or not consonants2:
94
+ return 0.0
95
+
96
+ # Compare each consonant
97
+ matches = 0
98
+ comparisons = min(len(consonants1), len(consonants2))
99
+
100
+ for cons1, cons2 in zip(consonants1, consonants2):
101
+ cons1 = cons1.rstrip('012')
102
+ cons2 = cons2.rstrip('012')
103
+
104
+ if cons1 == cons2:
105
+ matches += 1
106
+ continue
107
 
108
+ # Check if in same group
109
+ for group in consonant_groups:
110
+ if cons1 in group and cons2 in group:
111
+ matches += 0.5
112
+ break
113
+
114
+ return matches / comparisons if comparisons > 0 else 0.0
115
 
116
 
117
  def _calculate_similarity(self, word1, phones1, word2, phones2):
118
+ """Calculate similarity based on multiple factors."""
119
+ # Initialize scores
 
 
 
120
  rhyme_score = 0.0
121
+ phone_score = 0.0
122
+ char_diff_score = 0.0
123
+ consonant_score = 0.0
 
 
 
 
 
 
 
 
124
 
125
+ # Initialize phone lists
126
  phone_list1 = phones1.split()
127
  phone_list2 = phones2.split()
128
 
129
+ # Initialize variables for details
130
+ vowel1 = None
131
+ vowel2 = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # 1. Rhyme score (60%) - based on primary vowel
134
+ vowel1 = self._get_primary_vowel(phone_list1)
135
+ vowel2 = self._get_primary_vowel(phone_list2)
136
+ if vowel1 and vowel2 and vowel1 == vowel2:
137
+ rhyme_score = 1.0
 
138
 
139
+ # 2. Phone pattern score (20%) - based on number of phones
140
  if len(phone_list1) == len(phone_list2):
141
+ phone_score = 1.0
142
  else:
143
+ phone_score = 1.0 - (abs(len(phone_list1) - len(phone_list2)) / max(len(phone_list1), len(phone_list2)))
144
 
145
+ # 3. Character difference score (10%)
146
+ char_diff_score = self._calculate_char_difference(word1, word2)
 
 
 
 
147
 
148
+ # 4. Consonant similarity score (10%)
149
+ consonant_score = self._calculate_consonant_similarity(phone_list1, phone_list2)
 
150
 
151
+ # Combined weighted score
152
+ similarity = (
153
+ (rhyme_score * self.RHYME_WEIGHT) +
154
+ (phone_score * self.PHONE_PATTERN_WEIGHT) +
155
+ (char_diff_score * self.CHAR_DIFF_WEIGHT) +
156
+ (consonant_score * self.CONSONANT_WEIGHT)
157
+ )
158
 
159
  return {
160
  "similarity": round(similarity, 3),
161
  "rhyme_score": round(rhyme_score, 3),
162
+ "phone_score": round(phone_score, 3),
163
+ "char_diff_score": round(char_diff_score, 3),
164
+ "consonant_score": round(consonant_score, 3),
165
  "details": {
166
+ "primary_vowel1": vowel1,
167
+ "primary_vowel2": vowel2,
168
+ "phone_count1": len(phone_list1),
169
+ "phone_count2": len(phone_list2),
170
+ "char_differences": abs(len(word1) - len(word2))
171
  }
172
  }
173
 
 
177
  import string
178
  import json
179
 
180
+ # Initialize variables
181
  target = target.lower().strip(string.punctuation)
182
  min_similarity = float(min_similarity)
183
  suggestions = []
 
 
 
 
184
  valid_words = []
185
  invalid_words = []
 
186
  words = []
187
  target_phones = ""
188
  word_phones = ""
189
  word = ""
 
190
  similarity_result = {}
191
 
192
  # Parse JSON string to list
 
198
  "suggestions": []
199
  }, indent=2)
200
 
201
+ # Get target pronunciation
202
  target_phones = self._get_word_phones(target, custom_phones)
203
  if not target_phones:
204
  return json.dumps({
 
206
  "suggestions": []
207
  }, indent=2)
208
 
209
+ # Filter word list
 
 
210
  for word in words:
211
  word = word.lower().strip(string.punctuation)
212
  if self._get_word_phones(word, custom_phones):
 
221
  "suggestions": []
222
  }, indent=2)
223
 
 
 
 
224
  # Check each word
225
  for word in valid_words:
226
  word_phones = self._get_word_phones(word, custom_phones)
 
228
  similarity_result = self._calculate_similarity(word, word_phones, target, target_phones)
229
 
230
  if similarity_result["similarity"] >= min_similarity:
 
 
 
231
  suggestions.append({
232
  "word": word,
233
  "similarity": similarity_result["similarity"],
234
  "rhyme_score": similarity_result["rhyme_score"],
235
+ "phone_score": similarity_result["phone_score"],
236
+ "char_diff_score": similarity_result["char_diff_score"],
237
+ "consonant_score": similarity_result["consonant_score"],
238
  "phones": word_phones,
 
 
239
  "is_custom": word in custom_phones if custom_phones else False,
240
  "details": similarity_result["details"]
241
  })
 
246
  result = {
247
  "target": target,
248
  "target_phones": target_phones,
 
 
249
  "invalid_words": invalid_words,
250
  "suggestions": suggestions
251
  }