Update modeling_bunny_qwen2.py
Browse files- modeling_bunny_qwen2.py +6 -0
modeling_bunny_qwen2.py
CHANGED
@@ -701,11 +701,17 @@ class BunnyMetaForCausalLM(ABC):
|
|
701 |
if labels is None:
|
702 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
703 |
|
|
|
|
|
704 |
# remove the padding using attention_mask -- TODO: double check
|
705 |
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
|
706 |
zip(input_ids, attention_mask)]
|
707 |
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
708 |
|
|
|
|
|
|
|
|
|
709 |
new_input_embeds = []
|
710 |
new_labels = []
|
711 |
cur_image_idx = 0
|
|
|
701 |
if labels is None:
|
702 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
703 |
|
704 |
+
input_ids_temp = input_ids # points to the actual input_ids tensor
|
705 |
+
|
706 |
# remove the padding using attention_mask -- TODO: double check
|
707 |
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
|
708 |
zip(input_ids, attention_mask)]
|
709 |
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
710 |
|
711 |
+
# -- TODO: better implementation?
|
712 |
+
# replace IMAGE_TOKEN_INDEX(-200) with 0 to be compatible with repetition penalty
|
713 |
+
input_ids_temp[input_ids_temp == IMAGE_TOKEN_INDEX] = 0
|
714 |
+
|
715 |
new_input_embeds = []
|
716 |
new_labels = []
|
717 |
cur_image_idx = 0
|