Commit
·
880ed7e
1
Parent(s):
e516101
Initial commit
Browse files
custom_generate/generate.py
CHANGED
@@ -103,9 +103,9 @@ def ancestral_sampling(model_kwargs, model, eos_token_ids, pad_token_id, bos_tok
|
|
103 |
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
|
104 |
lens = torch.full((batch_size,), max_prompts_len, dtype=torch.long, device=input_ids.device)
|
105 |
# Modified log probabilities of the sequences
|
106 |
-
scores = torch.zeros((batch_size, max_new_tokens), dtype=
|
107 |
# Unfiltered sequence log probabilities (T=1, no sampling modifications)
|
108 |
-
logps = torch.zeros((batch_size, max_new_tokens), dtype=
|
109 |
|
110 |
for i in range(max_new_tokens):
|
111 |
# Get the next token probabilities and update the KV cache
|
@@ -153,6 +153,7 @@ def generate(model, **kwargs):
|
|
153 |
"""
|
154 |
generation_config = model.generation_config
|
155 |
max_new_tokens = kwargs.get('max_new_tokens', generation_config.max_new_tokens)
|
|
|
156 |
do_sample = kwargs.get('do_sample', True)
|
157 |
eos_token_ids = kwargs.get('eos_token_ids', generation_config.eos_token_id)
|
158 |
if eos_token_ids is None:
|
|
|
103 |
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
|
104 |
lens = torch.full((batch_size,), max_prompts_len, dtype=torch.long, device=input_ids.device)
|
105 |
# Modified log probabilities of the sequences
|
106 |
+
scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
|
107 |
# Unfiltered sequence log probabilities (T=1, no sampling modifications)
|
108 |
+
logps = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
|
109 |
|
110 |
for i in range(max_new_tokens):
|
111 |
# Get the next token probabilities and update the KV cache
|
|
|
153 |
"""
|
154 |
generation_config = model.generation_config
|
155 |
max_new_tokens = kwargs.get('max_new_tokens', generation_config.max_new_tokens)
|
156 |
+
max_new_tokens = 512 if max_new_tokens is None else max_new_tokens
|
157 |
do_sample = kwargs.get('do_sample', True)
|
158 |
eos_token_ids = kwargs.get('eos_token_ids', generation_config.eos_token_id)
|
159 |
if eos_token_ids is None:
|