geekyrakshit commited on
Commit
3bd446c
·
1 Parent(s): 353a440

fix: catch byaldi import

Browse files
medrag_multi_modal/retrieval/colpali_retrieval.py CHANGED
@@ -1,12 +1,10 @@
1
  import os
2
- from typing import Any, Optional
3
 
4
  import weave
5
 
6
- try:
7
  from byaldi import RAGMultiModalModel
8
- except ImportError:
9
- pass
10
 
11
  from PIL import Image
12
 
@@ -77,18 +75,20 @@ class CalPaliRetriever(weave.Model):
77
  """
78
 
79
  model_name: str
80
- _docs_retrieval_model: Optional[RAGMultiModalModel] = None
81
  _metadata: Optional[dict] = None
82
  _data_artifact_dir: Optional[str] = None
83
 
84
  def __init__(
85
  self,
86
  model_name: str = "vidore/colpali-v1.2",
87
- docs_retrieval_model: Optional[RAGMultiModalModel] = None,
88
  data_artifact_dir: Optional[str] = None,
89
  metadata_dataset_name: Optional[str] = None,
90
  ):
91
  super().__init__(model_name=model_name)
 
 
92
  self._docs_retrieval_model = (
93
  docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
94
  )
@@ -106,6 +106,8 @@ class CalPaliRetriever(weave.Model):
106
  metadata_dataset_name: str,
107
  data_artifact_name: str,
108
  ):
 
 
109
  index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
110
  data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
111
  docs_retrieval_model = RAGMultiModalModel.from_index(
 
1
  import os
2
+ from typing import TYPE_CHECKING, Any, Optional
3
 
4
  import weave
5
 
6
+ if TYPE_CHECKING:
7
  from byaldi import RAGMultiModalModel
 
 
8
 
9
  from PIL import Image
10
 
 
75
  """
76
 
77
  model_name: str
78
+ _docs_retrieval_model: Optional["RAGMultiModalModel"] = None
79
  _metadata: Optional[dict] = None
80
  _data_artifact_dir: Optional[str] = None
81
 
82
  def __init__(
83
  self,
84
  model_name: str = "vidore/colpali-v1.2",
85
+ docs_retrieval_model: Optional["RAGMultiModalModel"] = None,
86
  data_artifact_dir: Optional[str] = None,
87
  metadata_dataset_name: Optional[str] = None,
88
  ):
89
  super().__init__(model_name=model_name)
90
+ from byaldi import RAGMultiModalModel
91
+
92
  self._docs_retrieval_model = (
93
  docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
94
  )
 
106
  metadata_dataset_name: str,
107
  data_artifact_name: str,
108
  ):
109
+ from byaldi import RAGMultiModalModel
110
+
111
  index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
112
  data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
113
  docs_retrieval_model = RAGMultiModalModel.from_index(