bczhou commited on
Commit
89f73ad
·
1 Parent(s): 0f673fc

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -116
main.py DELETED
@@ -1,116 +0,0 @@
1
- from datasets import load_dataset
2
- from linear_mapping import LinearMapping, LinearMappingProcessor, LinearMappingConfig, Transform
3
- import torch
4
- from torchvision.io import ImageReadMode, read_image
5
- from transformers import Trainer, TrainingArguments
6
- import os
7
- from PIL import Image
8
- os.environ["WANDB_DISABLED"] = "true"
9
-
10
- DATA_DIR = os.path.join(os.getcwd(), "coco")
11
- CAPTION_COLUMN = "caption"
12
- IMAGE_COLUMN = "image_path"
13
-
14
-
15
- def main():
16
- ds = load_dataset("ydshieh/coco_dataset_script", "2017", DATA_DIR)
17
- config = LinearMappingConfig()
18
- processor = LinearMappingProcessor(config)
19
-
20
- def collate_fn(batch):
21
- return {
22
- 'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
23
- 'input_ids': torch.tensor([x['input_ids'] for x in batch], dtype=torch.long),
24
- 'attention_mask': torch.stack([x["attention_mask"] for x in batch]),
25
- }
26
-
27
- def tokenize_fn(examples):
28
- texts = list(examples[CAPTION_COLUMN])
29
- if config.add_image_token:
30
- texts = list(processor.tokenizer.cls_token + text for text in texts)
31
- inputs = processor.tokenizer(
32
- texts, padding="max_length", max_length=77,
33
- return_tensors="pt", truncation=True
34
- )
35
- examples["input_ids"] = inputs.input_ids
36
- examples["attention_mask"] = inputs.attention_mask
37
- return examples
38
-
39
- image_transformations = Transform(
40
- config.image_resize,
41
- [0.48145466, 0.4578275, 0.40821073],
42
- [0.26862954, 0.26130258, 0.27577711]
43
- )
44
- image_transformations = torch.jit.script(image_transformations)
45
-
46
- def transform_images(examples):
47
- images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[IMAGE_COLUMN]]
48
- examples["pixel_values"] = [image_transformations(image) for image in images]
49
-
50
- examples["attention_mask"] = torch.cat([
51
- torch.ones(len(images), config.prefix_length),
52
- torch.tensor(examples["attention_mask"])
53
- ], dim=1).to(dtype=torch.long)
54
- return examples
55
-
56
- def preprocess_fn(examples):
57
-
58
- texts = list(examples[CAPTION_COLUMN])
59
-
60
- images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[IMAGE_COLUMN]]
61
- inputs = processor(
62
- texts=texts, images=images, padding="max_length", truncation=True, max_length=77, return_tensors="pt"
63
- )
64
- return inputs
65
-
66
- def filter_corrupt_images(examples):
67
- """remove problematic images"""
68
- valid_images = []
69
- for image_file in examples[IMAGE_COLUMN]:
70
- try:
71
- Image.open(image_file)
72
- valid_images.append(True)
73
- except Exception:
74
- valid_images.append(False)
75
- return valid_images
76
-
77
- train_dataset = ds["train"]
78
-
79
- train_dataset = train_dataset.filter(
80
- function=filter_corrupt_images,
81
- batched=True
82
- )
83
- train_dataset = train_dataset.map(
84
- function=tokenize_fn,
85
- batched=True,
86
- remove_columns=[col for col in train_dataset.column_names if col != IMAGE_COLUMN and col != CAPTION_COLUMN],
87
- load_from_cache_file=True
88
- )
89
- train_dataset.set_transform(transform_images)
90
-
91
- training_args = TrainingArguments(
92
- learning_rate=5e-4,
93
- lr_scheduler_type='cosine',
94
- output_dir='clip-gpt2-image-captioner',
95
- do_train=True,
96
- logging_steps=50,
97
- num_train_epochs=5,
98
- logging_dir='runs',
99
- remove_unused_columns=False,
100
- max_grad_norm=1.0,
101
- per_device_train_batch_size=16,
102
- save_total_limit=3,
103
- warmup_steps=500
104
- )
105
- model = LinearMapping(config)
106
- trainer = Trainer(
107
- model=model,
108
- args=training_args,
109
- train_dataset=train_dataset,
110
- data_collator=collate_fn
111
- )
112
- trainer.train()
113
-
114
-
115
- if __name__ == '__main__':
116
- main()