patruff commited on
Commit
84640d3
·
verified ·
1 Parent(s): d7ee0e9

Upload tool

Browse files
Files changed (1) hide show
  1. tool.py +253 -109
tool.py CHANGED
@@ -1,7 +1,7 @@
1
  from smolagents.tools import Tool
2
- import string
3
- import json
4
  import pronouncing
 
 
5
 
6
  class WordPhoneTool(Tool):
7
  name = "word_phonetic_analyzer"
@@ -9,13 +9,23 @@ class WordPhoneTool(Tool):
9
  Can also compare two words for phonetic similarity and rhyming."""
10
  inputs = {'word': {'type': 'string', 'description': 'Primary word to analyze for pronunciation patterns'}, 'compare_to': {'type': 'string', 'description': 'Optional word to compare against for similarity scoring', 'nullable': True}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'nullable': True}}
11
  output_type = "string"
12
- VOWEL_REF = "AH,UH,AX|AE,EH|IY,IH|AO,AA|UW,UH|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR"
 
 
 
13
 
14
  def _get_vowel_groups(self):
 
15
  groups = []
16
- group_strs = self.VOWEL_REF.split("|")
17
- for group_str in group_strs:
18
- groups.append(group_str.split(","))
 
 
 
 
 
 
19
  return groups
20
 
21
 
@@ -29,24 +39,112 @@ class WordPhoneTool(Tool):
29
  return phones[0] if phones else None
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def _get_last_syllable(self, phones):
 
33
  last_vowel_idx = -1
34
  last_vowel = None
35
- vowel_groups = self._get_vowel_groups()
36
 
37
- for i in range(len(phones)):
38
- phone = phones[i]
39
- base_phone = ""
40
- for j in range(len(phone)):
41
- char = phone[j]
42
- if char not in "012":
43
- base_phone += char
44
-
45
- for group in vowel_groups:
46
- if base_phone in group:
47
- last_vowel_idx = i
48
- last_vowel = base_phone
49
  break
 
 
 
50
 
51
  if last_vowel_idx == -1:
52
  return None, []
@@ -59,122 +157,153 @@ class WordPhoneTool(Tool):
59
 
60
 
61
  def _strip_stress(self, phones):
 
62
  result = []
63
  for phone in phones:
64
- stripped = ""
65
- for char in phone:
66
- if char not in "012":
67
- stripped += char
68
  result.append(stripped)
69
  return result
70
 
71
 
72
  def _vowels_match(self, v1, v2):
73
- v1_stripped = ""
74
- v2_stripped = ""
75
-
76
- for char in v1:
77
- if char not in "012":
78
- v1_stripped += char
79
-
80
- for char in v2:
81
- if char not in "012":
82
- v2_stripped += char
83
 
84
  if v1_stripped == v2_stripped:
85
  return True
86
 
87
- vowel_groups = self._get_vowel_groups()
88
- for group in vowel_groups:
89
  if v1_stripped in group and v2_stripped in group:
90
  return True
91
  return False
92
 
93
 
94
- def _calculate_similarity(self, word1, phones1, word2, phones2):
95
- import pronouncing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Initialize variables before use
98
- last_vowel1 = None
99
- last_vowel2 = None
100
- word1_end = []
101
- word2_end = []
102
- matched = 0
103
- common_length = 0
104
- end1_clean = []
105
- end2_clean = []
106
- i = 0 # Initialize i for loop variable
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  phone_list1 = phones1.split()
109
  phone_list2 = phones2.split()
110
 
111
- # Get last syllable components
112
- result1 = self._get_last_syllable(phone_list1)
113
- result2 = self._get_last_syllable(phone_list2)
114
- last_vowel1, word1_end = result1
115
- last_vowel2, word2_end = result2
116
 
117
- # Calculate length similarity score first
118
- phone_diff = abs(len(phone_list1) - len(phone_list2))
119
- max_phones = max(len(phone_list1), len(phone_list2))
120
- length_score = 1.0 if phone_diff == 0 else 1.0 - (phone_diff / max_phones)
121
 
122
- # Calculate rhyme score (most important)
123
- rhyme_score = 0.0
124
- if last_vowel1 and last_vowel2:
125
- if self._vowels_match(last_vowel1, last_vowel2):
126
- end1_clean = self._strip_stress(word1_end)
127
- end2_clean = self._strip_stress(word2_end)
128
-
129
- if end1_clean == end2_clean:
130
- rhyme_score = 1.0 # Perfect rhyme, capped at 1.0
131
- else:
132
- # Partial rhyme based on ending similarity
133
- common_length = min(len(end1_clean), len(end2_clean))
134
- matched = 0
135
- for i in range(common_length):
136
- if end1_clean[i] == end2_clean[i]:
137
- matched += 1
138
- rhyme_score = 0.6 * (matched / max(len(end1_clean), len(end2_clean)))
139
-
140
- # Calculate stress pattern similarity
141
- stress1 = pronouncing.stresses(phones1)
142
- stress2 = pronouncing.stresses(phones2)
143
- stress_score = 1.0 if stress1 == stress2 else 0.5
144
 
145
- # Weighted combination prioritizing rhyming and length
146
- total_similarity = (
147
- (rhyme_score * 0.6) + # Rhyming most important (60%)
148
- (length_score * 0.3) + # Length similarity next (30%)
149
- (stress_score * 0.1) # Stress pattern least important (10%)
150
  )
151
 
152
- # Ensure total similarity is capped at 1.0
153
- total_similarity = min(1.0, total_similarity)
154
-
155
  return {
156
- "similarity": round(total_similarity, 3),
157
  "rhyme_score": round(rhyme_score, 3),
 
158
  "length_score": round(length_score, 3),
159
- "stress_score": round(stress_score, 3),
160
- "phone_length_difference": phone_diff
 
 
 
 
 
 
 
161
  }
162
 
163
 
 
 
 
 
 
 
164
  def forward(self, word, compare_to=None, custom_phones=None):
165
  import json
166
  import string
167
  import pronouncing
168
 
169
- # Initialize variables before use
170
- word_last_vowel = None
171
- compare_last_vowel = None
172
- word_end = []
173
- compare_end = []
174
- is_rhyme = False
175
-
176
- word_clean = word.lower()
177
- word_clean = word_clean.strip(string.punctuation)
178
  primary_phones = self._get_word_phones(word_clean, custom_phones)
179
 
180
  if not primary_phones:
@@ -192,13 +321,13 @@ class WordPhoneTool(Tool):
192
  'syllable_count': pronouncing.syllable_count(primary_phones),
193
  'phones': primary_phones.split(),
194
  'stresses': pronouncing.stresses(primary_phones),
195
- 'phone_count': len(primary_phones.split())
 
196
  }
197
  }
198
 
199
  if compare_to:
200
- compare_clean = compare_to.lower()
201
- compare_clean = compare_clean.strip(string.punctuation)
202
  compare_phones = self._get_word_phones(compare_clean, custom_phones)
203
 
204
  if not compare_phones:
@@ -206,13 +335,27 @@ class WordPhoneTool(Tool):
206
  'error': f'Comparison word "{compare_clean}" not found in dictionary or custom phones'
207
  }
208
  else:
209
- # Get rhyme components
 
 
 
 
 
 
210
  word_result = self._get_last_syllable(primary_phones.split())
211
  compare_result = self._get_last_syllable(compare_phones.split())
212
- word_last_vowel, word_end = word_result
213
- compare_last_vowel, compare_end = compare_result
214
 
215
- # Calculate if words rhyme
 
 
 
 
 
 
 
 
 
 
216
  if word_last_vowel and compare_last_vowel:
217
  if self._vowels_match(word_last_vowel, compare_last_vowel):
218
  word_end_clean = self._strip_stress(word_end)
@@ -220,7 +363,7 @@ class WordPhoneTool(Tool):
220
  if word_end_clean == compare_end_clean:
221
  is_rhyme = True
222
 
223
- # Calculate detailed comparison stats
224
  word_syl_count = pronouncing.syllable_count(primary_phones)
225
  compare_syl_count = pronouncing.syllable_count(compare_phones)
226
 
@@ -230,10 +373,11 @@ class WordPhoneTool(Tool):
230
  'syllable_count': compare_syl_count,
231
  'phones': compare_phones.split(),
232
  'stresses': pronouncing.stresses(compare_phones),
233
- 'phone_count': len(compare_phones.split())
 
234
  },
235
  'comparison_stats': {
236
- 'is_rhyme': is_rhyme,
237
  'same_syllable_count': word_syl_count == compare_syl_count,
238
  'same_stress_pattern': pronouncing.stresses(primary_phones) == pronouncing.stresses(compare_phones),
239
  'syllable_difference': abs(word_syl_count - compare_syl_count),
@@ -241,7 +385,7 @@ class WordPhoneTool(Tool):
241
  }
242
  }
243
 
244
- # Calculate detailed similarity scores
245
  similarity_result = self._calculate_similarity(
246
  word_clean, primary_phones,
247
  compare_clean, compare_phones
 
1
  from smolagents.tools import Tool
 
 
2
  import pronouncing
3
+ import json
4
+ import string
5
 
6
  class WordPhoneTool(Tool):
7
  name = "word_phonetic_analyzer"
 
9
  Can also compare two words for phonetic similarity and rhyming."""
