mergekit
Merge
Mistral_Star
Mistral_Quiet
Mistral
Mixtral
Question-Answer
Token-Classification
Sequence-Classification
SpydazWeb-AI
chemistry
biology
legal
code
climate
medical
LCARS_AI_StarTrek_Computer
text-generation-inference
chain-of-thought
tree-of-knowledge
forest-of-thoughts
visual-spacial-sketchpad
alpha-mind
knowledge-graph
entity-detection
encyclopedia
wikipedia
stack-exchange
Reddit
Cyber-series
MegaMind
Cybertron
SpydazWeb
Spydaz
LCARS
star-trek
mega-transformers
Mulit-Mega-Merge
Multi-Lingual
Afro-Centric
African-Model
Ancient-One
def custom_generate( | |
self, | |
input_ids, | |
attention_mask=None, | |
max_new_tokens=None, | |
min_length=None, | |
do_sample=None, | |
early_stopping=None, | |
num_beams=None, | |
temperature=None, | |
top_k=None, | |
top_p=None, | |
repetition_penalty=None, | |
bad_words_ids=None, | |
bos_token_id=None, | |
pad_token_id=None, | |
eos_token_id=None, | |
streamer=None, | |
length_penalty=None, | |
no_repeat_ngram_size=None, | |
num_return_sequences=None, | |
decoder_start_token_id=None, | |
use_cache=None, | |
num_beam_groups=None, | |
diversity_penalty=None, | |
prefix_allowed_tokens_fn=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
output_scores=None, | |
return_dict_in_generate=None, | |
forced_bos_token_id=None, | |
forced_eos_token_id=None, | |
remove_invalid_values=None, | |
synced_gpus=None, | |
**kwargs, | |
): | |
if input_ids is None or input_ids.nelement() == 0: | |
# If input_ids is None or an empty tensor, create a default input tensor | |
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device) | |
attention_mask = torch.ones_like(input_ids).to(self.device) | |
device = input_ids.device | |
with torch.no_grad(): | |
batch_size = input_ids.shape[0] | |
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device) | |
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device) | |
for cur_token_idx in range(max_new_tokens): | |
# Sample the next token | |
new_ids = self( | |
input_ids[~finished_generating], | |
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None, | |
**kwargs | |
)['logits'] | |
# Mask out the start and end thought tokens so we don't accidentally sample them | |
new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf") | |
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]): | |
# Find the index of the last token that is not padding | |
base_answer_ids = input_ids[answer_idx] | |
new_answer_ids = new_ids[list_idx] | |
last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max() | |
new_ids_sampled = torch.multinomial( | |
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1) | |
# Assign the new id to the last token | |
if last_token_idx + 1 >= len(base_answer_ids): | |
# Add padding everywhere | |
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long, | |
device=device) | |
input_ids = torch.cat([input_ids, new_padding], dim=-1) | |
if attention_mask is not None: | |
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1) | |
if attention_mask is not None: | |
attention_mask[answer_idx, last_token_idx + 1] = 1 | |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled | |
generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled | |
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id: | |
finished_generating[answer_idx] = 1 | |
# Check if the end token is generated | |
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"): | |
finished_generating[answer_idx] = 1 | |
if finished_generating.all(): | |
break | |
if streamer is not None: | |
streamer.put(new_ids_sampled) | |
return generated_token_ids | |
def generate( | |
self, | |
input_ids, | |
attention_mask=None, | |
max_new_tokens=None, | |
min_length=None, | |
do_sample=None, | |
early_stopping=None, | |
num_beams=None, | |
temperature=1.1, | |
streamer=None, | |
top_k=None, | |
top_p=None, | |
repetition_penalty=None, | |
bad_words_ids=None, | |
bos_token_id=None, | |
pad_token_id=None, | |
eos_token_id=None, | |
length_penalty=None, | |
no_repeat_ngram_size=None, | |
num_return_sequences=None, | |
decoder_start_token_id=None, | |
use_cache=None, | |
num_beam_groups=None, | |
diversity_penalty=None, | |
prefix_allowed_tokens_fn=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
output_scores=None, | |
return_dict_in_generate=None, | |
forced_bos_token_id=None, | |
forced_eos_token_id=None, | |
remove_invalid_values=None, | |
synced_gpus=None, | |
n_ahead=4, | |
n_ahead_talk=4, | |
merged_talk_heads=True, | |
merged_lm_and_talk_heads=False, | |
merged_lm_and_think_heads=True, | |
use_concat_talk_head=True, | |
use_shallow_think=True, | |
use_shallow_talk=False, | |
use_complex_think_head=False, | |
use_complex_talk_head=True, | |
use_weighted_talk_head=True, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16, | |
**model_kwargs, | |
): | |
if max_new_tokens is None: | |
max_new_tokens = 128 | |
# Set model attributes | |
self.max_thoughts = n_ahead + n_ahead_talk + 1 | |
self.merged_talk_heads = merged_talk_heads | |
self.merged_lm_and_talk_heads = merged_lm_and_talk_heads | |
self.merged_lm_and_think_heads = merged_lm_and_think_heads | |
self.use_concat_talk_head = use_concat_talk_head | |
self.use_shallow_think = use_shallow_think | |
self.use_shallow_talk = use_shallow_talk | |
self.use_complex_think_head = use_complex_think_head | |
self.use_complex_talk_head = use_complex_talk_head | |
self.use_weighted_talk_head = use_weighted_talk_head | |
# Set model properties | |
self.use_end_thought_token = True | |
self.use_start_thought_token = True | |
self.n_ahead = n_ahead | |
self.n_passes = 1 | |
self.eval_mode = True | |
self.first_run = False | |
self.rm_initialized = True | |
self.original_mode = False | |
# Check if the input is a string (for compatibility with text-generation-webui) | |
if isinstance(input_ids, str): | |
input_ids = self.tokenizer.encode(input_ids, return_tensors='pt') | |
# Move input_ids and attention_mask to the same device as the model | |
input_ids = input_ids.to(self.device) | |
if attention_mask is not None: | |
attention_mask = attention_mask.to(self.device) | |
generated_token_ids = custom_generate( | |
self, | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
min_length=min_length, | |
do_sample=do_sample, | |
early_stopping=early_stopping, | |
num_beams=num_beams, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
bad_words_ids=bad_words_ids, | |
bos_token_id=bos_token_id, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
length_penalty=length_penalty, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
num_return_sequences=num_return_sequences, | |
decoder_start_token_id=decoder_start_token_id, | |
use_cache=use_cache, | |
num_beam_groups=num_beam_groups, | |
diversity_penalty=diversity_penalty, | |
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
forced_bos_token_id=forced_bos_token_id, | |
forced_eos_token_id=forced_eos_token_id, | |
remove_invalid_values=remove_invalid_values, | |
synced_gpus=synced_gpus, | |
streamer=streamer, | |
**model_kwargs, | |
) | |
return generated_token_ids |