huytofu92 commited on
Commit
1fdec9b
·
1 Parent(s): 1cb8f4e

Enhance audio tools

Browse files
Files changed (1) hide show
  1. audio_tools.py +25 -7
audio_tools.py CHANGED
@@ -11,18 +11,36 @@ class TranscribeAudioTool(Tool):
11
  name = "transcribe_audio"
12
  description = "Transcribe an audio file"
13
  inputs = {
14
- "audio": {"type": "string", "description": "The audio file in base64 format"}
15
  }
16
  output_type = "string"
17
 
18
  def setup(self):
19
  self.model = InferenceClient(model="openai/whisper-large-v3", provider="hf-inference", token=os.getenv("HUGGINGFACE_API_KEY"))
20
 
21
- def forward(self, audio: str) -> str:
22
- audio_data = base64.b64decode(audio)
23
- audio_segment = AudioSegment.from_file(BytesIO(audio_data))
24
- result = self.model.automatic_speech_recognition(audio_segment)
25
- return result["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  transcribe_audio_tool = TranscribeAudioTool()
28
 
@@ -31,7 +49,7 @@ def audio_to_base64(file_path: str) -> str:
31
  """
32
  Convert an audio file to base64 format
33
  Args:
34
- file_path: Path to the audio file
35
  Returns:
36
  The audio file in base64 format
37
  """
 
11
  name = "transcribe_audio"
12
  description = "Transcribe an audio file"
13
  inputs = {
14
+ "audio": {"type": "any", "description": "The audio file in base64 format or as an AudioSegment object"}
15
  }
16
  output_type = "string"
17
 
18
  def setup(self):
19
  self.model = InferenceClient(model="openai/whisper-large-v3", provider="hf-inference", token=os.getenv("HUGGINGFACE_API_KEY"))
20
 
21
+ def forward(self, audio: any) -> str:
22
+ try:
23
+ # Handle AudioSegment object
24
+ if isinstance(audio, AudioSegment):
25
+ # Convert AudioSegment to base64
26
+ buffer = BytesIO()
27
+ audio.export(buffer, format="wav")
28
+ audio_data = buffer.getvalue()
29
+ # Handle base64 string
30
+ elif isinstance(audio, str):
31
+ audio_data = base64.b64decode(audio)
32
+ else:
33
+ raise ValueError(f"Unsupported audio type: {type(audio)}. Expected base64 string or AudioSegment object.")
34
+
35
+ # Create audio segment from the data
36
+ audio_segment = AudioSegment.from_file(BytesIO(audio_data))
37
+
38
+ # Transcribe using the model
39
+ result = self.model.automatic_speech_recognition(audio_segment)
40
+ return result["text"]
41
+
42
+ except Exception as e:
43
+ raise RuntimeError(f"Error in transcription: {str(e)}")
44
 
45
  transcribe_audio_tool = TranscribeAudioTool()
46
 
 
49
  """
50
  Convert an audio file to base64 format
51
  Args:
52
+ file_path: Path to the audio file (should be in mp3 format)
53
  Returns:
54
  The audio file in base64 format
55
  """