|
--- |
|
license: cc-by-sa-3.0 |
|
library_name: transformers |
|
tags: |
|
- supertrainer2000 |
|
- human-data |
|
datasets: |
|
- euclaise/TinyCoT |
|
- euclaise/reddit-instruct |
|
- sablo/oasst2_curated |
|
metrics: |
|
- accuracy |
|
--- |
|
|
|
![image/png](https://cdn-uploads.huggingface.co/production/uploads/64137e2150358a805203cbac/DlTWku8gant1yx6NaxqJX.png) |
|
|
|
Memphis-CoT is a finetune of [StableLM 3b 4e1t](stabilityai/stablelm-3b-4e1t) on [TinyCoT](https://huggingface.co/datasets/euclaise/TinyCoT), along with [reddit-instruct](https://huggingface.co/datasets/euclaise/reddit-instruct) (subset to 5000 examples, excluding posts with brackets in the title) and a [curated](https://huggingface.co/datasets/sablo/oasst2_curated) subset of [oasst2](https://huggingface.co/datasets/OpenAssistant/oasst2). |
|
|
|
**Memphis was trained *only* on human data! No GPT generations here.** |
|
|
|
Finetuning was performed using my [supertrainer2000](https://github.com/euclaise/supertrainer2000) framework, using my Adalite optimizer. |
|
|
|
|
|
## Training Procedure |
|
I finetuned the model using an iterative rationale-bootstrapping procedure inspired by [STaR](https://research.google/pubs/star-self-taught-reasoner-bootstrapping-reasoning-with-reasoning/) and [SPIN](https://arxiv.org/abs/2401.01335) |
|
|
|
First, I finetuned the model on all the datasets using a [MixCE](https://arxiv.org/abs/2305.16958) loss and [NEFTune](https://arxiv.org/abs/2310.05914), for 2 epochs. |
|
|
|
I then performed the following steps 4 times: |
|
1. Generate responses for each question in TinyCoT using the current model, check each response for correctness, and create a dataset of (correct, incorrect) pairs. Extra values are discarded, such that each correct and incorrect response is unique. |
|
2. Finetune the model for 1 epoch using a ranking loss over length-normalized log-probabilities of each sequence, similar to [Preference Ranking Optimization](https://arxiv.org/abs/2306.17492), comparing the correct vs incorrect generated response. A standard CE loss over the ground-truth was included to prevent excessive drift. |
|
|
|
This should be more efficient than either STaR or SPIN, as it uses a ranking loss rather than rejection sampling (unlike STaR), and verifies correctness instead of assuming all model responses are incorrect (unlike SPIN). |
|
|
|
|
|
## Prompt formats |
|
|
|
The format for reddit-instruct and oasst2 was: |
|
|
|
``` |
|
### User: |
|
[insert instruction here] |
|
### Assistant: |
|
[insert response here] |
|
### User: |
|
... |
|
``` |
|
|
|
The format for TinyCoT was: |
|
``` |
|
### User: |
|
[insert instruction here] |
|
### Rationale: |
|
[insert reasoning here] |
|
### Answer: |
|
[insert direct answer here] |
|
``` |
|
|
|
## Benchmarks |
|
|
|
| Model | Size | Data | Method | GSM8K (5-shot) | AGIEval (English/Nous subset, acc_norm) | BIG Bench Hard (CoT, few-shot*) | |
|
|:-----------------------------------------------------------------------|--------|:--------------------|---------------|:---------------|:----------------------------------------|:------------------------------ | |
|
| [StableLM 3B Base](https://hf.co/stabilityai/stablelm-zephyr-3b) | 3B | Base | Base | 2.05% | 25.14% | |
|
| [StableHermes 3B](https://hf.co/cxllin/StableHermes-3b) | 3B | GPT | SFT | 3.64% | 24.31% | *37.28%* | |
|
| [MPT 7B Instruct](https://hf.co/mosaicml/mpt-7b-instruct) | **7B** | **Human**+Anthropic | SFT | 2.05% | 24.12% | 11.01% | |
|
| [OpenLLaMA 7B v2 open-instruct](http://hf.co/VMware/open-llama-7b-v2-open-instruct) | **7B** | **Human** (nearly: ecqa is an exception) | SFT | 8.64% | 23.21% | 29.84% | |
|
| [StableLM Zephyr 3B](https://hf.co/stabilityai/stablelm-zephyr-3b) | 3B | GPT | DPO | contaminated (45.72%) | **33.31%** | 0.91% | |
|
| [**Memphis-CoT 3B**](https://hf.co/euclaise/memphis-cot-3b) | 3B | **Human** | Self-teaching | **13.8%** | *26.24%* | **38.24%** | |
|
*5-shot, as performed automatically by LM Evaluation Harness bbh_cot_fewshot even with num_fewshot=0 |
|
|
|
Memphis outperforms human-data models that are over twice its size, along with SFT models of its size, and trades with the Zephyr DPO model. That said, Zephyr uses synthetic data, and *much* more of it. |
|
|
|
Note that BBH results have wide SEs, exceeding 16%. |
|
|
|
|
|
It is unclear why Zephyr performs so poorly on BBH. Perhaps it is overfit, or maybe there was an issue with vllm. |
|
|
|
Notes: |
|
- Evaluations were performed using the `agieval` branch of [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) (commit `0bef5c9c273b1c2f68e6018d4bb9c32b9aaff298`), using the `vllm` model. |
|
- I tried to find human-data-trained StableLM models, but couldn't find any. I did find a few OpenLLaMA models, but they wouldn't load with LM Eval Harness and vllm. (I believe this can be fixed by changing the xformers backend, but I'm too lazy for that) |
|
- OpenLLaMA 7B v2 open-instruct is a particularly relevant comparison, as it was trained on a *very* similar dataset. |
|
|
|
## Hyperparameters |
|
|
|
For the initial supervised finetuning step: |
|
- Adalite optimizer, default hyperparameters of supertrainer2000 unless otherwise specified |
|
- Lambda (Adalite's analogue to weight decay, see [here](https://arxiv.org/abs/2103.06583) for details) of 0.01 |
|
- LR of 1e-5 |
|
- MixCE ratio of 0.75 |
|
- Sequence length of 4096 |
|
- Cosine decay with a 20% warmup |
|
- Frozen embeddings |
|
- No training on inputs |
|
- Accumulated batch size of 128 |
|
- NEFTune with an alpha of 10 |
|
|
|
For the generations: |
|
- Generated using the current git version of `vllm` |
|
- N=8 |
|
- Temperature of 0.5 |
|
- `top_p` of 0.8 |
|
- Maximum of 512 generated tokens, discarding responses that do not have a valid rationale and answer |
|
|
|
For the rank finetuning: |
|
- Adalite optimizer, default hyperparameters of supertrainer2000 unless otherwise specified |
|
- Lambda of 0.01 |
|
- LR of 5e-7 |
|
- Rank loss weight of 5 |
|
- Sequence length of 1024 |
|
- Cosine schedule with 10% warmup |
|
- Frozen embeddings |
|
- No training on inputs |
|
- Accumulated batch size of 128 |
|
- NEFTune with an alpha of 10 |