jacob-c commited on
Commit
46e5e67
·
1 Parent(s): ed7741a
Files changed (2) hide show
  1. README.md +2 -2
  2. src/classifier.py +10 -4
README.md CHANGED
@@ -62,7 +62,7 @@ python app.py
62
  2. Open your web browser and navigate to the provided URL (typically http://localhost:7860)
63
 
64
  3. Choose your input method:
65
- - Upload an audio file
66
  - Enter lyrics text
67
 
68
  4. Adjust generation parameters:
@@ -75,7 +75,7 @@ python app.py
75
  ## Models Used
76
 
77
  - **Genre Classification**:
78
- - Audio: `anton-l/wav2vec2-base-superb-gc`
79
  - Text: `facebook/bart-large-mnli` (Zero-shot classification)
80
  - **Lyric Generation**: `gpt2-medium`
81
 
 
62
  2. Open your web browser and navigate to the provided URL (typically http://localhost:7860)
63
 
64
  3. Choose your input method:
65
+ - Upload an audio file (supports .mp3, .wav, .ogg, .flac)
66
  - Enter lyrics text
67
 
68
  4. Adjust generation parameters:
 
75
  ## Models Used
76
 
77
  - **Genre Classification**:
78
+ - Audio: `superb/wav2vec2-base-superb-gc` (Pre-trained on music genre classification)
79
  - Text: `facebook/bart-large-mnli` (Zero-shot classification)
80
  - **Lyric Generation**: `gpt2-medium`
81
 
src/classifier.py CHANGED
@@ -13,10 +13,10 @@ class MusicGenreClassifier:
13
  model="facebook/bart-large-mnli"
14
  )
15
 
16
- # For audio classification, we'll use a pre-trained model
17
  self.audio_classifier = pipeline(
18
  "audio-classification",
19
- model="anton-l/wav2vec2-base-superb-gc"
20
  )
21
 
22
  self.genres = [
@@ -29,7 +29,11 @@ class MusicGenreClassifier:
29
  try:
30
  # Load audio using librosa (handles more formats)
31
  waveform, sample_rate = librosa.load(audio_path, sr=16000)
32
- return torch.from_numpy(waveform)
 
 
 
 
33
  except Exception as e:
34
  raise ValueError(f"Error processing audio file: {str(e)}")
35
 
@@ -37,8 +41,10 @@ class MusicGenreClassifier:
37
  """Classify genre from audio file."""
38
  try:
39
  waveform = self.process_audio(audio_path)
40
- predictions = self.audio_classifier(waveform)
41
  # Get the top prediction
 
 
42
  top_pred = max(predictions, key=lambda x: x['score'])
43
  return top_pred['label'], top_pred['score']
44
  except Exception as e:
 
13
  model="facebook/bart-large-mnli"
14
  )
15
 
16
+ # For audio classification, we'll use a different pre-trained model
17
  self.audio_classifier = pipeline(
18
  "audio-classification",
19
+ model="superb/wav2vec2-base-superb-gc"
20
  )
21
 
22
  self.genres = [
 
29
  try:
30
  # Load audio using librosa (handles more formats)
31
  waveform, sample_rate = librosa.load(audio_path, sr=16000)
32
+ # Convert to torch tensor and ensure proper shape
33
+ waveform = torch.from_numpy(waveform).float()
34
+ if len(waveform.shape) == 1:
35
+ waveform = waveform.unsqueeze(0)
36
+ return waveform
37
  except Exception as e:
38
  raise ValueError(f"Error processing audio file: {str(e)}")
39
 
 
41
  """Classify genre from audio file."""
42
  try:
43
  waveform = self.process_audio(audio_path)
44
+ predictions = self.audio_classifier(waveform, top_k=1)
45
  # Get the top prediction
46
+ if isinstance(predictions, list):
47
+ predictions = predictions[0]
48
  top_pred = max(predictions, key=lambda x: x['score'])
49
  return top_pred['label'], top_pred['score']
50
  except Exception as e: