anas-awadalla commited on
Commit
05bbe54
·
1 Parent(s): b2e96cb

Update open_flamingo/open_flamingo/src/factory.py

Browse files
open_flamingo/open_flamingo/src/factory.py CHANGED
@@ -15,6 +15,7 @@ def create_model_and_transforms(
15
  use_local_files: bool = False,
16
  decoder_layers_attr_name: str = None,
17
  freeze_lm_embeddings: bool = False,
 
18
  **flamingo_kwargs,
19
  ):
20
  """
@@ -39,6 +40,7 @@ def create_model_and_transforms(
39
  )
40
  # set the vision encoder to output the visual features
41
  vision_encoder.visual.output_tokens = True
 
42
 
43
  text_tokenizer = AutoTokenizer.from_pretrained(
44
  tokenizer_path,
@@ -58,7 +60,8 @@ def create_model_and_transforms(
58
  lang_encoder_path,
59
  local_files_only=use_local_files,
60
  trust_remote_code=True,
61
- )
 
62
 
63
  # hacks for MPT-1B, which doesn't have a get_input_embeddings method
64
  if "mpt-1b-redpajama-200b" in lang_encoder_path:
@@ -90,7 +93,7 @@ def create_model_and_transforms(
90
  ],
91
  cross_attn_every_n_layers=cross_attn_every_n_layers,
92
  **flamingo_kwargs,
93
- )
94
 
95
  # Freeze all parameters
96
  model.requires_grad_(False)
 
15
  use_local_files: bool = False,
16
  decoder_layers_attr_name: str = None,
17
  freeze_lm_embeddings: bool = False,
18
+ device: int = 0,
19
  **flamingo_kwargs,
20
  ):
21
  """
 
40
  )
41
  # set the vision encoder to output the visual features
42
  vision_encoder.visual.output_tokens = True
43
+ vision_encoder.to(device, dtype=torch.bfloat16) if device > -1 else None
44
 
45
  text_tokenizer = AutoTokenizer.from_pretrained(
46
  tokenizer_path,
 
60
  lang_encoder_path,
61
  local_files_only=use_local_files,
62
  trust_remote_code=True,
63
+
64
+ ).to(device, dtype=torch.bfloat16) if device > -1 else None
65
 
66
  # hacks for MPT-1B, which doesn't have a get_input_embeddings method
67
  if "mpt-1b-redpajama-200b" in lang_encoder_path:
 
93
  ],
94
  cross_attn_every_n_layers=cross_attn_every_n_layers,
95
  **flamingo_kwargs,
96
+ ).to(device, dtype=torch.bfloat16) if device > -1 else None
97
 
98
  # Freeze all parameters
99
  model.requires_grad_(False)