File size: 3,935 Bytes
34fa1de 6103f29 34fa1de 6103f29 34fa1de 96d8571 34fa1de 96d8571 34fa1de 954360e 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 34fa1de 6103f29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
---
license: mit
pipeline_tag: image-feature-extraction
tags:
- pretrained
datasets:
- ylecun/mnist
---
# Model Card for Llava-mnist
Llava-mnist is a simple example of Vision and Language model using LLaVA architecture trained on MNIST dataset.
You can use this model (just one linear layer vision encoder model) alongside [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct).
## Training Details
![overview](./Llava-mnist.png)
The model was trained on the *chat-style* MNIST dataset, which is structured as follows:
prompt: “\<image\>What digit is this?”
output: "The digit is {label}."
The Llava-MNIST model transforms the digit image into an embedding vector that resides in the same space as the text token embedding.
The loss function optimized during training is defined as:
$L(W)= -\log P_W(This digit is \{label\}|\<image\>What digit is this?)$
During training, the parameters of the Llama 3.1 model are kept frozen, and only the parameters of the vision encoder (Llava-MNIST) are optimized.
## How to use
You can input multi-modal data (vision and text) into the Llama 3.1 model by using the Llava-MNIST model as the vision encoder.
```
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datasets import load_dataset
from torchvision import transforms
import util
from transformers import AutoModel
def build_multi_modal_prompt(
prompt: str,
image: torch.Tensor,
tokenizer: AutoTokenizer,
model: AutoModelForCausalLM,
vision_model: AutoModel,
) -> torch.Tensor:
parts = prompt.split("<image>")
prefix = tokenizer(parts[0])
suffix = tokenizer(parts[1])
prefix_embedding = model.get_input_embeddings()(torch.tensor(prefix["input_ids"]))
suffix_embedding = model.get_input_embeddings()(torch.tensor(suffix["input_ids"]))
image_embedding = vision_model(image).to(torch.bfloat16).to(model.device)
multi_modal_embedding = torch.cat(
[prefix_embedding, image_embedding, suffix_embedding], dim=0
)
return multi_modal_embedding
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
vision_model = AutoModel.from_pretrained(
"speed/llava-mnist", trust_remote_code=True
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
system_prompt = (
"<|begin_of_text|><|start_header_id|>system<|end_header_id|><|eot_id|>"
)
user_prompt = "<|start_header_id|>user<|end_header_id|>"
question = "<image>What digit is this?"
assistant_prompt = "<|start_header_id|>assistant<|end_header_id|>"
prompt = system_prompt + user_prompt + question + assistant_prompt
ds = load_dataset("ylecun/mnist", split="test")
def transform_image(examples):
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Lambda(lambda x: torch.flatten(x)),
]
)
examples["pixel_values"] = [transform(image) for image in examples["image"]]
return examples
ds.set_transform(transform = transform_image)
model.eval()
vision_model.eval()
example = ds[0]
input_embeded = util.build_multi_modal_prompt(
prompt, example["pixel_values"].unsqueeze(0), tokenizer, model, vision_model
).unsqueeze(0)
response = model.generate(
inputs_embeds=input_embeded,
max_new_tokens=20,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = response[0]
print("Label:", example["label"]) # Label: 7
answer = tokenizer.decode(response, skip_special_tokens=True)
print("Answer:", answer) # Answer: The digit is 7.
```
## References
- Liu et al., LLaVA: Large Language and Vision Assistant, https://llava-vl.github.io/ |