10
  inputs = {'word': {'type': 'string', 'description': 'Primary word to analyze for pronunciation patterns'}, 'compare_to': {'type': 'string', 'description': 'Optional word to compare against for similarity scoring', 'nullable': True}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'nullable': True}}
11
  output_type = "string"
12
+ RHYME_WEIGHT = 0.6
13
+ PHONE_SEQUENCE_WEIGHT = 0.3
14
+ LENGTH_WEIGHT = 0.1
15
+ PHONE_GROUPS = "M,N,NG|P,B|T,D|K,G|F,V|TH,DH|S,Z|SH,ZH|L,R|W,Y|IY,IH|UW,UH|EH,AH|AO,AA|AE,AH|AY,EY|OW,UW"
16
 
17
  def _get_vowel_groups(self):
18
+ """Get vowel groups for comparison."""
19
  groups = []
20
+ vowel_parts = self.PHONE_GROUPS.split('|')
21
+ for part in vowel_parts:
22
+ is_vowel = False
23
+ for vowel in 'AEIOU':
24
+ if vowel in part:
25
+ is_vowel = True
26
+ break
27
+ if is_vowel:
28
+ groups.append(part.split(','))
29
  return groups
30
 
31
 
 
39
  return phones[0] if phones else None
40
 
41
 
