Rewrite image embedding to remove the in-place op

#53
by YenChunChen - opened

When attempting to LoRA-finetune Phi-3-V, the following error occurred if trying to LoRA the CLIP encoder together with the Phi-3 LM.

  File "/home/yenchun/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-vision-128k-instruct/f998a184b56bf0399b3af85c50b20ec0d5688f5f/image_embedding_phi3_v.py", line 280, in forward
    hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = (
RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

This PR use the out-of-place equivalent index_put instead. Moreover, the complete rewrite of forward removes code paths for early model variants for better readability. The final version of hd_transform is refactored for better readability as well. See the below comment for parity tests.

@haipingwu please review.

Parity and batching tests:

import copy

import requests
import torch
import torch.testing
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor

from configuration_phi3_v import Phi3VConfig
from image_embedding_phi3_v import Phi3ImageEmbedding


def load_models():
    model_path = 'microsoft/Phi-3-vision-128k-instruct'

    kwargs = {}
    kwargs['torch_dtype'] = torch.bfloat16

    processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        trust_remote_code=True,
        torch_dtype='auto',
        _attn_implementation='eager',
        revision='f998a184b56bf0399b3af85c50b20ec0d5688f5f',
    ).cuda()
    config = Phi3VConfig.from_pretrained(model_path, _attn_implementation='eager')
    embedding_config = {'embedding_cls': config.embd_layer['embedding_cls'], **config.embd_layer}

    image_embed_state_dict = model.model.vision_embed_tokens.state_dict()

    old_image_embedder = copy.deepcopy(model.model.vision_embed_tokens)
    old_image_embedder.load_state_dict(image_embed_state_dict)
    new_image_embedder = (
        Phi3ImageEmbedding(config, wte=model.model.embed_tokens, **embedding_config)
        .bfloat16()
        .cuda()
    )
    new_image_embedder.load_state_dict(image_embed_state_dict)

    del model
    return processor, old_image_embedder, new_image_embedder


def test_input_1(processor):
    user_prompt = '<|user|>\n'
    assistant_prompt = '<|assistant|>\n'
    prompt_suffix = '<|end|>\n'

    prompt = (
        f'{user_prompt}<|image_1|>\nWhat is shown in this image?{prompt_suffix}{assistant_prompt}'
    )
    url = 'https://www.ilankelman.org/stopsigns/australia.jpg'
    print(f'>>> Prompt\n{prompt}')
    image = Image.open(requests.get(url, stream=True).raw)
    inputs = processor(prompt, image, return_tensors='pt').to('cuda:0')
    return inputs


def test_input_2(processor):
    user_prompt = '<|user|>\n'
    assistant_prompt = '<|assistant|>\n'
    prompt_suffix = '<|end|>\n'

    prompt = f'{user_prompt}<|image_1|>\nCan you convert the table to markdown format?{prompt_suffix}{assistant_prompt}'
    url = 'https://support.content.office.net/en-us/media/3dd2b79b-9160-403d-9967-af893d17b580.png'
    image = Image.open(requests.get(url, stream=True).raw)
    inputs = processor(prompt, image, return_tensors='pt').to('cuda:0')
    return inputs


def compare_old_and_new_forward(old_image_embedder, new_image_embedder, inputs):
    input_ids = inputs['input_ids'].clone().cuda()
    pixel_values = inputs['pixel_values'].bfloat16().cuda()
    image_sizes = inputs['image_sizes']

    with torch.no_grad():
        old_outputs = old_image_embedder(
            input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes
        )
        input_ids = inputs['input_ids'].clone().cuda()
        new_outputs = new_image_embedder(
            input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes
        )
    torch.testing.assert_close(old_outputs, new_outputs)
    return old_outputs, new_outputs


def main():
    processor, old_image_embedder, new_image_embedder = load_models()

    inputs_1 = test_input_1(processor)
    inputs_2 = test_input_2(processor)

    # test parity of single example
    old_outputs_1, new_outputs_1 = compare_old_and_new_forward(
        old_image_embedder, new_image_embedder, inputs_1
    )
    old_outputs_2, new_outputs_2 = compare_old_and_new_forward(
        old_image_embedder, new_image_embedder, inputs_2
    )

    # test parity of batched examples
    inputs_1and2 = {
        'input_ids': torch.nn.utils.rnn.pad_sequence(
            [
                inputs_1['input_ids'].squeeze(0).unsqueeze(1),
                inputs_2['input_ids'].squeeze(0).unsqueeze(1),
            ],
            batch_first=True,
            padding_value=processor.tokenizer.pad_token_id,
        ).squeeze(2),
        'pixel_values': torch.cat([inputs_1['pixel_values'], inputs_2['pixel_values']]),
        'image_sizes': torch.cat([inputs_1['image_sizes'], inputs_2['image_sizes']]),
    }
    old_outputs_1and2, new_outputs_1and2 = compare_old_and_new_forward(
        old_image_embedder, new_image_embedder, inputs_1and2
    )

    # test batching correctness
    len_1 = inputs_1['input_ids'].shape[1]
    len_2 = inputs_2['input_ids'].shape[1]
    torch.testing.assert_close(new_outputs_1[0], new_outputs_1and2[0, :len_1])
    torch.testing.assert_close(new_outputs_2[0], new_outputs_1and2[1, :len_2])

    # test parity for single example with multiple images
    inputs_1plus2 = {
        'input_ids': torch.cat([inputs_1['input_ids'], inputs_2['input_ids']], dim=1),
        'pixel_values': torch.cat([inputs_1['pixel_values'], inputs_2['pixel_values']]),
        'image_sizes': torch.cat([inputs_1['image_sizes'], inputs_2['image_sizes']]),
    }
    old_outputs_1plus2, new_outputs_1plus2 = compare_old_and_new_forward(
        old_image_embedder, new_image_embedder, inputs_1plus2
    )
    torch.testing.assert_close(new_outputs_1, new_outputs_1plus2[:, :len_1])
    torch.testing.assert_close(new_outputs_2, new_outputs_1plus2[:, -len_2:])

    # test batched examples with potentially different number of images
    inputs_complex = {
        'input_ids': torch.nn.utils.rnn.pad_sequence(
            [
                inputs_1['input_ids'].squeeze(0).unsqueeze(1),
                inputs_1plus2['input_ids'].squeeze(0).unsqueeze(1),
                inputs_2['input_ids'].squeeze(0).unsqueeze(1),
            ],
            batch_first=True,
            padding_value=processor.tokenizer.pad_token_id,
        ).squeeze(2),
        'pixel_values': torch.cat(
            [inputs_1['pixel_values'], inputs_1plus2['pixel_values'], inputs_2['pixel_values']]
        ),
        'image_sizes': torch.cat(
            [inputs_1['image_sizes'], inputs_1plus2['image_sizes'], inputs_2['image_sizes']]
        ),
    }
    old_outputs_complex, new_outputs_complex = compare_old_and_new_forward(
        old_image_embedder, new_image_embedder, inputs_complex
    )
    torch.testing.assert_close(new_outputs_1[0], new_outputs_complex[0, :len_1])
    torch.testing.assert_close(new_outputs_1plus2[0], new_outputs_complex[1, : len_1 + len_2])
    torch.testing.assert_close(new_outputs_2[0], new_outputs_complex[2, :len_2])


if __name__ == '__main__':
    main()
YenChunChen changed pull request status to open
Microsoft org

hi @leoxiaobin , please merge this PR

Microsoft org

@YenChunChen , the branch has merge conflicts. Please fix it.

leoxiaobin changed pull request status to merged

Sign up or log in to comment