|
--- |
|
library_name: transformers |
|
tags: |
|
- trl |
|
- sft |
|
--- |
|
|
|
# tinygemma3 with vision |
|
|
|
This is trained on [CIFAR-10](https://huggingface.co/datasets/uoft-cs/cifar10) dataset. |
|
|
|
How to use: |
|
|
|
```py |
|
from transformers import AutoModelForImageTextToText, AutoProcessor |
|
|
|
model_id = "ngxson/tinygemma3_cifar" |
|
|
|
model = AutoModelForImageTextToText.from_pretrained(model_id).to("cuda") |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
|
|
##################### |
|
|
|
from datasets import load_dataset, Dataset |
|
|
|
ds_full = load_dataset("uoft-cs/cifar10") |
|
|
|
def ex_to_msg(ex): |
|
txt = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": "What is this:"}, |
|
{"type": "image"} |
|
] |
|
} |
|
] |
|
img = ex["img"] |
|
return { |
|
"messages": txt, |
|
"images": [img], |
|
} |
|
|
|
|
|
##################### |
|
|
|
test_idx = 0 |
|
|
|
test_msg = ex_to_msg(ds_full["train"][test_idx]) |
|
|
|
test_txt = processor.apply_chat_template(test_msg["messages"], tokenize=False, add_generation_prompt=True) |
|
test_input = processor(text=test_txt, images=test_msg["images"], return_tensors="pt").to(model.device) |
|
|
|
##################### |
|
|
|
generated_ids = model.generate(**test_input, do_sample=False, max_new_tokens=1) |
|
generated_texts = processor.batch_decode( |
|
generated_ids, |
|
skip_special_tokens=True, |
|
) |
|
print(generated_texts) |
|
|
|
# expected answer for test_idx = 0 is "airplane" |
|
|
|
``` |
|
|