Commit
·
81beb99
1
Parent(s):
d2544e1
Updated model.py to work with AutoModel and trust_remote_code
Browse files
model.py
CHANGED
@@ -96,6 +96,7 @@ class Block(nn.Module):
|
|
96 |
|
97 |
class GPT(PreTrainedModel):
|
98 |
config_class = GPTConfig
|
|
|
99 |
|
100 |
def __init__(self, config):
|
101 |
super().__init__(config)
|
@@ -135,13 +136,13 @@ class GPT(PreTrainedModel):
|
|
135 |
x = block(x)
|
136 |
x = self.transformer.ln_f(x)
|
137 |
|
|
|
|
|
138 |
if labels is not None:
|
139 |
-
logits = self.lm_head(x)
|
140 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
|
141 |
-
return {
|
142 |
else:
|
143 |
-
logits
|
144 |
-
return {'logits': logits}
|
145 |
|
146 |
@torch.no_grad()
|
147 |
def generate(self, input_ids, max_new_tokens, temperature=1.0, top_k=None):
|
|
|
96 |
|
97 |
class GPT(PreTrainedModel):
|
98 |
config_class = GPTConfig
|
99 |
+
base_model_prefix = "gpt"
|
100 |
|
101 |
def __init__(self, config):
|
102 |
super().__init__(config)
|
|
|
136 |
x = block(x)
|
137 |
x = self.transformer.ln_f(x)
|
138 |
|
139 |
+
logits = self.lm_head(x)
|
140 |
+
|
141 |
if labels is not None:
|
|
|
142 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
|
143 |
+
return {"loss": loss, "logits": logits}
|
144 |
else:
|
145 |
+
return {"logits": logits}
|
|
|
146 |
|
147 |
@torch.no_grad()
|
148 |
def generate(self, input_ids, max_new_tokens, temperature=1.0, top_k=None):
|