darpanaswal commited on
Commit
447e724
·
verified ·
1 Parent(s): 48dbaf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -21,15 +21,13 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  num_classes = 6
22
 
23
  # Load the pre-trained ResNet model
24
- model = models.resnet152(pretrained=True)
25
  for param in model.parameters():
26
  param.requires_grad = False # Freeze feature extractor
27
 
28
  # Modify the classifier for 6 classes with an additional hidden layer
29
  model.fc = nn.Sequential(
30
- nn.Linear(model.fc.in_features, 512),
31
- nn.ReLU(),
32
- nn.Linear(512, num_classes)
33
  )
34
 
35
  # Load trained weights
@@ -43,7 +41,7 @@ class_labels = ['bird', 'cat', 'deer', 'dog', 'frog', 'horse']
43
  def transform_image(image):
44
  """Preprocess the input image."""
45
  transform = transforms.Compose([
46
- transforms.Resize((32, 32)),
47
  transforms.ToTensor(),
48
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
49
  ])
 
21
  num_classes = 6
22
 
23
  # Load the pre-trained ResNet model
24
+ model = models.resnet18(pretrained=True)
25
  for param in model.parameters():
26
  param.requires_grad = False # Freeze feature extractor
27
 
28
  # Modify the classifier for 6 classes with an additional hidden layer
29
  model.fc = nn.Sequential(
30
+ nn.Linear(model.fc.in_features, num_classes)
 
 
31
  )
32
 
33
  # Load trained weights
 
41
  def transform_image(image):
42
  """Preprocess the input image."""
43
  transform = transforms.Compose([
44
+ transforms.Resize(224),
45
  transforms.ToTensor(),
46
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
47
  ])