42
+ def _get_primary_vowel(self, phones):
43
+ """Get the primary stressed vowel from phone list."""
44
+ for phone_str in phones:
45
+ is_vowel_with_stress = False
46
+ if '1' in phone_str:
47
+ for vowel in 'AEIOU':
48
+ if vowel in phone_str:
49
+ is_vowel_with_stress = True
50
+ break
51
+ if is_vowel_with_stress:
52
+ return phone_str.rstrip('012')
53
+ return None
54
+
55
+
56
+ def _get_phone_type(self, phone):
57
+ """Get the broad category of a phone."""
58
+ # Strip stress markers
59
+ phone = phone.rstrip('012')
60
+
61
+ # Vowels
62
+ is_vowel = False
63
+ for vowel in 'AEIOU':
64
+ if vowel in phone:
65
+ is_vowel = True
66
+ break
67
+ if is_vowel:
68
+ return 'vowel'
69
+
70
+ # Initialize fixed sets for categories
71
+ nasals = {'M', 'N', 'NG'}
72
+ stops = {'P', 'B', 'T', 'D', 'K', 'G'}
73
+ fricatives = {'F', 'V', 'TH', 'DH', 'S', 'Z', 'SH', 'ZH'}
74
+ liquids = {'L', 'R'}
75
+ glides = {'W', 'Y'}
76
+
77
+ if phone in nasals:
78
+ return 'nasal'
79
+ if phone in stops:
80
+ return 'stop'
81
+ if phone in fricatives:
82
+ return 'fricative'
83
+ if phone in liquids:
84
+ return 'liquid'
85
+ if phone in glides:
86
+ return 'glide'
87
+
88
+ return 'other'
89
+
90
+
91
+ def _phones_are_similar(self, phone1, phone2):
92
+ """Check if two phones are similar enough to be considered rhyming."""
93
+ # Strip stress markers
94
+ p1 = phone1.rstrip('012')
95
+ p2 = phone2.rstrip('012')
96
+
97
+ # Exact match
98
+ if p1 == p2:
99
+ return True
100
+
101
+ # Check similarity groups
102
+ for group_str in self.PHONE_GROUPS.split('|'):
103
+ group = group_str.split(',')
104
+ if p1 in group and p2 in group:
105
+ return True
106
+
107
+ return False
108
+
109
+
110
+ def _get_phone_similarity(self, phone1, phone2):
111
+ """Calculate similarity between two phones."""
112
+ # Initialize variables
113
+ p1 = phone1.rstrip('012')
114
+ p2 = phone2.rstrip('012')
115
+
116
+ # Exact match
117
+ if p1 == p2:
118
+ return 1.0
119
+
120
+ # Check similarity groups
121
+ for group_str in self.PHONE_GROUPS.split('|'):
122
+ group = group_str.split(',')
123
+ if p1 in group and p2 in group:
124
+ return 0.7
125
+
126
+ # Check broader categories
127
+ if self._get_phone_type(p1) == self._get_phone_type(p2):
128
+ return 0.3
129
+
130
+ return 0.0
131
+
132
+
133
  def _get_last_syllable(self, phones):
