Upload ModularStarEncoder
Browse files- modularStarEncoder.py +2 -2
modularStarEncoder.py
CHANGED
@@ -154,8 +154,8 @@ class StarEncoder2LMPredictionHead(nn.Module):
|
|
154 |
super().__init__()
|
155 |
for element in dir(config):
|
156 |
value = getattr(config, element) # Get the attribute value
|
157 |
-
if isinstance(value, tuple) or isinstance(value, list):
|
158 |
-
setattr(config, element, value[
|
159 |
self.transform = StarEncoder2PredictionHeadTransform(config)
|
160 |
|
161 |
# The output weights are the same as the input embeddings, but there is
|
|
|
154 |
super().__init__()
|
155 |
for element in dir(config):
|
156 |
value = getattr(config, element) # Get the attribute value
|
157 |
+
if (isinstance(value, tuple) or isinstance(value, list)) and len(value)>0:
|
158 |
+
setattr(config, element, value[0])
|
159 |
self.transform = StarEncoder2PredictionHeadTransform(config)
|
160 |
|
161 |
# The output weights are the same as the input embeddings, but there is
|