intfloat commited on
Commit
7f44eb9
·
verified ·
1 Parent(s): b620de4

Update custom_st.py

Browse files
Files changed (1) hide show
  1. custom_st.py +28 -23
custom_st.py CHANGED
@@ -1,38 +1,43 @@
 
1
  from typing import Any, Dict, Optional, List
2
  import torch
3
  from PIL import Image
4
  from transformers import AutoProcessor, MllamaForConditionalGeneration
5
  from sentence_transformers.models import Transformer as BaseTransformer
6
 
 
7
  class MultiModalTransformer(BaseTransformer):
8
  def __init__(
9
- self,
10
- model_name_or_path: str,
11
- cache_dir: Optional[str] = None,
12
- tokenizer_args: Optional[Dict[str, Any]] = None,
13
- **kwargs,
14
  ):
15
  super().__init__(model_name_or_path, **kwargs)
16
  if tokenizer_args is None:
17
  tokenizer_args = {}
18
-
19
- # Initialize processor and set padding side
20
  self.processor = AutoProcessor.from_pretrained(
21
  model_name_or_path, cache_dir=cache_dir, **tokenizer_args
22
  )
23
-
24
- # Configure model settings
25
- config = self.auto_model.config
26
- if hasattr(config, 'use_cache'):
27
- config.use_cache = False
28
 
29
- padding_side = "right"
30
- self.processor.tokenizer.padding_side = padding_side
31
- config.padding_side = padding_side
32
- self.auto_model.padding_side = padding_side
 
 
 
 
 
 
 
 
33
 
34
  def forward(
35
- self, features: Dict[str, torch.Tensor], **kwargs
36
  ) -> Dict[str, torch.Tensor]:
37
  # Process inputs through the model
38
  outputs = self.auto_model(
@@ -41,12 +46,12 @@ class MultiModalTransformer(BaseTransformer):
41
  output_hidden_states=True,
42
  **kwargs
43
  )
44
-
45
  # Apply last pooling and normalization
46
  last_hidden_state = outputs.hidden_states[-1]
47
  attention_mask = features["attention_mask"]
48
  sentence_embedding = self._last_pooling(last_hidden_state, attention_mask)
49
-
50
  features.update({"sentence_embedding": sentence_embedding})
51
  return features
52
 
@@ -57,11 +62,11 @@ class MultiModalTransformer(BaseTransformer):
57
  reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
58
  return torch.nn.functional.normalize(reps, p=2, dim=-1)
59
 
60
- def tokenize(self, texts: List[Dict] | List[str]) -> Dict[str, torch.Tensor]:
61
  def process_text_item(item):
62
  if isinstance(item, str):
63
  return item, []
64
-
65
  text, images = "", []
66
  for sub_item in item:
67
  if sub_item["type"] == "text":
@@ -101,5 +106,5 @@ class MultiModalTransformer(BaseTransformer):
101
  max_length=self.max_seq_length,
102
  return_tensors="pt"
103
  )
104
-
105
- return inputs
 
1
+ from io import BytesIO
2
  from typing import Any, Dict, Optional, List
3
  import torch
4
  from PIL import Image
5
  from transformers import AutoProcessor, MllamaForConditionalGeneration
6
  from sentence_transformers.models import Transformer as BaseTransformer
7
 
8
+
9
  class MultiModalTransformer(BaseTransformer):
10
  def __init__(
11
+ self,
12
+ model_name_or_path: str,
13
+ cache_dir: Optional[str] = None,
14
+ tokenizer_args: Optional[Dict[str, Any]] = None,
15
+ **kwargs,
16
  ):
17
  super().__init__(model_name_or_path, **kwargs)
18
  if tokenizer_args is None:
19
  tokenizer_args = {}
20
+
21
+ # Initialize processor
22
  self.processor = AutoProcessor.from_pretrained(
23
  model_name_or_path, cache_dir=cache_dir, **tokenizer_args
24
  )
 
 
 
 
 
25
 
26
+ def _load_model(
27
+ self,
28
+ model_name_or_path: str,
29
+ config,
30
+ cache_dir: str,
31
+ backend: str,
32
+ is_peft_model: bool,
33
+ **model_args,
34
+ ) -> None:
35
+ self.auto_model = MllamaForConditionalGeneration.from_pretrained(
36
+ model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args
37
+ )
38
 
39
  def forward(
40
+ self, features: Dict[str, torch.Tensor], **kwargs
41
  ) -> Dict[str, torch.Tensor]:
42
  # Process inputs through the model
43
  outputs = self.auto_model(
 
46
  output_hidden_states=True,
47
  **kwargs
48
  )
49
+
50
  # Apply last pooling and normalization
51
  last_hidden_state = outputs.hidden_states[-1]
52
  attention_mask = features["attention_mask"]
53
  sentence_embedding = self._last_pooling(last_hidden_state, attention_mask)
54
+
55
  features.update({"sentence_embedding": sentence_embedding})
56
  return features
57
 
 
62
  reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
63
  return torch.nn.functional.normalize(reps, p=2, dim=-1)
64
 
65
+ def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]:
66
  def process_text_item(item):
67
  if isinstance(item, str):
68
  return item, []
69
+
70
  text, images = "", []
71
  for sub_item in item:
72
  if sub_item["type"] == "text":
 
106
  max_length=self.max_seq_length,
107
  return_tensors="pt"
108
  )
109
+
110
+ return inputs