llava-mnist / README.md
speed's picture
Update README.md
96d8571 verified
metadata
license: mit
pipeline_tag: image-feature-extraction
tags:
  - pretrained
datasets:
  - ylecun/mnist

Model Card for Llava-mnist

Llava-mnist is a simple example of Vision and Language model using LLaVA architecture trained on MNIST dataset.

You can use this model (just one linear layer vision encoder model) alongside meta-llama/Meta-Llama-3.1-8B-Instruct.

Training Details

overview

The model was trained on the chat-style MNIST dataset, which is structured as follows:

prompt: “<image>What digit is this?”

output: "The digit is {label}."

The Llava-MNIST model transforms the digit image into an embedding vector that resides in the same space as the text token embedding.

The loss function optimized during training is defined as:

$L(W)= -\log P_W(This digit is {label}|<image>What digit is this?)$

During training, the parameters of the Llama 3.1 model are kept frozen, and only the parameters of the vision encoder (Llava-MNIST) are optimized.

How to use

You can input multi-modal data (vision and text) into the Llama 3.1 model by using the Llava-MNIST model as the vision encoder.

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datasets import load_dataset
from torchvision import transforms
import util
from transformers import AutoModel


def build_multi_modal_prompt(
    prompt: str,
    image: torch.Tensor,
    tokenizer: AutoTokenizer,
    model: AutoModelForCausalLM,
    vision_model: AutoModel,
) -> torch.Tensor:
    parts = prompt.split("<image>")
    prefix = tokenizer(parts[0])
    suffix = tokenizer(parts[1])
    prefix_embedding = model.get_input_embeddings()(torch.tensor(prefix["input_ids"]))
    suffix_embedding = model.get_input_embeddings()(torch.tensor(suffix["input_ids"]))
    image_embedding = vision_model(image).to(torch.bfloat16).to(model.device)
    multi_modal_embedding = torch.cat(
        [prefix_embedding, image_embedding, suffix_embedding], dim=0
    )
    return multi_modal_embedding


model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

vision_model = AutoModel.from_pretrained(
    "speed/llava-mnist", trust_remote_code=True
)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]

system_prompt = (
    "<|begin_of_text|><|start_header_id|>system<|end_header_id|><|eot_id|>"
)
user_prompt = "<|start_header_id|>user<|end_header_id|>"
question = "<image>What digit is this?"
assistant_prompt = "<|start_header_id|>assistant<|end_header_id|>"

prompt = system_prompt + user_prompt + question + assistant_prompt

ds = load_dataset("ylecun/mnist", split="test")


def transform_image(examples):
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
            transforms.Lambda(lambda x: torch.flatten(x)),
        ]
    )
    examples["pixel_values"] = [transform(image) for image in examples["image"]]

    return examples

ds.set_transform(transform = transform_image)


model.eval()
vision_model.eval()

example = ds[0]

input_embeded = util.build_multi_modal_prompt(
    prompt, example["pixel_values"].unsqueeze(0), tokenizer, model, vision_model
).unsqueeze(0)
response = model.generate(
    inputs_embeds=input_embeded,
    max_new_tokens=20,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
response = response[0]
print("Label:", example["label"]) # Label: 7
answer = tokenizer.decode(response, skip_special_tokens=True)
print("Answer:", answer) # Answer: The digit is 7.

References