Spaces:
Sleeping
Sleeping
Commit
·
70d9de4
1
Parent(s):
96bca50
add: utility for torch backend
Browse files
medrag_multi_modal/retrieval/contriever_retrieval.py
CHANGED
@@ -15,7 +15,7 @@ from transformers import (
|
|
15 |
|
16 |
import wandb
|
17 |
|
18 |
-
from ..utils import get_wandb_artifact
|
19 |
from .common import SimilarityMetric, argsort_scores, mean_pooling
|
20 |
|
21 |
|
@@ -150,7 +150,7 @@ class ContrieverRetriever(weave.Model):
|
|
150 |
os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
|
151 |
) as f:
|
152 |
vector_index = f.get_tensor("vector_index")
|
153 |
-
device = torch.device(
|
154 |
vector_index = vector_index.to(device)
|
155 |
chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows]
|
156 |
return cls(
|
@@ -199,7 +199,7 @@ class ContrieverRetriever(weave.Model):
|
|
199 |
list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
|
200 |
"""
|
201 |
query = [query]
|
202 |
-
device = torch.device(
|
203 |
with torch.no_grad():
|
204 |
query_embedding = self.encode(query).to(device)
|
205 |
if metric == SimilarityMetric.EUCLIDEAN:
|
|
|
15 |
|
16 |
import wandb
|
17 |
|
18 |
+
from ..utils import get_wandb_artifact, get_torch_backend
|
19 |
from .common import SimilarityMetric, argsort_scores, mean_pooling
|
20 |
|
21 |
|
|
|
150 |
os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
|
151 |
) as f:
|
152 |
vector_index = f.get_tensor("vector_index")
|
153 |
+
device = torch.device(get_torch_backend())
|
154 |
vector_index = vector_index.to(device)
|
155 |
chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows]
|
156 |
return cls(
|
|
|
199 |
list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
|
200 |
"""
|
201 |
query = [query]
|
202 |
+
device = torch.device(get_torch_backend())
|
203 |
with torch.no_grad():
|
204 |
query_embedding = self.encode(query).to(device)
|
205 |
if metric == SimilarityMetric.EUCLIDEAN:
|
medrag_multi_modal/utils.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import wandb
|
2 |
|
3 |
|
@@ -14,3 +15,13 @@ def get_wandb_artifact(
|
|
14 |
if get_metadata:
|
15 |
return artifact_dir, artifact.metadata
|
16 |
return artifact_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
import wandb
|
3 |
|
4 |
|
|
|
15 |
if get_metadata:
|
16 |
return artifact_dir, artifact.metadata
|
17 |
return artifact_dir
|
18 |
+
|
19 |
+
|
20 |
+
def get_torch_backend():
|
21 |
+
if torch.cuda.is_available():
|
22 |
+
return "cuda"
|
23 |
+
if torch.backends.mps.is_available():
|
24 |
+
if torch.backends.mps.is_built():
|
25 |
+
return "mps"
|
26 |
+
return "cpu"
|
27 |
+
return "cpu"
|