patruff commited on
Commit
49ca653
·
verified ·
1 Parent(s): 488bbfc

Upload tool

Browse files
Files changed (1) hide show
  1. tool.py +87 -29
tool.py CHANGED
@@ -1,13 +1,13 @@
1
  from smolagents.tools import Tool
2
- import json
3
  import string
4
  import pronouncing
 
5
 
6
  class ParodyWordSuggestionTool(Tool):
7
  name = "parody_word_suggester"
8
  description = """Suggests rhyming funny words using CMU dictionary and custom pronunciations.
9
  Returns similar-sounding words that rhyme, especially focusing on common vowel sounds."""
10
- inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'nullable': True}}
11
  output_type = "string"
12
  VOWEL_REF = "AH,AX|UH|AE,EH|IY,IH|AO,AA|UW|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR"
13
 
@@ -71,35 +71,78 @@ class ParodyWordSuggestionTool(Tool):
71
  return False
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def _calculate_similarity(self, word1, phones1, word2, phones2):
75
- """Calculate similarity score using improved metrics."""
76
- # Initialize all variables
 
 
 
 
 
 
77
  word_vowel = None
78
  word_end = []
79
  target_vowel = None
80
  target_end = []
81
- phone_diff = 0
82
- max_phones = 0
83
  length_score = 0.0
84
  rhyme_score = 0.0
85
  stress_score = 0.0
86
- i = 0 # For loop counter
87
  word_end_clean = []
88
  target_end_clean = []
89
- matched = 0
90
  common_length = 0
 
 
 
 
 
 
91
 
 
92
  phone_list1 = phones1.split()
93
  phone_list2 = phones2.split()
94
 
95
- # Calculate length similarity score
96
- phone_diff = abs(len(phone_list1) - len(phone_list2))
97
- max_phones = max(len(phone_list1), len(phone_list2))
98
- length_score = 1.0 if phone_diff == 0 else 1.0 - (phone_diff / max_phones)
 
 
 
 
99
 
100
- # Get last syllable components
101
- result1 = self._get_last_syllable(phone_list1)
102
- result2 = self._get_last_syllable(phone_list2)
103
  word_vowel, word_end = result1
104
  target_vowel, target_end = result2
105
 
@@ -111,30 +154,35 @@ class ParodyWordSuggestionTool(Tool):
111
  target_end_clean = self._strip_stress(target_end)
112
 
113
  if word_end_clean == target_end_clean:
114
- rhyme_score = 1.0
 
 
 
115
  else:
116
- # Partial rhyme based on ending similarity
117
  common_length = min(len(word_end_clean), len(target_end_clean))
118
  matched = 0
119
  for i in range(common_length):
120
  if word_end_clean[i] == target_end_clean[i]:
121
  matched += 1
122
- rhyme_score = 0.6 * (matched / max(len(word_end_clean), len(target_end_clean)))
123
 
124
- # Calculate stress pattern similarity
125
  import pronouncing
126
- stress1 = pronouncing.stresses(phones1)
127
- stress2 = pronouncing.stresses(phones2)
128
- stress_score = 1.0 if stress1 == stress2 else 0.5
 
 
 
129
 
130
- # Weighted combination (60% rhyme, 30% length, 10% stress)
131
  similarity = (
132
- (rhyme_score * 0.6) +
133
- (length_score * 0.3) +
134
- (stress_score * 0.1)
 
135
  )
136
 
137
- # Cap at 1.0
138
  similarity = min(1.0, similarity)
139
 
140
  return {
@@ -142,7 +190,9 @@ class ParodyWordSuggestionTool(Tool):
142
  "rhyme_score": round(rhyme_score, 3),
143
  "length_score": round(length_score, 3),
144
  "stress_score": round(stress_score, 3),
145
- "phone_length_difference": phone_diff
 
 
146
  }
147
 
148
 
@@ -162,6 +212,12 @@ class ParodyWordSuggestionTool(Tool):
162
  valid_words = []
163
  invalid_words = []
164
  target_phone_list = []
 
 
 
 
 
 
165
 
166
  # Parse JSON string to list
167
  try:
@@ -216,7 +272,9 @@ class ParodyWordSuggestionTool(Tool):
216
  "rhyme_score": similarity_result["rhyme_score"],
217
  "length_score": similarity_result["length_score"],
218
  "stress_score": similarity_result["stress_score"],
219
- "phone_length_difference": similarity_result["phone_length_difference"],
 
 
220
  "phones": word_phones,
221
  "last_vowel": word_vowel,
222
  "ending": " ".join(word_end) if word_end else "",
 
1
  from smolagents.tools import Tool
 
2
  import string
3
  import pronouncing
4
+ import json
5
 
6
  class ParodyWordSuggestionTool(Tool):
7
  name = "parody_word_suggester"
8
  description = """Suggests rhyming funny words using CMU dictionary and custom pronunciations.
9
  Returns similar-sounding words that rhyme, especially focusing on common vowel sounds."""
10
+ inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'default': '0.5', 'nullable': True}, 'custom_phones': {'type': 'object', 'description': 'Optional dictionary of custom word pronunciations', 'default': None, 'nullable': True}}
11
  output_type = "string"
12
  VOWEL_REF = "AH,AX|UH|AE,EH|IY,IH|AO,AA|UW|AY,EY|OW,AO|AW,AO|OY,OW|ER,AXR"
13
 
 
71
  return False
72
 
73
 
