vapit commited on
Commit
284b65c
·
1 Parent(s): 0708b73

add the torch import in model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -3
model.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torchvision
2
  from torchvision import transforms
3
  from PIL import Image
@@ -45,9 +46,9 @@ def create_model(num_classes: int = 20,
45
  torch.manual_seed(seed)
46
 
47
  # 6. Modify the number of output layers
48
- model.heads = nn.Sequential(
49
- nn.Dropout(p=0.5),
50
- nn.Linear(in_features=768, out_features=num_classes, bias=True)
51
  )
52
 
53
  return model, combined_transforms
 
1
+ import torch
2
  import torchvision
3
  from torchvision import transforms
4
  from PIL import Image
 
46
  torch.manual_seed(seed)
47
 
48
  # 6. Modify the number of output layers
49
+ model.heads = torch.nn.Sequential(
50
+ torch.nn.Dropout(p=0.5),
51
+ torch.nn.Linear(in_features=768, out_features=num_classes, bias=True)
52
  )
53
 
54
  return model, combined_transforms