geekyrakshit commited on
Commit
70d9de4
·
1 Parent(s): 96bca50

add: utility for torch backend

Browse files
medrag_multi_modal/retrieval/contriever_retrieval.py CHANGED
@@ -15,7 +15,7 @@ from transformers import (
15
 
16
  import wandb
17
 
18
- from ..utils import get_wandb_artifact
19
  from .common import SimilarityMetric, argsort_scores, mean_pooling
20
 
21
 
@@ -150,7 +150,7 @@ class ContrieverRetriever(weave.Model):
150
  os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
151
  ) as f:
152
  vector_index = f.get_tensor("vector_index")
153
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
  vector_index = vector_index.to(device)
155
  chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows]
156
  return cls(
@@ -199,7 +199,7 @@ class ContrieverRetriever(weave.Model):
199
  list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
200
  """
201
  query = [query]
202
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
203
  with torch.no_grad():
204
  query_embedding = self.encode(query).to(device)
205
  if metric == SimilarityMetric.EUCLIDEAN:
 
15
 
16
  import wandb
17
 
18
+ from ..utils import get_wandb_artifact, get_torch_backend
19
  from .common import SimilarityMetric, argsort_scores, mean_pooling
20
 
21
 
 
150
  os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
151
  ) as f:
152
  vector_index = f.get_tensor("vector_index")
153
+ device = torch.device(get_torch_backend())
154
  vector_index = vector_index.to(device)
155
  chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows]
156
  return cls(
 
199
  list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
200
  """
201
  query = [query]
202
+ device = torch.device(get_torch_backend())
203
  with torch.no_grad():
204
  query_embedding = self.encode(query).to(device)
205
  if metric == SimilarityMetric.EUCLIDEAN:
medrag_multi_modal/utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import wandb
2
 
3
 
@@ -14,3 +15,13 @@ def get_wandb_artifact(
14
  if get_metadata:
15
  return artifact_dir, artifact.metadata
16
  return artifact_dir
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  import wandb
3
 
4
 
 
15
  if get_metadata:
16
  return artifact_dir, artifact.metadata
17
  return artifact_dir
18
+
19
+
20
+ def get_torch_backend():
21
+ if torch.cuda.is_available():
22
+ return "cuda"
23
+ if torch.backends.mps.is_available():
24
+ if torch.backends.mps.is_built():
25
+ return "mps"
26
+ return "cpu"
27
+ return "cpu"