Upload train_k.py
Browse files- train_k.py +4 -2
train_k.py
CHANGED
@@ -156,8 +156,10 @@ for j in range(1, 179+1):
|
|
156 |
transform = transforms
|
157 |
)
|
158 |
|
159 |
-
|
160 |
-
|
|
|
|
|
161 |
|
162 |
|
163 |
model.config.decoder_start_token_id = tokenizer.cls_token_id
|
|
|
156 |
transform = transforms
|
157 |
)
|
158 |
|
159 |
+
if os.path.exists('VIT_large_gpt2_model'):
|
160 |
+
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained('VIT_large_gpt2_model')
|
161 |
+
else:
|
162 |
+
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(config.ENCODER, config.DECODER)
|
163 |
|
164 |
|
165 |
model.config.decoder_start_token_id = tokenizer.cls_token_id
|