Fixed Error due to Tensors Located on Different Files
Browse filesOriginally, the code at line 29 and 30 **does not modify** ```instance[key]``` in place. Instead, it returns a new tensor on the specified device, which is never used:
```
instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
instance[key].to(self.device)
```
It causes the following error:
```
Traceback (most recent call last):
File "/content/uptake-model/test.py", line 20, in <module>
print(my_handler(example))
File "/content/uptake-model/handler.py", line 74, in __call__
uptake_scores[str(utt["id"])] = self.get_uptake_score(textA=prev_text, textB=utt["text"])
File "/content/uptake-model/handler.py", line 46, in get_uptake_score
output = self.get_prediction(instance)
File "/content/uptake-model/handler.py", line 32, in get_prediction
output = self.model(input_ids=instance["input_ids"],
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/content/uptake-model/utils.py", line 98, in forward
output = self.bert(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/bert/modeling_bert.py", line 1078, in forward
embedding_output = self.embeddings(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/bert/modeling_bert.py", line 211, in forward
inputs_embeds = self.word_embeddings(input_ids)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py", line 190, in forward
return F.embedding(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py", line 2551, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
```
This PR simply merges the 2 lines of code together like follows:
```
instance[key] = torch.tensor(instance[key]).unsqueeze(0).to(self.device) # Batch size = 1
```
and the code works perfectly:
```
Loading models...
EXAMPLES
speaker Alice: How much is the fish?
speaker Bob: I have no idea, ask Alice
Running inference on 2 examples...
{'2': 0.8638461608296379}
```
- handler.py +1 -2
@@ -26,8 +26,7 @@ class EndpointHandler():
|
|
26 |
def get_prediction(self, instance):
|
27 |
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
|
28 |
for key in ["input_ids", "token_type_ids", "attention_mask"]:
|
29 |
-
instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
|
30 |
-
instance[key].to(self.device)
|
31 |
|
32 |
output = self.model(input_ids=instance["input_ids"],
|
33 |
attention_mask=instance["attention_mask"],
|
|
|
26 |
def get_prediction(self, instance):
|
27 |
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
|
28 |
for key in ["input_ids", "token_type_ids", "attention_mask"]:
|
29 |
+
instance[key] = torch.tensor(instance[key]).unsqueeze(0).to(self.device) # Batch size = 1
|
|
|
30 |
|
31 |
output = self.model(input_ids=instance["input_ids"],
|
32 |
attention_mask=instance["attention_mask"],
|