Spaces:
Sleeping
Sleeping
Commit
·
96bca50
1
Parent(s):
ca2f3e7
update: get_wandb_artifact
Browse files
medrag_multi_modal/retrieval/common.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
from enum import Enum
|
2 |
|
3 |
-
import wandb
|
4 |
-
|
5 |
|
6 |
class SimilarityMetric(Enum):
|
7 |
COSINE = "cosine"
|
@@ -14,18 +12,6 @@ def mean_pooling(token_embeddings, mask):
|
|
14 |
return sentence_embeddings
|
15 |
|
16 |
|
17 |
-
def get_wandb_artifact(artifact_address: str, artifact_type: str):
|
18 |
-
if wandb.run:
|
19 |
-
artifact = wandb.run.use_artifact(artifact_address, type=artifact_type)
|
20 |
-
artifact_dir = artifact.download()
|
21 |
-
else:
|
22 |
-
api = wandb.Api()
|
23 |
-
artifact = api.artifact(artifact_address)
|
24 |
-
artifact_dir = artifact.download()
|
25 |
-
metadata = artifact.metadata
|
26 |
-
return artifact_dir, metadata
|
27 |
-
|
28 |
-
|
29 |
def argsort_scores(scores: list[float], descending: bool = False):
|
30 |
return [
|
31 |
{"item": item, "original_index": idx}
|
|
|
1 |
from enum import Enum
|
2 |
|
|
|
|
|
3 |
|
4 |
class SimilarityMetric(Enum):
|
5 |
COSINE = "cosine"
|
|
|
12 |
return sentence_embeddings
|
13 |
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def argsort_scores(scores: list[float], descending: bool = False):
|
16 |
return [
|
17 |
{"item": item, "original_index": idx}
|
medrag_multi_modal/retrieval/contriever_retrieval.py
CHANGED
@@ -15,7 +15,8 @@ from transformers import (
|
|
15 |
|
16 |
import wandb
|
17 |
|
18 |
-
from
|
|
|
19 |
|
20 |
|
21 |
class ContrieverRetriever(weave.Model):
|
@@ -143,7 +144,7 @@ class ContrieverRetriever(weave.Model):
|
|
143 |
and chunk dataset.
|
144 |
"""
|
145 |
artifact_dir, metadata = get_wandb_artifact(
|
146 |
-
index_artifact_address, "contriever-index"
|
147 |
)
|
148 |
with safetensors.torch.safe_open(
|
149 |
os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
|
|
|
15 |
|
16 |
import wandb
|
17 |
|
18 |
+
from ..utils import get_wandb_artifact
|
19 |
+
from .common import SimilarityMetric, argsort_scores, mean_pooling
|
20 |
|
21 |
|
22 |
class ContrieverRetriever(weave.Model):
|
|
|
144 |
and chunk dataset.
|
145 |
"""
|
146 |
artifact_dir, metadata = get_wandb_artifact(
|
147 |
+
index_artifact_address, "contriever-index", get_metadata=True
|
148 |
)
|
149 |
with safetensors.torch.safe_open(
|
150 |
os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
|
medrag_multi_modal/utils.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import wandb
|
2 |
|
3 |
|
4 |
-
def get_wandb_artifact(
|
|
|
|
|
5 |
if wandb.run:
|
6 |
artifact = wandb.use_artifact(artifact_name, type=artifact_type)
|
7 |
artifact_dir = artifact.download()
|
@@ -9,4 +11,6 @@ def get_wandb_artifact(artifact_name: str, artifact_type: str) -> str:
|
|
9 |
api = wandb.Api()
|
10 |
artifact = api.artifact(artifact_name)
|
11 |
artifact_dir = artifact.download()
|
|
|
|
|
12 |
return artifact_dir
|
|
|
1 |
import wandb
|
2 |
|
3 |
|
4 |
+
def get_wandb_artifact(
|
5 |
+
artifact_name: str, artifact_type: str, get_metadata: bool = False
|
6 |
+
) -> str:
|
7 |
if wandb.run:
|
8 |
artifact = wandb.use_artifact(artifact_name, type=artifact_type)
|
9 |
artifact_dir = artifact.download()
|
|
|
11 |
api = wandb.Api()
|
12 |
artifact = api.artifact(artifact_name)
|
13 |
artifact_dir = artifact.download()
|
14 |
+
if get_metadata:
|
15 |
+
return artifact_dir, artifact.metadata
|
16 |
return artifact_dir
|