geekyrakshit commited on
Commit
96bca50
·
1 Parent(s): ca2f3e7

update: get_wandb_artifact

Browse files
medrag_multi_modal/retrieval/common.py CHANGED
@@ -1,7 +1,5 @@
1
  from enum import Enum
2
 
3
- import wandb
4
-
5
 
6
  class SimilarityMetric(Enum):
7
  COSINE = "cosine"
@@ -14,18 +12,6 @@ def mean_pooling(token_embeddings, mask):
14
  return sentence_embeddings
15
 
16
 
17
- def get_wandb_artifact(artifact_address: str, artifact_type: str):
18
- if wandb.run:
19
- artifact = wandb.run.use_artifact(artifact_address, type=artifact_type)
20
- artifact_dir = artifact.download()
21
- else:
22
- api = wandb.Api()
23
- artifact = api.artifact(artifact_address)
24
- artifact_dir = artifact.download()
25
- metadata = artifact.metadata
26
- return artifact_dir, metadata
27
-
28
-
29
  def argsort_scores(scores: list[float], descending: bool = False):
30
  return [
31
  {"item": item, "original_index": idx}
 
1
  from enum import Enum
2
 
 
 
3
 
4
  class SimilarityMetric(Enum):
5
  COSINE = "cosine"
 
12
  return sentence_embeddings
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def argsort_scores(scores: list[float], descending: bool = False):
16
  return [
17
  {"item": item, "original_index": idx}
medrag_multi_modal/retrieval/contriever_retrieval.py CHANGED
@@ -15,7 +15,8 @@ from transformers import (
15
 
16
  import wandb
17
 
18
- from .common import SimilarityMetric, argsort_scores, get_wandb_artifact, mean_pooling
 
19
 
20
 
21
  class ContrieverRetriever(weave.Model):
@@ -143,7 +144,7 @@ class ContrieverRetriever(weave.Model):
143
  and chunk dataset.
144
  """
145
  artifact_dir, metadata = get_wandb_artifact(
146
- index_artifact_address, "contriever-index"
147
  )
148
  with safetensors.torch.safe_open(
149
  os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
 
15
 
16
  import wandb
17
 
18
+ from ..utils import get_wandb_artifact
19
+ from .common import SimilarityMetric, argsort_scores, mean_pooling
20
 
21
 
22
  class ContrieverRetriever(weave.Model):
 
144
  and chunk dataset.
145
  """
146
  artifact_dir, metadata = get_wandb_artifact(
147
+ index_artifact_address, "contriever-index", get_metadata=True
148
  )
149
  with safetensors.torch.safe_open(
150
  os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
medrag_multi_modal/utils.py CHANGED
@@ -1,7 +1,9 @@
1
  import wandb
2
 
3
 
4
- def get_wandb_artifact(artifact_name: str, artifact_type: str) -> str:
 
 
5
  if wandb.run:
6
  artifact = wandb.use_artifact(artifact_name, type=artifact_type)
7
  artifact_dir = artifact.download()
@@ -9,4 +11,6 @@ def get_wandb_artifact(artifact_name: str, artifact_type: str) -> str:
9
  api = wandb.Api()
10
  artifact = api.artifact(artifact_name)
11
  artifact_dir = artifact.download()
 
 
12
  return artifact_dir
 
1
  import wandb
2
 
3
 
4
+ def get_wandb_artifact(
5
+ artifact_name: str, artifact_type: str, get_metadata: bool = False
6
+ ) -> str:
7
  if wandb.run:
8
  artifact = wandb.use_artifact(artifact_name, type=artifact_type)
9
  artifact_dir = artifact.download()
 
11
  api = wandb.Api()
12
  artifact = api.artifact(artifact_name)
13
  artifact_dir = artifact.download()
14
+ if get_metadata:
15
+ return artifact_dir, artifact.metadata
16
  return artifact_dir