Update README.md
Browse files
README.md
CHANGED
@@ -28,9 +28,11 @@ To load a pre-trained model ("VisionTransformer.pt"), use the following code sni
|
|
28 |
|
29 |
```python
|
30 |
import clip
|
31 |
-
from clip.downstream_task import TaskType
|
32 |
import torch # Make sure to import torch
|
33 |
|
|
|
|
|
|
|
34 |
device = "cpu" # Change to 'cuda' if you have a GPU
|
35 |
num_classes = 4 # Number of classes in the original HYPERVIEW dataset
|
36 |
|
@@ -43,4 +45,28 @@ model, _ = clip.load(
|
|
43 |
# Load the pre-trained weights
|
44 |
model.load_state_dict(torch.load("VisionTransformer.pt"))
|
45 |
model.eval() # Set the model to evaluation mode
|
46 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
```python
|
30 |
import clip
|
|
|
31 |
import torch # Make sure to import torch
|
32 |
|
33 |
+
from clip.downstream_task import TaskType
|
34 |
+
|
35 |
+
|
36 |
device = "cpu" # Change to 'cuda' if you have a GPU
|
37 |
num_classes = 4 # Number of classes in the original HYPERVIEW dataset
|
38 |
|
|
|
45 |
# Load the pre-trained weights
|
46 |
model.load_state_dict(torch.load("VisionTransformer.pt"))
|
47 |
model.eval() # Set the model to evaluation mode
|
48 |
+
```
|
49 |
+
|
50 |
+
### Loading training data
|
51 |
+
|
52 |
+
```python
|
53 |
+
import numpy as np
|
54 |
+
|
55 |
+
from clip.hyperview_data_loader import HyperDataloader, DataReader
|
56 |
+
|
57 |
+
|
58 |
+
im_size = 224 # Image size
|
59 |
+
num_classes = 4 # Number of classes in the original HYPERVIEW dataset
|
60 |
+
|
61 |
+
# Paths to training data and ground truth
|
62 |
+
train_path = "<TRAIN_PATH>"
|
63 |
+
train_gt_path = "<TRAIN_PATH>/train_gt.csv"
|
64 |
+
|
65 |
+
# Initialize the dataset reader and transformations
|
66 |
+
target_index = list(np.arange(num_classes))
|
67 |
+
trans_tr, _ = HyperDataloader._init_transform(im_size)
|
68 |
+
train_dataset = DataReader(
|
69 |
+
database_dir=train_path, label_paths=train_gt_path,
|
70 |
+
transform=trans_tr, target_index=target_index
|
71 |
+
)
|
72 |
+
````
|