Image Classification
timm
drhead commited on
Commit
80578f6
1 Parent(s): 938a123

minmax performance on those tensor cores lol

Browse files
Files changed (1) hide show
  1. inference_gradio.py +2 -2
inference_gradio.py CHANGED
@@ -123,7 +123,7 @@ model.eval()
123
  if torch.cuda.is_available():
124
  model.cuda()
125
  if torch.cuda.get_device_capability()[0] >= 7: # tensor cores
126
- model.half()
127
 
128
  with open("JTP_PILOT/tags.json", "r") as file:
129
  tags = json.load(file) # type: dict
@@ -139,7 +139,7 @@ def create_tags(image, threshold):
139
  if torch.cuda.is_available():
140
  tensor.cuda()
141
  if torch.cuda.get_device_capability()[0] >= 7:
142
- tensor.half()
143
 
144
  with torch.no_grad():
145
  logits = model(tensor)
 
123
  if torch.cuda.is_available():
124
  model.cuda()
125
  if torch.cuda.get_device_capability()[0] >= 7: # tensor cores
126
+ model.to(dtype=torch.float16, memory_format=torch.channels_last)
127
 
128
  with open("JTP_PILOT/tags.json", "r") as file:
129
  tags = json.load(file) # type: dict
 
139
  if torch.cuda.is_available():
140
  tensor.cuda()
141
  if torch.cuda.get_device_capability()[0] >= 7:
142
+ tensor.to(dtype=torch.float16, memory_format=torch.channels_last)
143
 
144
  with torch.no_grad():
145
  logits = model(tensor)