minmax performance on those tensor cores lol
Browse files- inference_gradio.py +2 -2
inference_gradio.py
CHANGED
@@ -123,7 +123,7 @@ model.eval()
|
|
123 |
if torch.cuda.is_available():
|
124 |
model.cuda()
|
125 |
if torch.cuda.get_device_capability()[0] >= 7: # tensor cores
|
126 |
-
model.
|
127 |
|
128 |
with open("JTP_PILOT/tags.json", "r") as file:
|
129 |
tags = json.load(file) # type: dict
|
@@ -139,7 +139,7 @@ def create_tags(image, threshold):
|
|
139 |
if torch.cuda.is_available():
|
140 |
tensor.cuda()
|
141 |
if torch.cuda.get_device_capability()[0] >= 7:
|
142 |
-
tensor.
|
143 |
|
144 |
with torch.no_grad():
|
145 |
logits = model(tensor)
|
|
|
123 |
if torch.cuda.is_available():
|
124 |
model.cuda()
|
125 |
if torch.cuda.get_device_capability()[0] >= 7: # tensor cores
|
126 |
+
model.to(dtype=torch.float16, memory_format=torch.channels_last)
|
127 |
|
128 |
with open("JTP_PILOT/tags.json", "r") as file:
|
129 |
tags = json.load(file) # type: dict
|
|
|
139 |
if torch.cuda.is_available():
|
140 |
tensor.cuda()
|
141 |
if torch.cuda.get_device_capability()[0] >= 7:
|
142 |
+
tensor.to(dtype=torch.float16, memory_format=torch.channels_last)
|
143 |
|
144 |
with torch.no_grad():
|
145 |
logits = model(tensor)
|