Update modeling_llama.py
Browse files- modeling_llama.py +6 -1
modeling_llama.py
CHANGED
@@ -1117,7 +1117,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1117 |
def detect_shutdown_token(self, input_ids):
|
1118 |
if torch.any(input_ids == self.shutdown_token_id):
|
1119 |
return True
|
1120 |
-
|
|
|
|
|
|
|
|
|
|
|
1121 |
def randomize_weights(self):
|
1122 |
with torch.no_grad():
|
1123 |
for param in self.parameters():
|
|
|
1117 |
def detect_shutdown_token(self, input_ids):
|
1118 |
if torch.any(input_ids == self.shutdown_token_id):
|
1119 |
return True
|
1120 |
+
def detect_shutdown_token(self, input_ids):
|
1121 |
+
shutdown_token_tensor = torch.tensor(self.shutdown_token_id, device=input_ids.device, dtype=input_ids.dtype)
|
1122 |
+
if torch.any(input_ids == shutdown_token_tensor):
|
1123 |
+
return True
|
1124 |
+
return False
|
1125 |
+
|
1126 |
def randomize_weights(self):
|
1127 |
with torch.no_grad():
|
1128 |
for param in self.parameters():
|