jacob-c commited on
Commit
c9ff2a7
·
1 Parent(s): 6dfd6d2
Files changed (1) hide show
  1. app.py +22 -3
app.py CHANGED
@@ -4,6 +4,8 @@ import os
4
  import torch
5
  import json
6
  import time
 
 
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
9
  # Check if CUDA is available and set the device accordingly
@@ -128,13 +130,30 @@ def classify_and_generate(audio_file):
128
  if not token:
129
  return "Error: HF_TOKEN environment variable is not set. Please set your Hugging Face API token."
130
 
131
- # First, classify the audio
132
- with open(audio_file, "rb") as f:
133
- data = f.read()
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  print("Sending request to Audio Classification API...")
136
  response = requests.post(AUDIO_API_URL, headers=headers, data=data)
137
 
 
 
 
 
 
 
138
  if response.status_code == 200:
139
  classification_results = response.json()
140
  # Format classification results
 
4
  import torch
5
  import json
6
  import time
7
+ import tempfile
8
+ import shutil
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
 
11
  # Check if CUDA is available and set the device accordingly
 
130
  if not token:
131
  return "Error: HF_TOKEN environment variable is not set. Please set your Hugging Face API token."
132
 
133
+ # Create a temporary file to handle the audio data
134
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_audio:
135
+ # If audio_file is a tuple (file path and sampling rate)
136
+ if isinstance(audio_file, tuple):
137
+ audio_path = audio_file[0]
138
+ else:
139
+ audio_path = audio_file
140
+
141
+ # Copy the audio file to our temporary file
142
+ shutil.copy2(audio_path, temp_audio.name)
143
+
144
+ # Read the temporary file
145
+ with open(temp_audio.name, "rb") as f:
146
+ data = f.read()
147
 
148
  print("Sending request to Audio Classification API...")
149
  response = requests.post(AUDIO_API_URL, headers=headers, data=data)
150
 
151
+ # Clean up the temporary file
152
+ try:
153
+ os.unlink(temp_audio.name)
154
+ except:
155
+ pass
156
+
157
  if response.status_code == 200:
158
  classification_results = response.json()
159
  # Format classification results