Update linear_mapping.py
Browse files- linear_mapping.py +9 -4
linear_mapping.py
CHANGED
@@ -2,6 +2,7 @@ from config import LinearMappingConfig
|
|
2 |
from transformers import (
|
3 |
GPT2TokenizerFast, GPT2LMHeadModel, AutoModel,
|
4 |
CLIPVisionModel, AutoProcessor, BatchEncoding,
|
|
|
5 |
)
|
6 |
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
|
7 |
import torch
|
@@ -104,9 +105,11 @@ class ImagePrefix(nn.Module):
|
|
104 |
|
105 |
def __init__(self, config: LinearMappingConfig):
|
106 |
super().__init__()
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
110 |
|
111 |
if config.freeze_image_model:
|
112 |
for param in self.encoder.parameters():
|
@@ -128,7 +131,9 @@ class LinearMapping(nn.Module):
|
|
128 |
def __init__(self, config: LinearMappingConfig):
|
129 |
super().__init__()
|
130 |
self.image_prefix = ImagePrefix(config)
|
131 |
-
self.language_model = GPT2LMHeadModel.from_pretrained(config.text_model)
|
|
|
|
|
132 |
self.processor = LinearMappingProcessor(config)
|
133 |
self.tokenizer = self.processor.tokenizer
|
134 |
self.image_processor = self.processor.image_processor
|
|
|
2 |
from transformers import (
|
3 |
GPT2TokenizerFast, GPT2LMHeadModel, AutoModel,
|
4 |
CLIPVisionModel, AutoProcessor, BatchEncoding,
|
5 |
+
AutoConfig, CLIPVisionConfig
|
6 |
)
|
7 |
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
|
8 |
import torch
|
|
|
105 |
|
106 |
def __init__(self, config: LinearMappingConfig):
|
107 |
super().__init__()
|
108 |
+
clip_config = CLIPVisionConfig.from_pretrained(config.image_model)
|
109 |
+
|
110 |
+
self.encoder = CLIPVisionModel(clip_config)
|
111 |
+
if config.image_from_pretrained:
|
112 |
+
self.encoder = self.encoder.from_pretrained(config.image_model)
|
113 |
|
114 |
if config.freeze_image_model:
|
115 |
for param in self.encoder.parameters():
|
|
|
131 |
def __init__(self, config: LinearMappingConfig):
|
132 |
super().__init__()
|
133 |
self.image_prefix = ImagePrefix(config)
|
134 |
+
self.language_model = GPT2LMHeadModel(AutoConfig.from_pretrained(config.text_model))
|
135 |
+
if config.text_from_pretrained:
|
136 |
+
self.language_model = self.language_model.from_pretrained(config.text_model)
|
137 |
self.processor = LinearMappingProcessor(config)
|
138 |
self.tokenizer = self.processor.tokenizer
|
139 |
self.image_processor = self.processor.image_processor
|