Tanusree88 commited on
Commit
25387a1
·
verified ·
1 Parent(s): 6bd55a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -97,7 +97,15 @@ class CustomImageDataset(Dataset):
97
  def fine_tune_classification_model(train_loader):
98
  # Load the ResNet model with ignore_mismatched_sizes
99
  model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3, ignore_mismatched_sizes=True)
100
- model.fc = torch.nn.Linear(model.fc.in_features, 3) # Assuming 3 output classes
 
 
 
 
 
 
 
 
101
 
102
  model.train()
103
 
 
97
  def fine_tune_classification_model(train_loader):
98
  # Load the ResNet model with ignore_mismatched_sizes
99
  model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3, ignore_mismatched_sizes=True)
100
+ # Print model architecture to identify the classifier layer
101
+ print(model) # Inspect the model structure
102
+
103
+ # Update the classifier layer to match the number of labels
104
+ if hasattr(model, 'classifier'):
105
+ model.classifier = torch.nn.Linear(model.classifier.in_features, 3) # Assuming 3 output classes
106
+ else:
107
+ # Access the linear layer differently if 'classifier' does not exist
108
+
109
 
110
  model.train()
111