m7n commited on
Commit
6bef292
·
1 Parent(s): 80e91af

Enhance device selection logic in setup_embedding_model to support MPS for Apple Silicon, in addition to CUDA and CPU fallback.

Browse files
Files changed (1) hide show
  1. data_setup.py +6 -1
data_setup.py CHANGED
@@ -111,7 +111,12 @@ def setup_embedding_model(model_name):
111
  model_name (str): Name or path of the SentenceTransformer model
112
  """
113
  print(f"Setting up language model: {time.strftime('%Y-%m-%d %H:%M:%S')}")
114
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
115
  print(f"Using device: {device}")
116
 
117
  model = SentenceTransformer(model_name)
 
111
  model_name (str): Name or path of the SentenceTransformer model
112
  """
113
  print(f"Setting up language model: {time.strftime('%Y-%m-%d %H:%M:%S')}")
114
+ if torch.cuda.is_available():
115
+ device = torch.device("cuda")
116
+ elif torch.backends.mps.is_available():
117
+ device = torch.device("mps")
118
+ else:
119
+ device = torch.device("cpu")
120
  print(f"Using device: {device}")
121
 
122
  model = SentenceTransformer(model_name)