patruff commited on
Commit
05e1f73
·
verified ·
1 Parent(s): af5f95f

Upload tool

Browse files
Files changed (1) hide show
  1. tool.py +71 -24
tool.py CHANGED
@@ -1,8 +1,8 @@
1
  from smolagents.tools import Tool
2
  import json
 
3
  import string
4
  import pronouncing
5
- import difflib
6
 
7
  class ParodyWordSuggestionTool(Tool):
8
  name = "parody_word_suggester"
@@ -11,21 +11,57 @@ class ParodyWordSuggestionTool(Tool):
11
  inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True}}
12
  output_type = "string"
13
 
14
- def _has_vowel(self, phone: str) -> bool:
15
- """Check if a phone contains a vowel."""
16
- VOWELS = ['A', 'E', 'I', 'O', 'U']
17
- for vowel in VOWELS:
18
- if vowel in phone:
19
- return True
20
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
- def _get_rhyme_part(self, phones: list) -> list:
24
- """Get the rhyming part of a word (last vowel onwards)."""
25
- for i, phone in enumerate(reversed(phones)):
26
- if self._has_vowel(phone):
27
- return phones[-(i+1):]
28
- return phones
 
 
 
 
 
 
29
 
30
 
31
  def forward(self, target: str, word_list_str: str, min_similarity: str = "0.5") -> str:
@@ -35,9 +71,14 @@ class ParodyWordSuggestionTool(Tool):
35
  import json
36
  from difflib import SequenceMatcher
37
 
 
38
  target = target.lower().strip(string.punctuation)
39
  min_similarity = float(min_similarity)
40
  suggestions = []
 
 
 
 
41
 
42
  # Parse JSON string to list
43
  try:
@@ -58,7 +99,7 @@ class ParodyWordSuggestionTool(Tool):
58
 
59
  target_phones = target_phones[0]
60
  target_phone_list = target_phones.split()
61
- target_rhyme_part = self._get_rhyme_part(target_phone_list)
62
 
63
  # Check each word
64
  for word in words:
@@ -67,11 +108,19 @@ class ParodyWordSuggestionTool(Tool):
67
  if phones:
68
  word_phones = phones[0]
69
  word_phone_list = word_phones.split()
70
- word_rhyme_part = self._get_rhyme_part(word_phone_list)
71
 
72
  # 1. Rhyme score (most important - 60%)
73
- # Perfect rhyme if the rhyming parts match exactly
74
- rhyme_score = 1.0 if word_rhyme_part == target_rhyme_part else 0.0
 
 
 
 
 
 
 
 
75
 
76
  # 2. Syllable match (25%)
77
  target_syl = pronouncing.syllable_count(target_phones)
@@ -94,7 +143,8 @@ class ParodyWordSuggestionTool(Tool):
94
  "string_similarity": round(string_similarity, 3),
95
  "syllables": word_syl,
96
  "phones": word_phones,
97
- "rhyme_part": " ".join(word_rhyme_part)
 
98
  })
99
 
100
  # Sort by similarity score descending
@@ -104,13 +154,10 @@ class ParodyWordSuggestionTool(Tool):
104
  "target": target,
105
  "target_syllables": pronouncing.syllable_count(target_phones),
106
  "target_phones": target_phones,
107
- "target_rhyme_part": " ".join(target_rhyme_part),
 
108
  "suggestions": suggestions
109
  }
110
 
111
  return json.dumps(result, indent=2)
112
 
113
-
114
- def __init__(self, *args, **kwargs):
115
- self.is_initialized = False
116
-
 
1
  from smolagents.tools import Tool
2
  import json
3
+ import difflib
4
  import string
5
  import pronouncing
 
6
 
7
  class ParodyWordSuggestionTool(Tool):
8
  name = "parody_word_suggester"
 
11
  inputs = {'target': {'type': 'string', 'description': 'The word you want to find rhyming alternatives for'}, 'word_list_str': {'type': 'string', 'description': 'JSON string of word list (e.g. \'["word1", "word2"]\')'}, 'min_similarity': {'type': 'string', 'description': 'Minimum similarity threshold (0.0-1.0)', 'nullable': True}}
12
  output_type = "string"
13
 
