nicolinho commited on
Commit
74280ee
1 Parent(s): 710e02a

Update modeling_custom.py

Browse files
Files changed (1) hide show
  1. modeling_custom.py +1 -1
modeling_custom.py CHANGED
@@ -89,7 +89,7 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
89
  super().__init__(config)
90
  #self.model = AutoModelForSequenceClassification.from_pretrained(
91
  # "Skywork/Skywork-Reward-Llama-3.1-8B", num_labels=1, torch_dtype=torch.bfloat16, use_flash_attention_2=True).model
92
- self.model = LlamaModel(config).to(torch.bfloat16)
93
  self.num_labels = config.num_labels
94
  config_dict = config.to_dict()
95
  self.num_objectives = config_dict.get("num_objectives", 19)
 
89
  super().__init__(config)
90
  #self.model = AutoModelForSequenceClassification.from_pretrained(
91
  # "Skywork/Skywork-Reward-Llama-3.1-8B", num_labels=1, torch_dtype=torch.bfloat16, use_flash_attention_2=True).model
92
+ self.model = LlamaModel(config)#.to(torch.bfloat16)
93
  self.num_labels = config.num_labels
94
  config_dict = config.to_dict()
95
  self.num_objectives = config_dict.get("num_objectives", 19)