Feature Extraction
Transformers
Safetensors
ModularStarEncoder
custom_code
andreagurioli1995 commited on
Commit
8f38a33
·
verified ·
1 Parent(s): 447b5e3

Upload ModularStarEncoder

Browse files
Files changed (1) hide show
  1. modularStarEncoder.py +6 -3
modularStarEncoder.py CHANGED
@@ -1,6 +1,6 @@
1
  from transformers import Starcoder2Model
2
  import sys
3
- from .config import ModularStarEncoderConfig
4
  import os
5
  from dataclasses import dataclass
6
  from typing import Optional, Tuple, Union
@@ -171,8 +171,11 @@ class StarEncoder2PreTrainingHeads(nn.Module):
171
 
172
  def forward(self, sequence_output, pooled_output,idx_layer: Optional[torch.Tensor] = None):
173
  if self.is_matryoshka:
174
- prediction_scores = self.predictions(torch.cat([sequence_output , self.conditional_embeddings(torch.tensor(idx_layer,device=sequence_output.get_device()).int()).expand(sequence_output.size()[0],sequence_output.size()[1],-1)],dim=-1))
175
- seq_relationship_score = self.seq_relationship(torch.cat([pooled_output , self.conditional_embeddings(torch.tensor(idx_layer,device=pooled_output.get_device()).int()).expand(pooled_output.size()[0],-1)],dim=-1))
 
 
 
176
  else:
177
  prediction_scores = self.predictions(sequence_output)
178
  seq_relationship_score = self.seq_relationship(pooled_output)
 
1
  from transformers import Starcoder2Model
2
  import sys
3
+ from config import ModularStarEncoderConfig
4
  import os
5
  from dataclasses import dataclass
6
  from typing import Optional, Tuple, Union
 
171
 
172
  def forward(self, sequence_output, pooled_output,idx_layer: Optional[torch.Tensor] = None):
173
  if self.is_matryoshka:
174
+ device_sequence = sequence_output.get_device()
175
+ if device_sequence<0:
176
+ device_sequence = "cpu"
177
+ prediction_scores = self.predictions(torch.cat([sequence_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(sequence_output.size()[0],sequence_output.size()[1],-1)],dim=-1))
178
+ seq_relationship_score = self.seq_relationship(torch.cat([pooled_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(pooled_output.size()[0],-1)],dim=-1))
179
  else:
180
  prediction_scores = self.predictions(sequence_output)
181
  seq_relationship_score = self.seq_relationship(pooled_output)