Spaces:
Runtime error
Runtime error
Commit
·
b0c3beb
1
Parent(s):
6d20fa3
drop device setting for already parallel models
Browse files- dmx_perplexity.py +7 -2
dmx_perplexity.py
CHANGED
@@ -40,6 +40,7 @@ Examples:
|
|
40 |
46.05925369262695
|
41 |
"""
|
42 |
|
|
|
43 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
44 |
class DmxPerplexity(evaluate.Metric):
|
45 |
def _info(self):
|
@@ -89,9 +90,13 @@ class DmxPerplexity(evaluate.Metric):
|
|
89 |
max_seq_len = model.config.n_positions
|
90 |
else:
|
91 |
max_seq_len = 2048
|
92 |
-
|
93 |
-
if not hasattr(model, "hf_device_map")
|
|
|
|
|
94 |
model = model.to(device)
|
|
|
|
|
95 |
encodings = tokenizer("\n\n".join(references), return_tensors="pt")
|
96 |
|
97 |
stride = max_seq_len
|
|
|
40 |
46.05925369262695
|
41 |
"""
|
42 |
|
43 |
+
|
44 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
45 |
class DmxPerplexity(evaluate.Metric):
|
46 |
def _info(self):
|
|
|
90 |
max_seq_len = model.config.n_positions
|
91 |
else:
|
92 |
max_seq_len = 2048
|
93 |
+
|
94 |
+
if not hasattr(model, "hf_device_map") and (
|
95 |
+
not hasattr(model, "model_parallel") or not model.model_parallel
|
96 |
+
):
|
97 |
model = model.to(device)
|
98 |
+
|
99 |
+
model.eval()
|
100 |
encodings = tokenizer("\n\n".join(references), return_tensors="pt")
|
101 |
|
102 |
stride = max_seq_len
|