Thouph commited on
Commit
9086e4d
·
1 Parent(s): 9f65be4

Upload train_k.py

Browse files
Files changed (1) hide show
  1. train_k.py +4 -2
train_k.py CHANGED
@@ -156,8 +156,10 @@ for j in range(1, 179+1):
156
  transform = transforms
157
  )
158
 
159
-
160
- model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(config.ENCODER, config.DECODER)
 
 
161
 
162
 
163
  model.config.decoder_start_token_id = tokenizer.cls_token_id
 
156
  transform = transforms
157
  )
158
 
159
+ if os.path.exists('VIT_large_gpt2_model'):
160
+ model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained('VIT_large_gpt2_model')
161
+ else:
162
+ model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(config.ENCODER, config.DECODER)
163
 
164
 
165
  model.config.decoder_start_token_id = tokenizer.cls_token_id