add device
Browse files- dmxMetric.py +6 -1
dmxMetric.py
CHANGED
@@ -3,6 +3,7 @@ import lm_eval
|
|
3 |
from typing import Union, List, Optional
|
4 |
from dmx.compressor.dmx import config_rules, DmxModel
|
5 |
import datasets
|
|
|
6 |
|
7 |
_DESCRIPTION = """
|
8 |
Evaluation function using lm-eval with d-Matrix integration.
|
@@ -54,6 +55,7 @@ class DmxMetric(evaluate.Metric):
|
|
54 |
batch_size: Optional[Union[int, str]] = None,
|
55 |
max_batch_size: Optional[int] = None,
|
56 |
limit: Optional[Union[int, float]] = None,
|
|
|
57 |
revision: str = "main",
|
58 |
trust_remote_code: bool = False,
|
59 |
log_samples: bool = True,
|
@@ -63,7 +65,10 @@ class DmxMetric(evaluate.Metric):
|
|
63 |
"""
|
64 |
Evaluate a model on multiple tasks and metrics using lm-eval with optional d-Matrix integration.
|
65 |
"""
|
66 |
-
|
|
|
|
|
|
|
67 |
|
68 |
lm = lm_eval.api.registry.get_model("hf").create_from_arg_string(
|
69 |
model_args,
|
|
|
3 |
from typing import Union, List, Optional
|
4 |
from dmx.compressor.dmx import config_rules, DmxModel
|
5 |
import datasets
|
6 |
+
import torch
|
7 |
|
8 |
_DESCRIPTION = """
|
9 |
Evaluation function using lm-eval with d-Matrix integration.
|
|
|
55 |
batch_size: Optional[Union[int, str]] = None,
|
56 |
max_batch_size: Optional[int] = None,
|
57 |
limit: Optional[Union[int, float]] = None,
|
58 |
+
device: Optional[str] = None,
|
59 |
revision: str = "main",
|
60 |
trust_remote_code: bool = False,
|
61 |
log_samples: bool = True,
|
|
|
65 |
"""
|
66 |
Evaluate a model on multiple tasks and metrics using lm-eval with optional d-Matrix integration.
|
67 |
"""
|
68 |
+
if device is None:
|
69 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
70 |
+
|
71 |
+
model_args = f"pretrained={model},revision={revision},trust_remote_code={str(trust_remote_code)},device={device}"
|
72 |
|
73 |
lm = lm_eval.api.registry.get_model("hf").create_from_arg_string(
|
74 |
model_args,
|