manueldeprada HF Staff commited on
Commit
880ed7e
·
1 Parent(s): e516101

Initial commit

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +3 -2
custom_generate/generate.py CHANGED
@@ -103,9 +103,9 @@ def ancestral_sampling(model_kwargs, model, eos_token_ids, pad_token_id, bos_tok
103
  active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
104
  lens = torch.full((batch_size,), max_prompts_len, dtype=torch.long, device=input_ids.device)
105
  # Modified log probabilities of the sequences
106
- scores = torch.zeros((batch_size, max_new_tokens), dtype=torch.float32)
107
  # Unfiltered sequence log probabilities (T=1, no sampling modifications)
108
- logps = torch.zeros((batch_size, max_new_tokens), dtype=torch.float32)
109
 
110
  for i in range(max_new_tokens):
111
  # Get the next token probabilities and update the KV cache
@@ -153,6 +153,7 @@ def generate(model, **kwargs):
153
  """
154
  generation_config = model.generation_config
155
  max_new_tokens = kwargs.get('max_new_tokens', generation_config.max_new_tokens)
 
156
  do_sample = kwargs.get('do_sample', True)
157
  eos_token_ids = kwargs.get('eos_token_ids', generation_config.eos_token_id)
158
  if eos_token_ids is None:
 
103
  active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
104
  lens = torch.full((batch_size,), max_prompts_len, dtype=torch.long, device=input_ids.device)
105
  # Modified log probabilities of the sequences
106
+ scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
107
  # Unfiltered sequence log probabilities (T=1, no sampling modifications)
108
+ logps = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
109
 
110
  for i in range(max_new_tokens):
111
  # Get the next token probabilities and update the KV cache
 
153
  """
154
  generation_config = model.generation_config
155
  max_new_tokens = kwargs.get('max_new_tokens', generation_config.max_new_tokens)
156
+ max_new_tokens = 512 if max_new_tokens is None else max_new_tokens
157
  do_sample = kwargs.get('do_sample', True)
158
  eos_token_ids = kwargs.get('eos_token_ids', generation_config.eos_token_id)
159
  if eos_token_ids is None: