Update ink_detection_pipeline.py
Browse filesfixing bfloat16 handling on certain devices
ink_detection_pipeline.py
CHANGED
@@ -72,7 +72,7 @@ class InkDetectionPipeline(Pipeline):
|
|
72 |
sub_y_preds = torch.sigmoid(sub_y_preds)
|
73 |
|
74 |
# Move to CPU and numpy
|
75 |
-
sub_y_preds = sub_y_preds.detach().cpu().numpy()
|
76 |
# shape (subB, 1, tile_size, tile_size)
|
77 |
|
78 |
all_preds.append(sub_y_preds)
|
|
|
72 |
sub_y_preds = torch.sigmoid(sub_y_preds)
|
73 |
|
74 |
# Move to CPU and numpy
|
75 |
+
sub_y_preds = sub_y_preds.detach().cpu().float().numpy()
|
76 |
# shape (subB, 1, tile_size, tile_size)
|
77 |
|
78 |
all_preds.append(sub_y_preds)
|