Tanusree88 commited on
Commit
9594f1f
·
verified ·
1 Parent(s): c90d799

Create traininginVIT

Browse files
Files changed (1) hide show
  1. traininginVIT +56 -0
traininginVIT ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader, Dataset
2
+ import torch
3
+ from transformers import ViTForImageClassification, AdamW
4
+
5
+ # Custom dataset class for loading images
6
+ class MRIDataset(Dataset):
7
+ def __init__(self, image_paths, labels):
8
+ self.image_paths = image_paths
9
+ self.labels = labels
10
+
11
+ def __len__(self):
12
+ return len(self.image_paths)
13
+
14
+ def __getitem__(self, idx):
15
+ image = preprocess_image(self.image_paths[idx])
16
+ label = torch.tensor(self.labels[idx])
17
+ return image, label
18
+
19
+ # Load your ViT model and processor
20
+ model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=3)
21
+ processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
22
+
23
+ # Move the model to the device (GPU if available)
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ model.to(device)
26
+
27
+ # Define optimizer and loss function
28
+ optimizer = AdamW(model.parameters(), lr=1e-4)
29
+ criterion = torch.nn.CrossEntropyLoss()
30
+
31
+ # Load your dataset
32
+ image_paths = ["path_to_image1.npy", "path_to_image2.npy"] # Update with actual image paths
33
+ labels = [0, 1] # Corresponding labels
34
+ dataset = MRIDataset(image_paths, labels)
35
+ data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
36
+
37
+ # Fine-tuning loop
38
+ num_epochs = 10
39
+ for epoch in range(num_epochs):
40
+ model.train()
41
+ total_loss = 0
42
+ for images, labels in data_loader:
43
+ images, labels = images.to(device), labels.to(device)
44
+
45
+ optimizer.zero_grad()
46
+ outputs = model(pixel_values=images).logits
47
+ loss = criterion(outputs, labels)
48
+ loss.backward()
49
+ optimizer.step()
50
+
51
+ total_loss += loss.item()
52
+
53
+ print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(data_loader)}')
54
+
55
+ # Save the fine-tuned model
56
+ torch.save(model.state_dict(), 'vit_finetuned.pth')