74
+ def _strip_common_suffix(self, phones: list) -> tuple:
75
+ """Strip common suffixes and return base and suffix phones."""
76
+ # Initialize variables
77
+ suffix_name = ""
78
+ suffix_phones = []
79
+ phone1 = ""
80
+ phone2 = ""
81
+
82
+ # Common suffix patterns in CMU phonetic representation
83
+ SUFFIXES = {
84
+ 'ING': ['IH0', 'NG'], # -ing
85
+ 'ED': ['EH0', 'D'], # -ed
86
+ 'ER': ['ER0'], # -er
87
+ 'EST': ['EH0', 'S', 'T'], # -est
88
+ 'LY': ['L', 'IY0'], # -ly
89
+ 'NESS': ['N', 'EH0', 'S'], # -ness
90
+ }
91
+
92
+ for suffix_name, suffix_phones in SUFFIXES.items():
93
+ if len(phones) > len(suffix_phones):
94
+ if all(phone1.rstrip('012') == phone2.rstrip('012')
95
+ for phone1, phone2 in zip(phones[-len(suffix_phones):], suffix_phones)):
96
+ return phones[:-len(suffix_phones)], suffix_phones
97
+
98
+ return phones, []
99
+
100
+
101
  def _calculate_similarity(self, word1, phones1, word2, phones2):
102
+ """Calculate similarity score using improved metrics and suffix handling."""
103
+ # Initialize all variables first
104
+ phone_list1 = []
105
+ phone_list2 = []
106
+ base1 = []
107
+ base2 = []
108
+ suffix1 = []
109
+ suffix2 = []
110
  word_vowel = None
111
  word_end = []
112
  target_vowel = None
113
  target_end = []
114
+ base_length_diff = 0
115
+ max_base_length = 0
116
  length_score = 0.0
117
  rhyme_score = 0.0
118
  stress_score = 0.0
119
+ suffix_score = 0.0
120
  word_end_clean = []
121
  target_end_clean = []
 
122
  common_length = 0
123
+ matched = 0
124
+ stress1 = ""
125
+ stress2 = ""
126
+ similarity = 0.0
127
+ result1 = (None, [])
128
+ result2 = (None, [])
129
 
130
+ # Main logic
131
  phone_list1 = phones1.split()
132
  phone_list2 = phones2.split()
133
 
134
+ # Strip common suffixes first
135
+ base1, suffix1 = self._strip_common_suffix(phone_list1)
136
+ base2, suffix2 = self._strip_common_suffix(phone_list2)
137
+
138
+ # Calculate base word similarity
139
+ base_length_diff = abs(len(base1) - len(base2))
140
+ max_base_length = max(len(base1), len(base2))
141
+ length_score = 1.0 if base_length_diff == 0 else 1.0 - (base_length_diff / max_base_length)
142
 
143
+ # Get last syllable components of base words
144
+ result1 = self._get_last_syllable(base1)
145
+ result2 = self._get_last_syllable(base2)
146
  word_vowel, word_end = result1
147
  target_vowel, target_end = result2
148
 
 
154
  target_end_clean = self._strip_stress(target_end)
155
 
156
  if word_end_clean == target_end_clean:
157
+ if word_vowel.rstrip('012') == target_vowel.rstrip('012'):
158
+ rhyme_score = 1.0
159
+ else:
160
+ rhyme_score = 0.7 # Penalize different vowels in same group
161
  else:
 
162
  common_length = min(len(word_end_clean), len(target_end_clean))
163
  matched = 0
164
  for i in range(common_length):
165
  if word_end_clean[i] == target_end_clean[i]:
166
  matched += 1
167
+ rhyme_score = 0.3 * (matched / max(len(word_end_clean), len(target_end_clean)))
168
 
169
+ # Calculate stress pattern similarity using base words
170
  import pronouncing
171
+ stress1 = pronouncing.stresses(' '.join(base1))
172
+ stress2 = pronouncing.stresses(' '.join(base2))
173
+ stress_score = 1.0 if stress1 == stress2 else 0.3
174
+
175
+ # Add suffix matching bonus
176
+ suffix_score = 1.0 if suffix1 == suffix2 else 0.0
177
 
178
+ # Weighted combination with emphasis on base word similarity
179
  similarity = (
180
+ (rhyme_score * 0.6) + # Base word rhyme
181
+ (length_score * 0.1) + # Base word length
182
+ (stress_score * 0.2) + # Base word stress
183
+ (suffix_score * 0.1) # Suffix match as small bonus
184
  )
185
 
 
186
  similarity = min(1.0, similarity)
187
 
188
  return {
 
190
  "rhyme_score": round(rhyme_score, 3),
191
  "length_score": round(length_score, 3),
192
  "stress_score": round(stress_score, 3),
193
+ "base_word_diff": base_length_diff,
194
+ "has_common_suffix": bool(suffix1 and suffix2),
195
+ "suffix_match": suffix_score == 1.0
196
  }
197
 
198
 
 
212
  valid_words = []
213
  invalid_words = []
214
  target_phone_list = []
215
+ words = []
216
+ target_phones = ""
217
+ word_phones = ""
218
+ word = ""
219
+ word_phone_list = []
220
+ similarity_result = {}
221
 
222
  # Parse JSON string to list
223
  try:
 
272
  "rhyme_score": similarity_result["rhyme_score"],
273
  "length_score": similarity_result["length_score"],
274
  "stress_score": similarity_result["stress_score"],
275
+ "base_word_diff": similarity_result["base_word_diff"],
276
+ "has_common_suffix": similarity_result["has_common_suffix"],
277
+ "suffix_match": similarity_result["suffix_match"],
278
  "phones": word_phones,
279
  "last_vowel": word_vowel,
280
  "ending": " ".join(word_end) if word_end else "",