Commit
·
6e67cf8
1
Parent(s):
384a482
"fix"
Browse files
model.py
CHANGED
@@ -164,7 +164,7 @@ class GPT(PreTrainedModel):
|
|
164 |
return {"input_ids": input_ids, "past_key_values": past_key_values}
|
165 |
|
166 |
@torch.no_grad()
|
167 |
-
def generate(self, input_ids, max_length, temperature=1.0, top_k=None, attention_mask=None):
|
168 |
for _ in range(max_length - input_ids.size(1)):
|
169 |
idx_cond = input_ids if input_ids.size(1) <= self.config.block_size else input_ids[:, -self.config.block_size:]
|
170 |
out = self(idx_cond)
|
|
|
164 |
return {"input_ids": input_ids, "past_key_values": past_key_values}
|
165 |
|
166 |
@torch.no_grad()
|
167 |
+
def generate(self, input_ids, max_length, temperature=1.0, top_k=None, attention_mask=None, **kwargs):
|
168 |
for _ in range(max_length - input_ids.size(1)):
|
169 |
idx_cond = input_ids if input_ids.size(1) <= self.config.block_size else input_ids[:, -self.config.block_size:]
|
170 |
out = self(idx_cond)
|