134
+ """Get the last vowel and remaining consonants (for rhyming analysis)."""
135
  last_vowel_idx = -1
136
  last_vowel = None
 
137
 
138
+ for i, phone in enumerate(phones):
139
+ base_phone = phone.rstrip('012')
140
+ is_vowel = False
141
+ for vowel in 'AEIOU':
142
+ if vowel in base_phone:
143
+ is_vowel = True
 
 
 
 
 
 
144
  break
145
+ if is_vowel:
146
+ last_vowel_idx = i
147
+ last_vowel = base_phone
148
 
149
  if last_vowel_idx == -1:
150
  return None, []
 
157
 
158
 
159
  def _strip_stress(self, phones):
160
+ """Remove stress markers from phones."""
161
  result = []
162
  for phone in phones:
163
+ stripped = phone.rstrip('012')
 
 
 
164
  result.append(stripped)
165
  return result
166
 
167
 
168
  def _vowels_match(self, v1, v2):
169
+ """Check if vowels are similar enough to rhyme."""
170
+ v1_stripped = v1.rstrip('012')
171
+ v2_stripped = v2.rstrip('012')
 
 
 
 
 
 
 
172
 
173
  if v1_stripped == v2_stripped:
174
  return True
175
 
176
+ for group_str in self.PHONE_GROUPS.split('|'):
177
+ group = group_str.split(',')
178
  if v1_stripped in group and v2_stripped in group:
179
  return True
180
  return False
181
 
182
 
