|
--- |
|
library_name: transformers |
|
datasets: |
|
- nvidia/OpenMathInstruct-2 |
|
base_model: |
|
- google/gemma-2-2b-it |
|
--- |
|
|
|
# Gemma2-2B Distilled from OpenMath2-Llama3.1-8B Model Card |
|
|
|
Gemma2-2B distilled from [OpenMath2-Llama3.1-8B](https://huggingface.co/nvidia/OpenMath2-Llama3.1-8B) for math tasks. |
|
|
|
This model greatly outperforms the general-purpose Gemma2 instruction-tuning finetune on math tasks. |
|
|
|
## Model Details |
|
|
|
- **Base Model:** Gemma2-2B |
|
- **Tokenization:** Gemma2-2B |
|
- **Training Methodology:** Distillation from OpenMath2-Llama3.1-8B on [OpenMathInstruct-2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2). |
|
|
|
|
|
| **Benchmark** | **Gemma2-2B-Distilled-Math** | **Original Gemma2-2B-IT** | |
|
|---------------|------------------------------------|------------------------| |
|
| **GSM8K (zero-shot)** | 65.1 | 6.1 | |
|
| **MATH (zero-shot)** | 52.1 | 9.9 | |
|
|
|
|
|
|
|
## Model Details |
|
|
|
Details on the training methodology are forthcoming. |
|
|
|
## Use |
|
|
|
```python |
|
import torch |
|
from transformers import pipeline |
|
|
|
template = "<|start_header_id|>user<|end_header_id|>\n\nSolve the following math problem. Make sure to put the answer (and only answer) inside \boxed{}.\n\n{{problem}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" |
|
problem = "What is the minimum value of $a^2+6a-7$?" |
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model="benjamin/Gemma2-2B-Distilled-Math", |
|
model_kwargs={"torch_dtype": torch.bfloat16}, |
|
eos_token_id=107, |
|
device_map="auto", |
|
) |
|
|
|
|
|
outputs = pipe(template.format(problem), max_new_tokens=256) |
|
assistant_response = outputs[0]["generated_text"].strip() |
|
print(assistant_response) |
|
``` |