Thouph commited on
Commit
997f017
·
1 Parent(s): 9d65247

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +25 -26
train.py CHANGED
@@ -64,17 +64,19 @@ def feature_extraction_fn(image_paths, check_image=True):
64
 
65
  return encoder_inputs.pixel_values
66
 
 
 
 
67
 
68
- def preprocess_fn(examples, max_target_length, check_image=True):
69
- """Run tokenization + image feature extraction"""
70
- image_paths = examples["image_path"]
71
- captions = examples['tags']
72
 
 
 
73
  model_inputs = {}
74
- # This contains image path column
75
- model_inputs['labels'] = tokenization_fn(captions, max_target_length)
76
- model_inputs['pixel_values'] = feature_extraction_fn(image_paths, check_image=check_image)
77
-
78
  return model_inputs
79
 
80
  def postprocess_text(preds, labels):
@@ -164,8 +166,6 @@ print(dataset)
164
  def add_image_path(example):
165
  image_name = [i + '.jpg' for i in example["image_id"]]
166
  folder_name=example["folder_name"]
167
- #image_name = example['image_id'] + '.jpg'
168
- #image_path = os.path.join(r"D:\dump384_224x224_384\384", image_name)
169
  image_path = [os.path.join(rf"/home/user/dump_small/{folder_name[i]}", image_name[i]) for i in range(len(image_name))]
170
  example['image_path'] = image_path
171
  return example
@@ -174,36 +174,35 @@ ds = dataset.map(add_image_path, batched=True, batch_size=8192)["train"]
174
  print(ds)
175
 
176
  ds = ds.train_test_split(test_size=0.02)
 
 
 
177
 
178
- print(ds['train'][0])
179
- processed_dataset = ds.map(
180
- function=preprocess_fn,
181
- batched=True,
182
- fn_kwargs={"max_target_length": 128},
183
- batch_size=8192,
184
- num_proc=16,
185
- #remove_columns=ds['train'].column_names
186
- )
187
 
188
  training_args = Seq2SeqTrainingArguments(
189
  predict_with_generate=True,
190
  evaluation_strategy="steps",
191
  eval_steps=100,
192
  gradient_accumulation_steps=4,
193
- per_device_train_batch_size=1,
194
  weight_decay=0.1,
195
- max_steps=1000,
196
  warmup_steps=1000,
197
  logging_strategy="steps",
198
- save_steps=200,
199
  fp16=True,
200
  tpu_num_cores=8,
201
- per_device_eval_batch_size=1,
202
  output_dir="image-captioning-output",
203
  learning_rate=5e-4,
204
  lr_scheduler_type="cosine",
205
  )
206
 
 
 
 
 
 
207
 
208
  metric = evaluate.load("rouge")
209
  ignore_pad_token_for_loss = True
@@ -214,9 +213,9 @@ trainer = Seq2SeqTrainer(
214
  tokenizer=feature_extractor,
215
  args=training_args,
216
  compute_metrics=compute_metrics,
217
- train_dataset=processed_dataset['train'],
218
- eval_dataset=processed_dataset['test'],
219
- data_collator=default_data_collator,
220
  )
221
 
222
 
 
64
 
65
  return encoder_inputs.pixel_values
66
 
67
+ def transform(example_batch):
68
+ # Take a list of PIL images and turn them to pixel values
69
+ inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
70
 
71
+ # Don't forget to include the labels!
72
+ inputs['labels'] = example_batch['labels']
73
+ return inputs
 
74
 
75
+ def preprocess_fn(example_batch):
76
+ """Run tokenization + image feature extraction"""
77
  model_inputs = {}
78
+ model_inputs['pixel_values'] = feature_extraction_fn([x for x in example_batch['image_path']])
79
+ model_inputs['labels'] = tokenization_fn([x for x in example_batch['tags']], 128)
 
 
80
  return model_inputs
81
 
82
  def postprocess_text(preds, labels):
 
166
  def add_image_path(example):
167
  image_name = [i + '.jpg' for i in example["image_id"]]
168
  folder_name=example["folder_name"]
 
 
169
  image_path = [os.path.join(rf"/home/user/dump_small/{folder_name[i]}", image_name[i]) for i in range(len(image_name))]
170
  example['image_path'] = image_path
171
  return example
 
174
  print(ds)
175
 
176
  ds = ds.train_test_split(test_size=0.02)
177
+ print(ds['train'][0:2])
178
+ ds.set_transform(preprocess_fn)
179
+ print(ds['train'][0:2])
180
 
 
 
 
 
 
 
 
 
 
181
 
182
  training_args = Seq2SeqTrainingArguments(
183
  predict_with_generate=True,
184
  evaluation_strategy="steps",
185
  eval_steps=100,
186
  gradient_accumulation_steps=4,
187
+ per_device_train_batch_size=128,
188
  weight_decay=0.1,
189
+ max_steps=10000,
190
  warmup_steps=1000,
191
  logging_strategy="steps",
192
+ save_steps=5000,
193
  fp16=True,
194
  tpu_num_cores=8,
195
+ per_device_eval_batch_size=128,
196
  output_dir="image-captioning-output",
197
  learning_rate=5e-4,
198
  lr_scheduler_type="cosine",
199
  )
200
 
201
+ def collate_fn(batch):
202
+ return {
203
+ 'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
204
+ 'labels': torch.tensor([x['labels'] for x in batch])
205
+ }
206
 
207
  metric = evaluate.load("rouge")
208
  ignore_pad_token_for_loss = True
 
213
  tokenizer=feature_extractor,
214
  args=training_args,
215
  compute_metrics=compute_metrics,
216
+ train_dataset=ds['train'],
217
+ eval_dataset=ds['test'],
218
+ data_collator=collate_fn,
219
  )
220
 
221