Dan Bochman commited on
Commit
a76f764
·
unverified ·
1 Parent(s): 6db27b1

cast to cuda outside of spaces.GPU scope

Browse files
Files changed (1) hide show
  1. app.py +2 -6
app.py CHANGED
@@ -12,9 +12,6 @@ from torchvision import transforms
12
 
13
  # ----------------- ENV ----------------- #
14
 
15
- if torch.cuda.get_device_properties(0).major >= 8:
16
- torch.backends.cuda.matmul.allow_tf32 = True
17
- torch.backends.cudnn.allow_tf32 = True
18
 
19
  ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
20
 
@@ -107,13 +104,12 @@ if not os.path.exists(model_path):
107
 
108
  model = torch.jit.load(model_path)
109
  model.eval()
 
110
 
111
 
112
  @spaces.GPU
113
  @torch.inference_mode()
114
  def run_model(input_tensor, height, width):
115
- model.to("cuda") # set the device after acquiring it with ZERO
116
- input_tensor = input_tensor.to("cuda")
117
  output = model(input_tensor)
118
  output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
119
  _, preds = torch.max(output, 1)
@@ -131,7 +127,7 @@ transform_fn = transforms.Compose(
131
 
132
 
133
  def segment(image: Image.Image) -> Image.Image:
134
- input_tensor = transform_fn(image).unsqueeze(0)
135
  preds = run_model(input_tensor, height=image.height, width=image.width)
136
  mask = preds.squeeze(0).cpu().numpy()
137
  mask_image = Image.fromarray(mask.astype("uint8"))
 
12
 
13
  # ----------------- ENV ----------------- #
14
 
 
 
 
15
 
16
  ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
17
 
 
104
 
105
  model = torch.jit.load(model_path)
106
  model.eval()
107
+ model.to("cuda")
108
 
109
 
110
  @spaces.GPU
111
  @torch.inference_mode()
112
  def run_model(input_tensor, height, width):
 
 
113
  output = model(input_tensor)
114
  output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
115
  _, preds = torch.max(output, 1)
 
127
 
128
 
129
  def segment(image: Image.Image) -> Image.Image:
130
+ input_tensor = transform_fn(image).unsqueeze(0).to("cuda")
131
  preds = run_model(input_tensor, height=image.height, width=image.width)
132
  mask = preds.squeeze(0).cpu().numpy()
133
  mask_image = Image.fromarray(mask.astype("uint8"))