jacob-c commited on
Commit
a19410a
·
1 Parent(s): 3dc83fd
Files changed (1) hide show
  1. src/lyric_generator.py +5 -1
src/lyric_generator.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import pipeline
 
2
  from typing import Dict, List, Optional
3
 
4
  class LyricGenerator:
@@ -9,10 +10,13 @@ class LyricGenerator:
9
  Args:
10
  model_name: The name of the pre-trained model to use
11
  """
 
 
 
12
  self.generator = pipeline(
13
  "text-generation",
14
  model=model_name,
15
- device=0 if pipeline.device.type == "cuda" else -1
16
  )
17
 
18
  # Genre-specific prompts to guide generation
 
1
  from transformers import pipeline
2
+ import torch
3
  from typing import Dict, List, Optional
4
 
5
  class LyricGenerator:
 
10
  Args:
11
  model_name: The name of the pre-trained model to use
12
  """
13
+ # Check if CUDA is available
14
+ device = 0 if torch.cuda.is_available() else -1
15
+
16
  self.generator = pipeline(
17
  "text-generation",
18
  model=model_name,
19
+ device=device
20
  )
21
 
22
  # Genre-specific prompts to guide generation