bczhou commited on
Commit
84f4f69
·
1 Parent(s): 9827c34

Update linear_mapping.py

Browse files
Files changed (1) hide show
  1. linear_mapping.py +9 -4
linear_mapping.py CHANGED
@@ -2,6 +2,7 @@ from config import LinearMappingConfig
2
  from transformers import (
3
  GPT2TokenizerFast, GPT2LMHeadModel, AutoModel,
4
  CLIPVisionModel, AutoProcessor, BatchEncoding,
 
5
  )
6
  from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
7
  import torch
@@ -104,9 +105,11 @@ class ImagePrefix(nn.Module):
104
 
105
  def __init__(self, config: LinearMappingConfig):
106
  super().__init__()
107
- self.encoder = AutoModel.from_pretrained(config.image_model)
108
- if "clip" in config.image_model:
109
- self.encoder = CLIPVisionModel.from_pretrained(config.image_model)
 
 
110
 
111
  if config.freeze_image_model:
112
  for param in self.encoder.parameters():
@@ -128,7 +131,9 @@ class LinearMapping(nn.Module):
128
  def __init__(self, config: LinearMappingConfig):
129
  super().__init__()
130
  self.image_prefix = ImagePrefix(config)
131
- self.language_model = GPT2LMHeadModel.from_pretrained(config.text_model)
 
 
132
  self.processor = LinearMappingProcessor(config)
133
  self.tokenizer = self.processor.tokenizer
134
  self.image_processor = self.processor.image_processor
 
2
  from transformers import (
3
  GPT2TokenizerFast, GPT2LMHeadModel, AutoModel,
4
  CLIPVisionModel, AutoProcessor, BatchEncoding,
5
+ AutoConfig, CLIPVisionConfig
6
  )
7
  from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
8
  import torch
 
105
 
106
  def __init__(self, config: LinearMappingConfig):
107
  super().__init__()
108
+ clip_config = CLIPVisionConfig.from_pretrained(config.image_model)
109
+
110
+ self.encoder = CLIPVisionModel(clip_config)
111
+ if config.image_from_pretrained:
112
+ self.encoder = self.encoder.from_pretrained(config.image_model)
113
 
114
  if config.freeze_image_model:
115
  for param in self.encoder.parameters():
 
131
  def __init__(self, config: LinearMappingConfig):
132
  super().__init__()
133
  self.image_prefix = ImagePrefix(config)
134
+ self.language_model = GPT2LMHeadModel(AutoConfig.from_pretrained(config.text_model))
135
+ if config.text_from_pretrained:
136
+ self.language_model = self.language_model.from_pretrained(config.text_model)
137
  self.processor = LinearMappingProcessor(config)
138
  self.tokenizer = self.processor.tokenizer
139
  self.image_processor = self.processor.image_processor