Upload train.py
Browse files
train.py
CHANGED
@@ -64,17 +64,19 @@ def feature_extraction_fn(image_paths, check_image=True):
|
|
64 |
|
65 |
return encoder_inputs.pixel_values
|
66 |
|
|
|
|
|
|
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
captions = examples['tags']
|
72 |
|
|
|
|
|
73 |
model_inputs = {}
|
74 |
-
|
75 |
-
model_inputs['labels'] = tokenization_fn(
|
76 |
-
model_inputs['pixel_values'] = feature_extraction_fn(image_paths, check_image=check_image)
|
77 |
-
|
78 |
return model_inputs
|
79 |
|
80 |
def postprocess_text(preds, labels):
|
@@ -164,8 +166,6 @@ print(dataset)
|
|
164 |
def add_image_path(example):
|
165 |
image_name = [i + '.jpg' for i in example["image_id"]]
|
166 |
folder_name=example["folder_name"]
|
167 |
-
#image_name = example['image_id'] + '.jpg'
|
168 |
-
#image_path = os.path.join(r"D:\dump384_224x224_384\384", image_name)
|
169 |
image_path = [os.path.join(rf"/home/user/dump_small/{folder_name[i]}", image_name[i]) for i in range(len(image_name))]
|
170 |
example['image_path'] = image_path
|
171 |
return example
|
@@ -174,36 +174,35 @@ ds = dataset.map(add_image_path, batched=True, batch_size=8192)["train"]
|
|
174 |
print(ds)
|
175 |
|
176 |
ds = ds.train_test_split(test_size=0.02)
|
|
|
|
|
|
|
177 |
|
178 |
-
print(ds['train'][0])
|
179 |
-
processed_dataset = ds.map(
|
180 |
-
function=preprocess_fn,
|
181 |
-
batched=True,
|
182 |
-
fn_kwargs={"max_target_length": 128},
|
183 |
-
batch_size=8192,
|
184 |
-
num_proc=16,
|
185 |
-
#remove_columns=ds['train'].column_names
|
186 |
-
)
|
187 |
|
188 |
training_args = Seq2SeqTrainingArguments(
|
189 |
predict_with_generate=True,
|
190 |
evaluation_strategy="steps",
|
191 |
eval_steps=100,
|
192 |
gradient_accumulation_steps=4,
|
193 |
-
per_device_train_batch_size=
|
194 |
weight_decay=0.1,
|
195 |
-
max_steps=
|
196 |
warmup_steps=1000,
|
197 |
logging_strategy="steps",
|
198 |
-
save_steps=
|
199 |
fp16=True,
|
200 |
tpu_num_cores=8,
|
201 |
-
per_device_eval_batch_size=
|
202 |
output_dir="image-captioning-output",
|
203 |
learning_rate=5e-4,
|
204 |
lr_scheduler_type="cosine",
|
205 |
)
|
206 |
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
metric = evaluate.load("rouge")
|
209 |
ignore_pad_token_for_loss = True
|
@@ -214,9 +213,9 @@ trainer = Seq2SeqTrainer(
|
|
214 |
tokenizer=feature_extractor,
|
215 |
args=training_args,
|
216 |
compute_metrics=compute_metrics,
|
217 |
-
train_dataset=
|
218 |
-
eval_dataset=
|
219 |
-
data_collator=
|
220 |
)
|
221 |
|
222 |
|
|
|
64 |
|
65 |
return encoder_inputs.pixel_values
|
66 |
|
67 |
+
def transform(example_batch):
|
68 |
+
# Take a list of PIL images and turn them to pixel values
|
69 |
+
inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
|
70 |
|
71 |
+
# Don't forget to include the labels!
|
72 |
+
inputs['labels'] = example_batch['labels']
|
73 |
+
return inputs
|
|
|
74 |
|
75 |
+
def preprocess_fn(example_batch):
|
76 |
+
"""Run tokenization + image feature extraction"""
|
77 |
model_inputs = {}
|
78 |
+
model_inputs['pixel_values'] = feature_extraction_fn([x for x in example_batch['image_path']])
|
79 |
+
model_inputs['labels'] = tokenization_fn([x for x in example_batch['tags']], 128)
|
|
|
|
|
80 |
return model_inputs
|
81 |
|
82 |
def postprocess_text(preds, labels):
|
|
|
166 |
def add_image_path(example):
|
167 |
image_name = [i + '.jpg' for i in example["image_id"]]
|
168 |
folder_name=example["folder_name"]
|
|
|
|
|
169 |
image_path = [os.path.join(rf"/home/user/dump_small/{folder_name[i]}", image_name[i]) for i in range(len(image_name))]
|
170 |
example['image_path'] = image_path
|
171 |
return example
|
|
|
174 |
print(ds)
|
175 |
|
176 |
ds = ds.train_test_split(test_size=0.02)
|
177 |
+
print(ds['train'][0:2])
|
178 |
+
ds.set_transform(preprocess_fn)
|
179 |
+
print(ds['train'][0:2])
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
training_args = Seq2SeqTrainingArguments(
|
183 |
predict_with_generate=True,
|
184 |
evaluation_strategy="steps",
|
185 |
eval_steps=100,
|
186 |
gradient_accumulation_steps=4,
|
187 |
+
per_device_train_batch_size=128,
|
188 |
weight_decay=0.1,
|
189 |
+
max_steps=10000,
|
190 |
warmup_steps=1000,
|
191 |
logging_strategy="steps",
|
192 |
+
save_steps=5000,
|
193 |
fp16=True,
|
194 |
tpu_num_cores=8,
|
195 |
+
per_device_eval_batch_size=128,
|
196 |
output_dir="image-captioning-output",
|
197 |
learning_rate=5e-4,
|
198 |
lr_scheduler_type="cosine",
|
199 |
)
|
200 |
|
201 |
+
def collate_fn(batch):
|
202 |
+
return {
|
203 |
+
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
|
204 |
+
'labels': torch.tensor([x['labels'] for x in batch])
|
205 |
+
}
|
206 |
|
207 |
metric = evaluate.load("rouge")
|
208 |
ignore_pad_token_for_loss = True
|
|
|
213 |
tokenizer=feature_extractor,
|
214 |
args=training_args,
|
215 |
compute_metrics=compute_metrics,
|
216 |
+
train_dataset=ds['train'],
|
217 |
+
eval_dataset=ds['test'],
|
218 |
+
data_collator=collate_fn,
|
219 |
)
|
220 |
|
221 |
|