Upload ModularStarEncoder
Browse files- modularStarEncoder.py +6 -3
modularStarEncoder.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from transformers import Starcoder2Model
|
2 |
import sys
|
3 |
-
from
|
4 |
import os
|
5 |
from dataclasses import dataclass
|
6 |
from typing import Optional, Tuple, Union
|
@@ -171,8 +171,11 @@ class StarEncoder2PreTrainingHeads(nn.Module):
|
|
171 |
|
172 |
def forward(self, sequence_output, pooled_output,idx_layer: Optional[torch.Tensor] = None):
|
173 |
if self.is_matryoshka:
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
176 |
else:
|
177 |
prediction_scores = self.predictions(sequence_output)
|
178 |
seq_relationship_score = self.seq_relationship(pooled_output)
|
|
|
1 |
from transformers import Starcoder2Model
|
2 |
import sys
|
3 |
+
from config import ModularStarEncoderConfig
|
4 |
import os
|
5 |
from dataclasses import dataclass
|
6 |
from typing import Optional, Tuple, Union
|
|
|
171 |
|
172 |
def forward(self, sequence_output, pooled_output,idx_layer: Optional[torch.Tensor] = None):
|
173 |
if self.is_matryoshka:
|
174 |
+
device_sequence = sequence_output.get_device()
|
175 |
+
if device_sequence<0:
|
176 |
+
device_sequence = "cpu"
|
177 |
+
prediction_scores = self.predictions(torch.cat([sequence_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(sequence_output.size()[0],sequence_output.size()[1],-1)],dim=-1))
|
178 |
+
seq_relationship_score = self.seq_relationship(torch.cat([pooled_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(pooled_output.size()[0],-1)],dim=-1))
|
179 |
else:
|
180 |
prediction_scores = self.predictions(sequence_output)
|
181 |
seq_relationship_score = self.seq_relationship(pooled_output)
|