183
+ def _get_rhyme_score(self, phones1, phones2):
184
+ """Calculate rhyme score based on matching phones after primary stressed vowel."""
185
+ # Find primary stressed vowel position in both words
186
+ pos1 = -1
187
+ pos2 = -1
188
+
189
+ for i, phone in enumerate(phones1):
190
+ is_stressed_vowel = False
191
+ if '1' in phone:
192
+ for vowel in 'AEIOU':
193
+ if vowel in phone:
194
+ is_stressed_vowel = True
195
+ break
196
+ if is_stressed_vowel:
197
+ pos1 = i
198
+ break
199
+
200
+ for i, phone in enumerate(phones2):
201
+ is_stressed_vowel = False
202
+ if '1' in phone:
203
+ for vowel in 'AEIOU':
204
+ if vowel in phone:
205
+ is_stressed_vowel = True
206
+ break
207
+ if is_stressed_vowel:
208
+ pos2 = i
209
+ break
210
+
211
+ if pos1 == -1 or pos2 == -1:
212
+ return 0.0
213
+
214
+ # Get all phones after and including the stressed vowel
215
+ rhyme_part1 = phones1[pos1:]
216
+ rhyme_part2 = phones2[pos2:]
217
 
218
+ # If lengths are too different, not a good rhyme
219
+ if abs(len(rhyme_part1) - len(rhyme_part2)) > 1:
220
+ return 0.0
 
 
 
 
 
 
 
221
 
222
+ # Calculate similarity score for rhyming part
223
+ similarity_count = 0
224
+ max_compare = min(len(rhyme_part1), len(rhyme_part2))
225
+
226
+ for i in range(max_compare):
227
+ if self._phones_are_similar(rhyme_part1[i], rhyme_part2[i]):
228
+ similarity_count += 1
229
+
230
+ # Return score based on how many phones were similar
231
+ return similarity_count / max(len(rhyme_part1), len(rhyme_part2)) if max(len(rhyme_part1), len(rhyme_part2)) > 0 else 0.0
232
+
233
+
234
+ def _calculate_phone_sequence_similarity(self, phones1, phones2):
235
+ """Calculate similarity based on matching phones in sequence."""
236
+ if not phones1 or not phones2:
237
+ return 0.0
238
+
239
+ total_similarity = 0.0
240
+ comparisons = max(len(phones1), len(phones2))
241
+
242
+ # Compare each position
243
+ for i in range(min(len(phones1), len(phones2))):
244
+ similarity = self._get_phone_similarity(phones1[i], phones2[i])
245
+ total_similarity += similarity
246
+
247
+ return total_similarity / comparisons if comparisons > 0 else 0.0
248
+
249
+
250
+ def _calculate_length_similarity(self, phones1, phones2):
251
+ """Calculate similarity based on phone length."""
252
+ max_length = max(len(phones1), len(phones2))
253
+ length_diff = abs(len(phones1) - len(phones2))
254
+ return 1.0 - (length_diff / max_length) if max_length > 0 else 0.0
255
+
256
+
257
+ def _calculate_similarity(self, word1, phones1, word2, phones2):
258
+ """Calculate similarity based on multiple factors."""
259
  phone_list1 = phones1.split()
260
  phone_list2 = phones2.split()
261
 
262
+ # Get rhyme score (most important)
263
+ rhyme_score = self._get_rhyme_score(phone_list1, phone_list2)
 
 
 
264
 
265
+ # Calculate phone sequence similarity
266
+ phone_sequence_score = self._calculate_phone_sequence_similarity(phone_list1, phone_list2)
 
 
267
 
