File size: 3,935 Bytes
34fa1de
6103f29
 
 
 
 
 
34fa1de
 
6103f29
34fa1de
96d8571
34fa1de
96d8571
 
34fa1de
 
954360e
 
 
6103f29
34fa1de
6103f29
34fa1de
6103f29
34fa1de
6103f29
34fa1de
6103f29
34fa1de
6103f29
34fa1de
6103f29
34fa1de
6103f29
34fa1de
6103f29
34fa1de
6103f29
 
 
 
 
 
 
34fa1de
 
6103f29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34fa1de
 
6103f29
34fa1de
6103f29
 
 
 
 
 
34fa1de
6103f29
 
 
34fa1de
6103f29
 
 
 
34fa1de
6103f29
 
 
 
 
 
34fa1de
6103f29
34fa1de
6103f29
34fa1de
 
6103f29
 
 
 
 
 
 
 
 
34fa1de
6103f29
34fa1de
6103f29
34fa1de
 
6103f29
 
34fa1de
6103f29
34fa1de
6103f29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
---
license: mit
pipeline_tag: image-feature-extraction
tags:
- pretrained
datasets:
- ylecun/mnist
---

# Model Card for Llava-mnist

Llava-mnist is a simple example of Vision and Language model using LLaVA architecture trained on MNIST dataset.


You can use this model (just one linear layer vision encoder model) alongside [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct).

## Training Details

![overview](./Llava-mnist.png)

The model was trained on the *chat-style* MNIST dataset, which is structured as follows:

prompt: “\<image\>What digit is this?”

output: "The digit is {label}."

The Llava-MNIST model transforms the digit image into an embedding vector that resides in the same space as the text token embedding.

The loss function optimized during training is defined as:

$L(W)= -\log P_W(This digit is \{label\}|\<image\>What digit is this?)$

During training, the parameters of the Llama 3.1 model are kept frozen, and only the parameters of the vision encoder (Llava-MNIST) are optimized.

## How to use

You can input multi-modal data (vision and text) into the Llama 3.1 model by using the Llava-MNIST model as the vision encoder.

```
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datasets import load_dataset
from torchvision import transforms
import util
from transformers import AutoModel


def build_multi_modal_prompt(
    prompt: str,
    image: torch.Tensor,
    tokenizer: AutoTokenizer,
    model: AutoModelForCausalLM,
    vision_model: AutoModel,
) -> torch.Tensor:
    parts = prompt.split("<image>")
    prefix = tokenizer(parts[0])
    suffix = tokenizer(parts[1])
    prefix_embedding = model.get_input_embeddings()(torch.tensor(prefix["input_ids"]))
    suffix_embedding = model.get_input_embeddings()(torch.tensor(suffix["input_ids"]))
    image_embedding = vision_model(image).to(torch.bfloat16).to(model.device)
    multi_modal_embedding = torch.cat(
        [prefix_embedding, image_embedding, suffix_embedding], dim=0
    )
    return multi_modal_embedding


model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

vision_model = AutoModel.from_pretrained(
    "speed/llava-mnist", trust_remote_code=True
)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]

system_prompt = (
    "<|begin_of_text|><|start_header_id|>system<|end_header_id|><|eot_id|>"
)
user_prompt = "<|start_header_id|>user<|end_header_id|>"
question = "<image>What digit is this?"
assistant_prompt = "<|start_header_id|>assistant<|end_header_id|>"

prompt = system_prompt + user_prompt + question + assistant_prompt

ds = load_dataset("ylecun/mnist", split="test")


def transform_image(examples):
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
            transforms.Lambda(lambda x: torch.flatten(x)),
        ]
    )
    examples["pixel_values"] = [transform(image) for image in examples["image"]]

    return examples

ds.set_transform(transform = transform_image)


model.eval()
vision_model.eval()

example = ds[0]

input_embeded = util.build_multi_modal_prompt(
    prompt, example["pixel_values"].unsqueeze(0), tokenizer, model, vision_model
).unsqueeze(0)
response = model.generate(
    inputs_embeds=input_embeded,
    max_new_tokens=20,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
response = response[0]
print("Label:", example["label"]) # Label: 7
answer = tokenizer.decode(response, skip_special_tokens=True)
print("Answer:", answer) # Answer: The digit is 7.

```

## References
- Liu et al., LLaVA: Large Language and Vision Assistant, https://llava-vl.github.io/