DGurgurov commited on
Commit
912047e
·
verified ·
1 Parent(s): e9434d7

Upload clip_finetune.py

Browse files
Files changed (1) hide show
  1. clip_finetune.py +238 -0
clip_finetune.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from functools import partial
4
+ from typing import Any
5
+
6
+ import evaluate
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from datasets import Dataset, DatasetDict, load_dataset
11
+ from torch.utils.data import DataLoader
12
+ from tqdm.notebook import tqdm
13
+ from transformers import (CLIPImageProcessor, CLIPModel, CLIPProcessor,
14
+ CLIPTokenizerFast, Trainer, TrainingArguments)
15
+ from datasets.formatting.formatting import LazyBatch
16
+ from huggingface_hub import HfApi, login, create_repo
17
+
18
+ # Environment settings
19
+ os.environ["CURL_CA_BUNDLE"] = ""
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+
22
+ # Seed setting
23
+ def seed_all(seed: int):
24
+ random.seed(seed)
25
+ torch.manual_seed(seed)
26
+ np.random.seed(seed)
27
+
28
+ seed_all(69)
29
+
30
+ # Dataset preparation
31
+ dataset = load_dataset("pcuenq/oxford-pets")
32
+ dataset_train_val = dataset['train'].train_test_split(test_size=0.3)
33
+ dataset_val_test = dataset_train_val['test'].train_test_split(test_size=0.2)
34
+ dataset = DatasetDict({
35
+ "train": dataset_train_val['train'],
36
+ "val": dataset_val_test['test'],
37
+ "test": dataset_val_test['train']
38
+ })
39
+
40
+ labels = set(dataset['train']['label'])
41
+ label2id = {label: i for i, label in enumerate(labels)}
42
+ id2label = {i: label for label, i in label2id.items()}
43
+ labels = list(label2id)
44
+
45
+ MODEL_NAME = "openai/clip-vit-base-patch32"
46
+ TOKENIZER = CLIPTokenizerFast.from_pretrained(MODEL_NAME)
47
+ IMAGE_PROCESSOR = CLIPImageProcessor.from_pretrained(MODEL_NAME)
48
+
49
+ # Transformation functions
50
+ def transform_class_labels(items: LazyBatch, tokenizer: CLIPTokenizerFast, label2id: dict[str, int]) -> dict[str, Any]:
51
+ label_prompt = [f"a photo of {label}" for label in items["label"]]
52
+ output = tokenizer(label_prompt, padding=True, return_tensors="pt")
53
+ items["input_ids"] = output["input_ids"]
54
+ items["attention_mask"] = output["attention_mask"]
55
+ items["label_id"] = [label2id[label] for label in items["label"]]
56
+ return items
57
+
58
+ def transform_image(items: LazyBatch, image_processor: CLIPImageProcessor) -> dict[str, Any]:
59
+ output = image_processor(items["image"], return_tensors="pt")
60
+ items["pixel_values"] = output["pixel_values"]
61
+ return items
62
+
63
+ dataset = dataset.map(partial(transform_class_labels, tokenizer=TOKENIZER, label2id=label2id), batched=True)
64
+ dataset.set_transform(partial(transform_image, image_processor=IMAGE_PROCESSOR))
65
+
66
+ # Utility functions
67
+ def get_module_device(module: nn.Module) -> torch.device:
68
+ return next(module.parameters()).device
69
+
70
+ def freeze_params(module: nn.Module, freeze_top_percent: float = 1.0) -> None:
71
+ all_params_length = len(list(module.parameters()))
72
+ for indx, param in enumerate(module.parameters()):
73
+ if int(all_params_length * freeze_top_percent) <= indx:
74
+ break
75
+ param.requires_grad = False
76
+
77
+ def print_trainable_parameters(model: nn.Module) -> None:
78
+ trainable_params = 0
79
+ all_param = 0
80
+ for _, param in model.named_parameters():
81
+ all_param += param.numel()
82
+ if param.requires_grad:
83
+ trainable_params += param.numel()
84
+ print(
85
+ f"Trainable params: {(trainable_params / 10**6):.4f}M || All params: {(all_param / 10**6):.4f}M || Trainable%: {100 * trainable_params / all_param:.2f}%"
86
+ )
87
+
88
+ # CLIP Classifier model
89
+ class CLIPClassifier(nn.Module):
90
+ def __init__(self, clip_model: CLIPModel, tokenizer: CLIPTokenizerFast, labels: list[str]):
91
+ super().__init__()
92
+ self.model = clip_model
93
+ self.tokenizer = tokenizer
94
+ self.logit_scale = self.model.logit_scale.exp()
95
+ self.label2id = {label: i for i, label in enumerate(labels)}
96
+ self.labels_embeddings = nn.Parameter(self.generate_labels_embeddings(labels))
97
+
98
+ def generate_labels_embeddings(self, labels: list[str]) -> torch.Tensor:
99
+ labels_inputs = self.tokenizer(
100
+ [f"a photo of {label}" for label in labels],
101
+ return_tensors="pt",
102
+ padding=True,
103
+ ).to(get_module_device(self.model))
104
+ labels_embeddings = self.model.get_text_features(**labels_inputs)
105
+ labels_embeddings /= labels_embeddings.norm(p=2, dim=-1, keepdim=True)
106
+ return labels_embeddings
107
+
108
+ def forward(self, images: torch.Tensor) -> torch.Tensor:
109
+ image_features = self.model.get_image_features(images)
110
+ image_features /= image_features.norm(p=2, dim=-1, keepdim=True)
111
+ return torch.matmul(image_features, self.labels_embeddings.T) * self.logit_scale
112
+
113
+ # Evaluation function
114
+ def calculate_accuracy(model: CLIPClassifier, dataloader: DataLoader) -> float:
115
+ metric = evaluate.load("accuracy")
116
+ predictions_list = []
117
+ references_list = []
118
+ device = get_module_device(model)
119
+ for batch in tqdm(dataloader, total=len(dataloader), desc="Evaluate model on dataset"):
120
+ batch["pixel_values"] = batch["pixel_values"].to(device)
121
+ predictions = model(batch["pixel_values"])
122
+ predictions_list.append(torch.argmax(predictions, dim=1))
123
+ references_list.append(batch["label_id"])
124
+ return metric.compute(
125
+ predictions=torch.concat(predictions_list),
126
+ references=torch.concat(references_list),
127
+ )["accuracy"]
128
+
129
+ def collate_fn(items: LazyBatch) -> dict[str, Any]:
130
+ return {
131
+ "pixel_values": torch.stack([item["pixel_values"] for item in items]),
132
+ "input_ids": torch.tensor([item["input_ids"] for item in items]),
133
+ "attention_mask": torch.tensor([item["attention_mask"] for item in items]),
134
+ "label_id": torch.tensor([item["label_id"] for item in items]),
135
+ "return_loss": True,
136
+ }
137
+
138
+ @torch.no_grad()
139
+ def evaluate_clip_classifier(
140
+ model: nn.Module,
141
+ dataset: Dataset,
142
+ tokenizer: CLIPTokenizerFast,
143
+ labels: list[str],
144
+ batch_size: int = 64,
145
+ num_workers: int = 5,
146
+ device: str = "cuda",
147
+ ) -> None:
148
+ clip_classifier = CLIPClassifier(model, tokenizer, labels)
149
+ test_dataloader = DataLoader(
150
+ dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn
151
+ )
152
+ clip_classifier = clip_classifier.to(device)
153
+ acc = calculate_accuracy(clip_classifier, test_dataloader)
154
+ print(f"Model accuracy: {acc}")
155
+
156
+ def collate_train_fn(items: LazyBatch):
157
+ items = collate_fn(items)
158
+ items.pop("label_id")
159
+ return items
160
+
161
+ def get_default_training_args(
162
+ experiment_name: str,
163
+ lr: float,
164
+ batch_size: int = 256,
165
+ num_epoch: int = 4,
166
+ num_workers: int = 15,
167
+ ) -> TrainingArguments:
168
+ return TrainingArguments(
169
+ experiment_name,
170
+ per_device_train_batch_size=batch_size,
171
+ learning_rate=lr,
172
+ num_train_epochs=num_epoch,
173
+ per_device_eval_batch_size=batch_size,
174
+ gradient_accumulation_steps=1,
175
+ logging_steps=10,
176
+ save_total_limit=2,
177
+ evaluation_strategy="epoch",
178
+ save_strategy="epoch",
179
+ fp16=True,
180
+ remove_unused_columns=False,
181
+ load_best_model_at_end=True,
182
+ dataloader_num_workers=num_workers,
183
+ )
184
+
185
+ # Training
186
+ clip_full_finetuned = CLIPModel.from_pretrained(MODEL_NAME)
187
+ trainer = Trainer(
188
+ model=clip_full_finetuned,
189
+ args=get_default_training_args("clip-all-layers-tuning-oxford-pets", 3e-6),
190
+ data_collator=collate_train_fn,
191
+ train_dataset=dataset["train"],
192
+ eval_dataset=dataset["val"],
193
+ )
194
+
195
+ trainer.train()
196
+
197
+ print_trainable_parameters(clip_full_finetuned)
198
+ evaluate_clip_classifier(clip_full_finetuned, dataset['test'], TOKENIZER, labels)
199
+
200
+ # Hugging Face Hub interaction
201
+ login(token='TOKEN')
202
+ api = HfApi()
203
+ repo_url = create_repo(repo_id="DGurgurov/clip-vit-base-patch32-oxford-pets", exist_ok=True)
204
+ print(f"Repository created at: {repo_url}")
205
+
206
+ api.upload_folder(
207
+ folder_path=f'clip-all-layers-tuning-oxford-pets/checkpoint-84',
208
+ path_in_repo='',
209
+ repo_id='DGurgurov/clip-vit-base-patch32-oxford-pets'
210
+ )
211
+
212
+ # README creation
213
+ readme_content = f"""
214
+ # CLIP ViT Base Patch32 Fine-tuned on Oxford Pets
215
+
216
+ This model is a fine-tuned version of OpenAI's CLIP model on the Oxford Pets dataset.
217
+
218
+ ## Training Information
219
+
220
+ - **Model Name**: openai/clip-vit-base-patch32
221
+ - **Dataset**: oxford-pets
222
+ - **Training Epochs**: 4
223
+ - **Batch Size**: 256
224
+ - **Learning Rate**: 3e-6
225
+ - **Accuracy**: 93.74%
226
+
227
+ ## License
228
+ [MIT]
229
+ """
230
+
231
+ with open(f'clip-all-layers-tuning-oxford-pets/checkpoint-84/README.md', 'w') as f:
232
+ f.write(readme_content)
233
+
234
+ api.upload_file(
235
+ path_or_fileobj=f'clip-all-layers-tuning-oxford-pets/checkpoint-84/README.md',
236
+ path_in_repo='README.md',
237
+ repo_id='DGurgurov/clip-vit-base-patch32-oxford-pets'
238
+ )