gmastrapas commited on
Commit
6d8e609
·
1 Parent(s): 942a5da

fix: kwargs in custom Sentence Transformer

Browse files
Files changed (1) hide show
  1. custom_st.py +8 -2
custom_st.py CHANGED
@@ -22,6 +22,8 @@ class Transformer(nn.Module):
22
  model_kwargs: Optional[Dict[str, Any]] = None,
23
  tokenizer_kwargs: Optional[Dict[str, Any]] = None,
24
  image_processor_kwargs: Optional[Dict[str, Any]] = None,
 
 
25
  ) -> None:
26
  super(Transformer, self).__init__()
27
 
@@ -30,19 +32,23 @@ class Transformer(nn.Module):
30
  tokenizer_kwargs = tokenizer_kwargs or {}
31
  image_processor_kwargs = image_processor_kwargs or {}
32
 
33
- config = AutoConfig.from_pretrained(model_name_or_path, **config_kwargs)
 
 
34
  self.model = AutoModel.from_pretrained(
35
- model_name_or_path, config=config, **model_kwargs
36
  )
37
  if max_seq_length is not None and 'model_max_length' not in tokenizer_kwargs:
38
  tokenizer_kwargs['model_max_length'] = max_seq_length
39
 
40
  self.tokenizer = AutoTokenizer.from_pretrained(
41
  tokenizer_name_or_path or model_name_or_path,
 
42
  **tokenizer_kwargs,
43
  )
44
  self.image_processor = AutoImageProcessor.from_pretrained(
45
  image_processor_name_or_path or model_name_or_path,
 
46
  **image_processor_kwargs,
47
  )
48
 
 
22
  model_kwargs: Optional[Dict[str, Any]] = None,
23
  tokenizer_kwargs: Optional[Dict[str, Any]] = None,
24
  image_processor_kwargs: Optional[Dict[str, Any]] = None,
25
+ cache_dir: str = None,
26
+ **_,
27
  ) -> None:
28
  super(Transformer, self).__init__()
29
 
 
32
  tokenizer_kwargs = tokenizer_kwargs or {}
33
  image_processor_kwargs = image_processor_kwargs or {}
34
 
35
+ config = AutoConfig.from_pretrained(
36
+ model_name_or_path, cache_dir=cache_dir, **config_kwargs
37
+ )
38
  self.model = AutoModel.from_pretrained(
39
+ model_name_or_path, config=config, cache_dir=cache_dir, **model_kwargs
40
  )
41
  if max_seq_length is not None and 'model_max_length' not in tokenizer_kwargs:
42
  tokenizer_kwargs['model_max_length'] = max_seq_length
43
 
44
  self.tokenizer = AutoTokenizer.from_pretrained(
45
  tokenizer_name_or_path or model_name_or_path,
46
+ cache_dir=cache_dir,
47
  **tokenizer_kwargs,
48
  )
49
  self.image_processor = AutoImageProcessor.from_pretrained(
50
  image_processor_name_or_path or model_name_or_path,
51
+ cache_dir=cache_dir,
52
  **image_processor_kwargs,
53
  )
54