14
+ def __init__(self):
15
+ self.vowel_groups = [
16
+ {'AH', 'UH', 'AX'}, # 'luck', 'buck', 'cuck'
17
+ {'AE', 'EH'}, # 'bat', 'bet'
18
+ {'IY', 'IH'}, # 'seat', 'sit'
19
+ {'AO', 'AA'}, # 'caught', 'cot'
20
+ {'UW', 'UH'}, # 'boot', 'put'
21
+ {'AY', 'EY'}, # 'bite', 'bait'
22
+ {'OW', 'AO'}, # 'boat', 'bought'
23
+ {'AW', 'AO'}, # 'bout', 'bought'
24
+ {'OY', 'OW'}, # 'boy', 'bow'
25
+ {'ER', 'AXR'}, # 'bird', 'hurt'
26
+ ]
27
+
28
+
29
+ def _get_last_syllable(self, phones: list) -> tuple:
30
+ """Extract the last syllable (vowel + remaining consonants)."""
31
+ last_vowel_idx = -1
32
+ last_vowel = None
33
+
34
+ # Find the last vowel
35
+ for i, phone in enumerate(phones):
36
+ # Strip stress markers for checking
37
+ base_phone = phone.rstrip('012')
38
+ for group in self.vowel_groups:
39
+ if base_phone in group:
40
+ last_vowel_idx = i
41
+ last_vowel = base_phone
42
+ break
43
+
44
+ if last_vowel_idx == -1:
45
+ return None, []
46
+
47
+ # Get all consonants after the vowel
48
+ remaining = phones[last_vowel_idx + 1:]
49
+
50
+ return last_vowel, remaining
51
 
52
 
53
+ def _vowels_match(self, v1: str, v2: str) -> bool:
54
+ """Check if two vowels are in the same group."""
55
+ v1 = v1.rstrip('012')
56
+ v2 = v2.rstrip('012')
57
+
58
+ if v1 == v2:
59
+ return True
60
+
61
+ for group in self.vowel_groups:
62
+ if v1 in group and v2 in group:
63
+ return True
64
+ return False
65
 
66
 
67
  def forward(self, target: str, word_list_str: str, min_similarity: str = "0.5") -> str:
 
71
  import json
72
  from difflib import SequenceMatcher
73
 
74
+ # Initialize variables
75
  target = target.lower().strip(string.punctuation)
76
  min_similarity = float(min_similarity)
77
  suggestions = []
78
+ word_vowel = None
79
+ word_end = []
80
+ target_vowel = None
81
+ target_end = []
82
 
83
  # Parse JSON string to list
84
  try:
 
99
 
100
  target_phones = target_phones[0]
101
  target_phone_list = target_phones.split()
102
+ target_vowel, target_end = self._get_last_syllable(target_phone_list)
103
 
104
  # Check each word
105
  for word in words:
 
108
  if phones:
109
  word_phones = phones[0]
110
  word_phone_list = word_phones.split()
111
+ word_vowel, word_end = self._get_last_syllable(word_phone_list)
112
 
113
  # 1. Rhyme score (most important - 60%)
114
+ rhyme_score = 0.0
115
+ if word_vowel and target_vowel:
116
+ # Check vowel match
117
+ if self._vowels_match(word_vowel, target_vowel):
118
+ # Perfect rhyme if endings match too
119
+ if word_end == target_end:
120
+ rhyme_score = 1.0
121
+ # Partial rhyme if just the vowel matches
122
+ else:
123
+ rhyme_score = 0.6
124
 
125
  # 2. Syllable match (25%)
126
  target_syl = pronouncing.syllable_count(target_phones)
 
143
  "string_similarity": round(string_similarity, 3),
144
  "syllables": word_syl,
145
  "phones": word_phones,
146
+ "last_vowel": word_vowel,
147
+ "ending": " ".join(word_end) if word_end else ""
148
  })
149
 
150
  # Sort by similarity score descending
 
154
  "target": target,
155
  "target_syllables": pronouncing.syllable_count(target_phones),
156
  "target_phones": target_phones,
157
+ "target_last_vowel": target_vowel,
158
+ "target_ending": " ".join(target_end) if target_end else "",
159
  "suggestions": suggestions
160
  }
161
 
162
  return json.dumps(result, indent=2)
163