avanish07 commited on
Commit
f326c04
·
1 Parent(s): a33f382

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -66,7 +66,7 @@ def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
66
 
67
 
68
  # Load the CSRNet model
69
- csrmodel = CSRNet()
70
  checkpoint = torch.load("model.pt")
71
  csrmodel.load_state_dict(checkpoint)
72
  csrmodel.eval()
@@ -82,7 +82,7 @@ transform = transforms.Compose([
82
  # Define the prediction function
83
  def predict_count(input_image):
84
  # Preprocess the input image
85
- image = transform(input_image).unsqueeze(0)
86
 
87
  # Perform the forward pass
88
  output = csrmodel(image)
 
66
 
67
 
68
  # Load the CSRNet model
69
+ csrmodel = CSRNet(load_weights=True).cpu()
70
  checkpoint = torch.load("model.pt")
71
  csrmodel.load_state_dict(checkpoint)
72
  csrmodel.eval()
 
82
  # Define the prediction function
83
  def predict_count(input_image):
84
  # Preprocess the input image
85
+ image = transform(input_image).unsqueeze(0).cpu()
86
 
87
  # Perform the forward pass
88
  output = csrmodel(image)