pengdaqian commited on
Commit
ca13e69
1 Parent(s): f0c5f90
Files changed (1) hide show
  1. audio_to_text.py +2 -1
audio_to_text.py CHANGED
@@ -10,6 +10,7 @@ from transformers import pipeline
10
 
11
  class AudioPipeline(object):
12
  def __init__(self, audio_text_path, audio_text_embeddings_path):
 
13
  self.model = laion_clap.CLAP_Module(enable_fusion=False)
14
  self.model.load_ckpt() # download the default pretrained checkpoint.
15
  self.audio_text_path = audio_text_path
@@ -39,7 +40,7 @@ class AudioPipeline(object):
39
  texts = json.load(f)
40
 
41
  tensors = {}
42
- with safe_open(self.audio_text_embeddings_path, framework="pt", device=0) as f:
43
  for k in f.keys():
44
  tensors[k] = f.get_tensor(k)
45
  text_embed = tensors["text_embed"]
 
10
 
11
  class AudioPipeline(object):
12
  def __init__(self, audio_text_path, audio_text_embeddings_path):
13
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
  self.model = laion_clap.CLAP_Module(enable_fusion=False)
15
  self.model.load_ckpt() # download the default pretrained checkpoint.
16
  self.audio_text_path = audio_text_path
 
40
  texts = json.load(f)
41
 
42
  tensors = {}
43
+ with safe_open(self.audio_text_embeddings_path, framework="pt", device=self.device) as f:
44
  for k in f.keys():
45
  tensors[k] = f.get_tensor(k)
46
  text_embed = tensors["text_embed"]