KPLabs commited on
Commit
b8ec6d0
·
verified ·
1 Parent(s): f6022d4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +28 -2
README.md CHANGED
@@ -28,9 +28,11 @@ To load a pre-trained model ("VisionTransformer.pt"), use the following code sni
28
 
29
  ```python
30
  import clip
31
- from clip.downstream_task import TaskType
32
  import torch # Make sure to import torch
33
 
 
 
 
34
  device = "cpu" # Change to 'cuda' if you have a GPU
35
  num_classes = 4 # Number of classes in the original HYPERVIEW dataset
36
 
@@ -43,4 +45,28 @@ model, _ = clip.load(
43
  # Load the pre-trained weights
44
  model.load_state_dict(torch.load("VisionTransformer.pt"))
45
  model.eval() # Set the model to evaluation mode
46
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  ```python
30
  import clip
 
31
  import torch # Make sure to import torch
32
 
33
+ from clip.downstream_task import TaskType
34
+
35
+
36
  device = "cpu" # Change to 'cuda' if you have a GPU
37
  num_classes = 4 # Number of classes in the original HYPERVIEW dataset
38
 
 
45
  # Load the pre-trained weights
46
  model.load_state_dict(torch.load("VisionTransformer.pt"))
47
  model.eval() # Set the model to evaluation mode
48
+ ```
49
+
50
+ ### Loading training data
51
+
52
+ ```python
53
+ import numpy as np
54
+
55
+ from clip.hyperview_data_loader import HyperDataloader, DataReader
56
+
57
+
58
+ im_size = 224 # Image size
59
+ num_classes = 4 # Number of classes in the original HYPERVIEW dataset
60
+
61
+ # Paths to training data and ground truth
62
+ train_path = "<TRAIN_PATH>"
63
+ train_gt_path = "<TRAIN_PATH>/train_gt.csv"
64
+
65
+ # Initialize the dataset reader and transformations
66
+ target_index = list(np.arange(num_classes))
67
+ trans_tr, _ = HyperDataloader._init_transform(im_size)
68
+ train_dataset = DataReader(
69
+ database_dir=train_path, label_paths=train_gt_path,
70
+ transform=trans_tr, target_index=target_index
71
+ )
72
+ ````