Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -66,7 +66,7 @@ def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
|
|
66 |
|
67 |
|
68 |
# Load the CSRNet model
|
69 |
-
csrmodel = CSRNet()
|
70 |
checkpoint = torch.load("model.pt")
|
71 |
csrmodel.load_state_dict(checkpoint)
|
72 |
csrmodel.eval()
|
@@ -82,7 +82,7 @@ transform = transforms.Compose([
|
|
82 |
# Define the prediction function
|
83 |
def predict_count(input_image):
|
84 |
# Preprocess the input image
|
85 |
-
image = transform(input_image).unsqueeze(0)
|
86 |
|
87 |
# Perform the forward pass
|
88 |
output = csrmodel(image)
|
|
|
66 |
|
67 |
|
68 |
# Load the CSRNet model
|
69 |
+
csrmodel = CSRNet(load_weights=True).cpu()
|
70 |
checkpoint = torch.load("model.pt")
|
71 |
csrmodel.load_state_dict(checkpoint)
|
72 |
csrmodel.eval()
|
|
|
82 |
# Define the prediction function
|
83 |
def predict_count(input_image):
|
84 |
# Preprocess the input image
|
85 |
+
image = transform(input_image).unsqueeze(0).cpu()
|
86 |
|
87 |
# Perform the forward pass
|
88 |
output = csrmodel(image)
|