AjayP13 commited on
Commit
3f3c51c
·
verified ·
1 Parent(s): 1d2f2e3

Update instruction_template_retriever.py

Browse files
Files changed (1) hide show
  1. instruction_template_retriever.py +2 -0
instruction_template_retriever.py CHANGED
@@ -142,6 +142,7 @@ def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0):
142
  sigma (float): Standard deviation for Gaussian weighting.
143
  alpha (float): Weighting factor for merging with standard mean pooling.
144
  """
 
145
  if isinstance(m[1], GaussianCoveragePooling):
146
  m = unuse_gaussian_coverage_pooling(m)
147
  word_embedding_model = m[0]
@@ -151,6 +152,7 @@ def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0):
151
  old_pooling = m[1]
152
  new_m = m.__class__(modules=[word_embedding_model, custom_pooling])
153
  new_m.old_pooling = {"old_pooling": old_pooling}
 
154
  return new_m
155
 
156
 
 
142
  sigma (float): Standard deviation for Gaussian weighting.
143
  alpha (float): Weighting factor for merging with standard mean pooling.
144
  """
145
+ old_device = m.device
146
  if isinstance(m[1], GaussianCoveragePooling):
147
  m = unuse_gaussian_coverage_pooling(m)
148
  word_embedding_model = m[0]
 
152
  old_pooling = m[1]
153
  new_m = m.__class__(modules=[word_embedding_model, custom_pooling])
154
  new_m.old_pooling = {"old_pooling": old_pooling}
155
+ new_m = new_m.to(old_device)
156
  return new_m
157
 
158