Exquisique commited on
Commit
6e67cf8
·
1 Parent(s): 384a482
Files changed (1) hide show
  1. model.py +1 -1
model.py CHANGED
@@ -164,7 +164,7 @@ class GPT(PreTrainedModel):
164
  return {"input_ids": input_ids, "past_key_values": past_key_values}
165
 
166
  @torch.no_grad()
167
- def generate(self, input_ids, max_length, temperature=1.0, top_k=None, attention_mask=None):
168
  for _ in range(max_length - input_ids.size(1)):
169
  idx_cond = input_ids if input_ids.size(1) <= self.config.block_size else input_ids[:, -self.config.block_size:]
170
  out = self(idx_cond)
 
164
  return {"input_ids": input_ids, "past_key_values": past_key_values}
165
 
166
  @torch.no_grad()
167
+ def generate(self, input_ids, max_length, temperature=1.0, top_k=None, attention_mask=None, **kwargs):
168
  for _ in range(max_length - input_ids.size(1)):
169
  idx_cond = input_ids if input_ids.size(1) <= self.config.block_size else input_ids[:, -self.config.block_size:]
170
  out = self(idx_cond)