Tanusree88 commited on
Commit
c309443
·
verified ·
1 Parent(s): 39139d0

Delete traininginVIT

Browse files
Files changed (1) hide show
  1. traininginVIT +0 -79
traininginVIT DELETED
@@ -1,79 +0,0 @@
1
- from torch.utils.data import DataLoader, Dataset
2
- import torch
3
- from transformers import ViTForImageClassification, AdamW
4
- import os
5
- import numpy as np
6
- import torch
7
- import streamlit as st
8
- from transformers import ViTForImageClassification, ViTImageProcessor
9
-
10
- # Custom dataset class for loading images
11
- class MRIDataset(Dataset):
12
- def __init__(self, image_paths, labels):
13
- self.image_paths = image_paths
14
- self.labels = labels
15
-
16
- def __len__(self):
17
- return len(self.image_paths)
18
-
19
- def __getitem__(self, idx):
20
- image = preprocess_image(self.image_paths[idx])
21
- label = torch.tensor(self.labels[idx])
22
- return image, label
23
-
24
- # Load your ViT model and processor
25
- model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=3)
26
- processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
27
-
28
- # Move the model to the device (GPU if available)
29
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
- model.to(device)
31
-
32
- # Define optimizer and loss function
33
- optimizer = AdamW(model.parameters(), lr=1e-4)
34
- criterion = torch.nn.CrossEntropyLoss()
35
-
36
- # Load your dataset
37
- image_paths = ["path_to_image1.npy", "path_to_image2.npy"] # Update with actual image paths
38
- labels = [0, 1] # Corresponding labels
39
- dataset = MRIDataset(image_paths, labels)
40
- data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
41
-
42
- # Fine-tuning loop
43
- num_epochs = 10
44
- for epoch in range(num_epochs):
45
- model.train()
46
- total_loss = 0
47
- for images, labels in data_loader:
48
- images, labels = images.to(device), labels.to(device)
49
-
50
- optimizer.zero_grad()
51
- outputs = model(pixel_values=images).logits
52
- loss = criterion(outputs, labels)
53
- loss.backward()
54
- optimizer.step()
55
-
56
- total_loss += loss.item()
57
-
58
- print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(data_loader)}')
59
-
60
- # Save the fine-tuned model
61
- torch.save(model.state_dict(), 'vit_finetuned.pth')
62
-
63
- def fine_tune_model():
64
- # Your fine-tuning logic goes here (using the ViT model)
65
- num_epochs = 10
66
- running_loss = 0.0
67
- for epoch in range(num_epochs):
68
- # Fine-tuning loop (train the model)
69
- # ...
70
- running_loss += 0.5 # Just a placeholder for demo purposes
71
- return running_loss # Return the final loss after training
72
-
73
- # Streamlit UI to trigger fine-tuning and display results
74
- st.title("MRI Image Fine-Tuning with ViT")
75
-
76
- if st.button("Start Training"):
77
- # Run the fine-tuning loop when the button is clicked
78
- final_loss = fine_tune_model() # Call the function where your fine-tuning loop is
79
- st.write(f"Training complete with final loss: {final_loss}")