Add to(device) to "input_ids"

#18
Files changed (1) hide show
  1. README.md +1 -1
README.md CHANGED
@@ -65,7 +65,7 @@ def paraphrase(
65
  return_tensors="pt", padding="longest",
66
  max_length=max_length,
67
  truncation=True,
68
- ).input_ids
69
 
70
  outputs = model.generate(
71
  input_ids, temperature=temperature, repetition_penalty=repetition_penalty,
 
65
  return_tensors="pt", padding="longest",
66
  max_length=max_length,
67
  truncation=True,
68
+ ).input_ids.to(device)
69
 
70
  outputs = model.generate(
71
  input_ids, temperature=temperature, repetition_penalty=repetition_penalty,