Spaces:
Sleeping
Sleeping
✨ add ability to force CPU
Browse filesSigned-off-by: peter szemraj <[email protected]>
- aggregate.py +9 -2
aggregate.py
CHANGED
@@ -54,15 +54,22 @@ class BatchAggregator:
|
|
54 |
DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
|
55 |
|
56 |
def __init__(
|
57 |
-
self,
|
|
|
|
|
|
|
58 |
):
|
59 |
"""
|
60 |
__init__ initializes the BatchAggregator class.
|
61 |
|
62 |
:param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
|
|
|
63 |
"""
|
64 |
self.device = None
|
65 |
self.is_compiled = False
|
|
|
|
|
|
|
66 |
self.logger = logging.getLogger(__name__)
|
67 |
self.init_model(model_name)
|
68 |
|
@@ -105,7 +112,7 @@ class BatchAggregator:
|
|
105 |
|
106 |
:raises Exception: if the pipeline cannot be created
|
107 |
"""
|
108 |
-
self.device = 0 if torch.cuda.is_available() else -1
|
109 |
try:
|
110 |
self.logger.info(
|
111 |
f"Creating pipeline with model {model_name} on device {self.device}"
|
|
|
54 |
DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
|
55 |
|
56 |
def __init__(
|
57 |
+
self,
|
58 |
+
model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1",
|
59 |
+
force_cpu: bool = False,
|
60 |
+
**kwargs,
|
61 |
):
|
62 |
"""
|
63 |
__init__ initializes the BatchAggregator class.
|
64 |
|
65 |
:param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
|
66 |
+
:param bool force_cpu: force the model to run on CPU, default: False
|
67 |
"""
|
68 |
self.device = None
|
69 |
self.is_compiled = False
|
70 |
+
self.model_name = None
|
71 |
+
self.aggregator = None
|
72 |
+
self.force_cpu = force_cpu
|
73 |
self.logger = logging.getLogger(__name__)
|
74 |
self.init_model(model_name)
|
75 |
|
|
|
112 |
|
113 |
:raises Exception: if the pipeline cannot be created
|
114 |
"""
|
115 |
+
self.device = 0 if torch.cuda.is_available() and not self.force_cpu else -1
|
116 |
try:
|
117 |
self.logger.info(
|
118 |
f"Creating pipeline with model {model_name} on device {self.device}"
|