Update instruction_template_retriever.py
Browse files
instruction_template_retriever.py
CHANGED
@@ -142,6 +142,7 @@ def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0):
|
|
142 |
sigma (float): Standard deviation for Gaussian weighting.
|
143 |
alpha (float): Weighting factor for merging with standard mean pooling.
|
144 |
"""
|
|
|
145 |
if isinstance(m[1], GaussianCoveragePooling):
|
146 |
m = unuse_gaussian_coverage_pooling(m)
|
147 |
word_embedding_model = m[0]
|
@@ -151,6 +152,7 @@ def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0):
|
|
151 |
old_pooling = m[1]
|
152 |
new_m = m.__class__(modules=[word_embedding_model, custom_pooling])
|
153 |
new_m.old_pooling = {"old_pooling": old_pooling}
|
|
|
154 |
return new_m
|
155 |
|
156 |
|
|
|
142 |
sigma (float): Standard deviation for Gaussian weighting.
|
143 |
alpha (float): Weighting factor for merging with standard mean pooling.
|
144 |
"""
|
145 |
+
old_device = m.device
|
146 |
if isinstance(m[1], GaussianCoveragePooling):
|
147 |
m = unuse_gaussian_coverage_pooling(m)
|
148 |
word_embedding_model = m[0]
|
|
|
152 |
old_pooling = m[1]
|
153 |
new_m = m.__class__(modules=[word_embedding_model, custom_pooling])
|
154 |
new_m.old_pooling = {"old_pooling": old_pooling}
|
155 |
+
new_m = new_m.to(old_device)
|
156 |
return new_m
|
157 |
|
158 |
|