fyrefist commited on
Commit
735d6c7
·
verified ·
1 Parent(s): 776a2f8

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +3 -2
inference.py CHANGED
@@ -6,8 +6,9 @@ def load_model():
6
  return model
7
 
8
  def predict(model, inputs):
9
- inputs = [torch.tensor(input) for input in inputs] # Batch inputs as tensors
10
  with torch.no_grad():
11
  output = model(*inputs)
12
- return output.numpy().tolist()
 
13
 
 
6
  return model
7
 
8
  def predict(model, inputs):
9
+ inputs = [torch.tensor(input) for input in inputs]
10
  with torch.no_grad():
11
  output = model(*inputs)
12
+ return output.detach().cpu().numpy().tolist()
13
+
14