File size: 1,716 Bytes
ba0f73b
 
66a3e78
 
 
 
ba0f73b
 
66a3e78
ba0f73b
66a3e78
ba0f73b
66a3e78
ba0f73b
 
 
66a3e78
 
 
ba0f73b
 
66a3e78
 
 
 
ba0f73b
 
 
66a3e78
ba0f73b
66a3e78
ba0f73b
66a3e78
ba0f73b
66a3e78
 
 
ba0f73b
66a3e78
 
ba0f73b
66a3e78
 
 
 
 
 
 
ba0f73b
 
66a3e78
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
---
library_name: transformers
datasets:
- nvidia/OpenMathInstruct-2
base_model:
- google/gemma-2-2b-it
---

# Gemma2-2B Distilled from OpenMath2-Llama3.1-8B Model Card

Gemma2-2B distilled from [OpenMath2-Llama3.1-8B](https://huggingface.co/nvidia/OpenMath2-Llama3.1-8B) for math tasks.

This model greatly outperforms the general-purpose Gemma2 instruction-tuning finetune on math tasks.

## Model Details

- **Base Model:** Gemma2-2B
- **Tokenization:** Gemma2-2B
- **Training Methodology:** Distillation from OpenMath2-Llama3.1-8B on [OpenMathInstruct-2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2).


| **Benchmark** | **Gemma2-2B-Distilled-Math**       | **Original Gemma2-2B-IT** |
|---------------|------------------------------------|------------------------|
| **GSM8K (zero-shot)**    | 65.1                              | 6.1                |
| **MATH (zero-shot)**     | 52.1                              | 9.9                   |



## Model Details

Details on the training methodology are forthcoming.

## Use

```python
import torch
from transformers import pipeline

template = "<|start_header_id|>user<|end_header_id|>\n\nSolve the following math problem. Make sure to put the answer (and only answer) inside \boxed{}.\n\n{{problem}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
problem = "What is the minimum value of $a^2+6a-7$?"

pipe = pipeline(
    "text-generation",
    model="benjamin/Gemma2-2B-Distilled-Math",
    model_kwargs={"torch_dtype": torch.bfloat16},
    eos_token_id=107,
    device_map="auto",
)


outputs = pipe(template.format(problem), max_new_tokens=256)
assistant_response = outputs[0]["generated_text"].strip()
print(assistant_response)
```