268
+ # Calculate length similarity score
269
+ length_score = self._calculate_length_similarity(phone_list1, phone_list2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
+ # Combined weighted score
272
+ similarity = (
273
+ (rhyme_score * self.RHYME_WEIGHT) +
274
+ (phone_sequence_score * self.PHONE_SEQUENCE_WEIGHT) +
275
+ (length_score * self.LENGTH_WEIGHT)
276
  )
277
 
 
 
 
278
  return {
279
+ "similarity": round(similarity, 3),
280
  "rhyme_score": round(rhyme_score, 3),
281
+ "phone_sequence_score": round(phone_sequence_score, 3),
282
  "length_score": round(length_score, 3),
283
+ "details": {
284
+ "primary_vowel1": self._get_primary_vowel(phone_list1),
285
+ "primary_vowel2": self._get_primary_vowel(phone_list2),
286
+ "phone_count1": len(phone_list1),
287
+ "phone_count2": len(phone_list2),
288
+ "is_rhyme": rhyme_score > 0.8,
289
+ "stress_pattern1": self._get_stress_pattern(phones1),
290
+ "stress_pattern2": self._get_stress_pattern(phones2)
291
+ }
292
  }
293
 
294
 
295
+ def _get_stress_pattern(self, phones):
296
+ """Extract stress pattern from phones."""
297
+ import pronouncing
298
+ return pronouncing.stresses(phones)
299
+
300
+
301
  def forward(self, word, compare_to=None, custom_phones=None):
302
  import json
303
  import string
304
  import pronouncing
305
 
306
+ word_clean = word.lower().strip(string.punctuation)
 
 
 
 
 
 
 
 
307
  primary_phones = self._get_word_phones(word_clean, custom_phones)
308
 
309
  if not primary_phones:
 
321
  'syllable_count': pronouncing.syllable_count(primary_phones),
322
  'phones': primary_phones.split(),
323
  'stresses': pronouncing.stresses(primary_phones),
324
+ 'phone_count': len(primary_phones.split()),
325
+ 'primary_vowel': self._get_primary_vowel(primary_phones.split())
326
  }
327
  }
328
 
329
  if compare_to:
330
+ compare_clean = compare_to.lower().strip(string.punctuation)
 
331
  compare_phones = self._get_word_phones(compare_clean, custom_phones)
332
 
333
  if not compare_phones:
 
335
  'error': f'Comparison word "{compare_clean}" not found in dictionary or custom phones'
336
  }
337
  else:
338
+ # Initialize variables for traditional rhyme analysis
339
+ word_last_vowel = None
340
+ compare_last_vowel = None
341
+ word_end = []
342
+ compare_end = []
343
+
344
+ # Get rhyme components for traditional rhyme analysis
345
  word_result = self._get_last_syllable(primary_phones.split())
346
  compare_result = self._get_last_syllable(compare_phones.split())
 
 
347
 
348
+ # Unpack results with explicit assignment
349
+ if word_result and len(word_result) == 2:
350
+ word_last_vowel = word_result[0]
351
+ word_end = word_result[1]
352
+
353
+ if compare_result and len(compare_result) == 2:
354
+ compare_last_vowel = compare_result[0]
355
+ compare_end = compare_result[1]
356
+
357
+ # Calculate if words rhyme using traditional method
358
+ is_rhyme = False
359
  if word_last_vowel and compare_last_vowel:
360
  if self._vowels_match(word_last_vowel, compare_last_vowel):
361
  word_end_clean = self._strip_stress(word_end)
 
363
  if word_end_clean == compare_end_clean:
364
  is_rhyme = True
365
 
366
+ # Basic analysis
367
  word_syl_count = pronouncing.syllable_count(primary_phones)
368
  compare_syl_count = pronouncing.syllable_count(compare_phones)
369
 
 
373
  'syllable_count': compare_syl_count,
374
  'phones': compare_phones.split(),
375
  'stresses': pronouncing.stresses(compare_phones),
376
+ 'phone_count': len(compare_phones.split()),
377
+ 'primary_vowel': self._get_primary_vowel(compare_phones.split())
378
  },
379
  'comparison_stats': {
380
+ 'traditional_rhyme': is_rhyme,
381
  'same_syllable_count': word_syl_count == compare_syl_count,
382
  'same_stress_pattern': pronouncing.stresses(primary_phones) == pronouncing.stresses(compare_phones),
383
  'syllable_difference': abs(word_syl_count - compare_syl_count),
 
385
  }
386
  }
387
 
388
+ # Calculate detailed similarity scores with new algorithm
389
  similarity_result = self._calculate_similarity(
390
  word_clean, primary_phones,
391
  compare_clean, compare_phones