jacob-c commited on
Commit
8599ceb
·
1 Parent(s): a19410a
Files changed (2) hide show
  1. src/classifier.py +28 -11
  2. src/lyric_generator.py +19 -8
src/classifier.py CHANGED
@@ -7,17 +7,34 @@ from typing import Union, Tuple, List
7
 
8
  class MusicGenreClassifier:
9
  def __init__(self):
10
- # Initialize both audio and text classification pipelines
11
- self.text_classifier = pipeline(
12
- "zero-shot-classification",
13
- model="facebook/bart-large-mnli"
14
- )
15
-
16
- # For audio classification, we'll use MIT's music classification model
17
- self.audio_classifier = pipeline(
18
- "audio-classification",
19
- model="mit/ast-finetuned-audioset-10-10-0.4593"
20
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Define standard genres for classification
23
  self.genres = [
 
7
 
8
  class MusicGenreClassifier:
9
  def __init__(self):
10
+ try:
11
+ # Initialize both audio and text classification pipelines with auto device mapping
12
+ self.text_classifier = pipeline(
13
+ "zero-shot-classification",
14
+ model="facebook/bart-large-mnli",
15
+ device_map="auto"
16
+ )
17
+
18
+ # For audio classification, we'll use MIT's music classification model
19
+ self.audio_classifier = pipeline(
20
+ "audio-classification",
21
+ model="mit/ast-finetuned-audioset-10-10-0.4593",
22
+ device_map="auto"
23
+ )
24
+ except Exception as e:
25
+ print(f"Warning: GPU initialization failed, falling back to CPU. Error: {str(e)}")
26
+ # Fall back to CPU if GPU initialization fails
27
+ self.text_classifier = pipeline(
28
+ "zero-shot-classification",
29
+ model="facebook/bart-large-mnli",
30
+ device="cpu"
31
+ )
32
+
33
+ self.audio_classifier = pipeline(
34
+ "audio-classification",
35
+ model="mit/ast-finetuned-audioset-10-10-0.4593",
36
+ device="cpu"
37
+ )
38
 
39
  # Define standard genres for classification
40
  self.genres = [
src/lyric_generator.py CHANGED
@@ -10,14 +10,25 @@ class LyricGenerator:
10
  Args:
11
  model_name: The name of the pre-trained model to use
12
  """
13
- # Check if CUDA is available
14
- device = 0 if torch.cuda.is_available() else -1
15
-
16
- self.generator = pipeline(
17
- "text-generation",
18
- model=model_name,
19
- device=device
20
- )
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Genre-specific prompts to guide generation
23
  self.genre_prompts = {
 
10
  Args:
11
  model_name: The name of the pre-trained model to use
12
  """
13
+ try:
14
+ # Try to use CUDA if available
15
+ if torch.cuda.is_available():
16
+ device = "cuda"
17
+ else:
18
+ device = "cpu"
19
+
20
+ self.generator = pipeline(
21
+ "text-generation",
22
+ model=model_name,
23
+ device_map="auto" # Let transformers handle device mapping
24
+ )
25
+ except Exception as e:
26
+ print(f"Warning: GPU initialization failed, falling back to CPU. Error: {str(e)}")
27
+ self.generator = pipeline(
28
+ "text-generation",
29
+ model=model_name,
30
+ device="cpu"
31
+ )
32
 
33
  # Genre-specific prompts to guide generation
34
  self.genre_prompts = {