muellerzr HF Staff commited on
Commit
8a128ee
·
1 Parent(s): 0c3d87d
Files changed (1) hide show
  1. src/app.py +1 -2
src/app.py CHANGED
@@ -27,7 +27,6 @@ vocab = [
27
  model = get_model()
28
  state = torch.load('exported_model.pth')["model"]
29
  apply_weights(model, state, copy_weight)
30
- model.cuda()
31
 
32
  to_tensor = ToTensor()
33
  norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
@@ -37,7 +36,7 @@ def classify_image(inp):
37
  transformed_input = pad(crop(inp, (460, 460)), (460, 460))
38
  transformed_input = to_tensor(transformed_input).unsqueeze(0)
39
  transformed_input = gpu_crop(transformed_input, (224, 224))
40
- transformed_input = norm(transformed_input).cuda()
41
  model.eval()
42
  with torch.no_grad():
43
  pred = model(transformed_input)
 
27
  model = get_model()
28
  state = torch.load('exported_model.pth')["model"]
29
  apply_weights(model, state, copy_weight)
 
30
 
31
  to_tensor = ToTensor()
32
  norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
36
  transformed_input = pad(crop(inp, (460, 460)), (460, 460))
37
  transformed_input = to_tensor(transformed_input).unsqueeze(0)
38
  transformed_input = gpu_crop(transformed_input, (224, 224))
39
+ transformed_input = norm(transformed_input)
40
  model.eval()
41
  with torch.no_grad():
42
  pred = model(transformed_input)