benjamin's picture
Update README.md
66a3e78 verified
---
library_name: transformers
datasets:
- nvidia/OpenMathInstruct-2
base_model:
- google/gemma-2-2b-it
---
# Gemma2-2B Distilled from OpenMath2-Llama3.1-8B Model Card
Gemma2-2B distilled from [OpenMath2-Llama3.1-8B](https://huggingface.co/nvidia/OpenMath2-Llama3.1-8B) for math tasks.
This model greatly outperforms the general-purpose Gemma2 instruction-tuning finetune on math tasks.
## Model Details
- **Base Model:** Gemma2-2B
- **Tokenization:** Gemma2-2B
- **Training Methodology:** Distillation from OpenMath2-Llama3.1-8B on [OpenMathInstruct-2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2).
| **Benchmark** | **Gemma2-2B-Distilled-Math** | **Original Gemma2-2B-IT** |
|---------------|------------------------------------|------------------------|
| **GSM8K (zero-shot)** | 65.1 | 6.1 |
| **MATH (zero-shot)** | 52.1 | 9.9 |
## Model Details
Details on the training methodology are forthcoming.
## Use
```python
import torch
from transformers import pipeline
template = "<|start_header_id|>user<|end_header_id|>\n\nSolve the following math problem. Make sure to put the answer (and only answer) inside \boxed{}.\n\n{{problem}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
problem = "What is the minimum value of $a^2+6a-7$?"
pipe = pipeline(
"text-generation",
model="benjamin/Gemma2-2B-Distilled-Math",
model_kwargs={"torch_dtype": torch.bfloat16},
eos_token_id=107,
device_map="auto",
)
outputs = pipe(template.format(problem), max_new_tokens=256)
assistant_response = outputs[0]["generated_text"].strip()
print(assistant_response)
```