Text Generation
Transformers
Safetensors
jamba
conversational
custom_code
Inference Endpoints
ptrdvn's picture
Update README.md
27e8cf6 verified
|
raw
history blame
7.03 kB
---
library_name: transformers
tags: []
---
# Model Overview
This model was trained as a small-scale experiment to determine how easy it is to fine-tune [ai21labs/Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1) to work as a chatbot.
The aim of this experiment was to find how intelligently and reliably Jamba can chat in both English and other languages if only QLoRA finetuned for a few hours.
Initial subjective testing has shown that this model can chat reasonably well in both English and Japanese, so feel free to give it a try!
## Model Details
- **Model type:** Joint Attention and Mamba (Jamba)
- **License:** Apache 2.0
- **Context length:** 256K
- **Knowledge cutoff date:** March 5, 2024
## How to use
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("kinokokoro/jamba_airoboros3.2_sharegpt4",
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("kinokokoro/jamba_airoboros3.2_sharegpt4")
input_text = """<|im_start|>system
You are GPT-4, a helpful assistant.
<|im_end|>
<|im_start|>user
ๆœ€่ฟ‘ใ€้‹ๅ‹•ใ™ใ‚Œใฐใ€ใ™ใใซใ‚ใฃใกใ‚ƒใใฃใกใ‚ƒๆฑ—ใ‹ใ„ใกใ‚ƒใ†ใ‚“ใ ใ‘ใฉใ€ใฉใ†ใ—ใŸใ‚‰ใ„ใ„ใงใ™ใ‹๏ผŸ
<|im_end|>
<|im_start|>assistant
"""
input_ids = tokenizer(input_text, return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=256, temperature=0.0)\
print(tokenizer.batch_decode([outputs[0][len(input_ids[0]):]]))
# ['ๆฑ—ใŒๅ‡บใ‚‹ใ“ใจใฏใ€้‹ๅ‹•ใ‚’ใ™ใ‚‹ใจใใซไฝ“ๆธฉใŒไธŠใŒใ‚Šใ€ไฝ“ๅ†…ใฎ็†ฑใ‚’ๅค–้ƒจใซๆ”พๅ‡บใ™ใ‚‹ใŸใ‚ใฎ่‡ช็„ถใชใƒกใ‚ซใƒ‹ใ‚บใƒ ใงใ™ใ€‚ๆฑ—ใŒๅ‡บใ‚‹ใ“ใจใŒๅคšใ„ใ“ใจใฏใ€ไธ€่ˆฌ็š„ใซใฏใ€ไฝ“ใฎๆธฉๅบฆ่ชฟ็ฏ€ๆฉŸ่ƒฝใŒๅƒใ„ใฆใ„ใ‚‹ใ“ใจใ‚’ๆ„ๅ‘ณใ—ใพใ™ใ€‚ใ—ใ‹ใ—ใ€ๆฑ—ใŒๅ‡บใ‚‹ใ“ใจใŒๅคšใ™ใŽใ‚‹ใจใ€ไธๅฟซๆ„Ÿใ‚„ๆฑ—็—‡ใชใฉใฎๅ•้กŒใŒ็™บ็”Ÿใ™ใ‚‹ใ“ใจใŒใ‚ใ‚Šใพใ™ใ€‚ไปฅไธ‹ใซใ€ๆฑ—ใŒๅ‡บใ‚‹ใ“ใจใŒๅคšใ„ๅ ดๅˆใฎๅฏพ็ญ–ใ‚’็ดนไป‹ใ—ใพใ™ใ€‚\n\n1. ้ฉๅˆ‡ใชๆœ่ฃ…ใ‚’้ธใถ: ๆฑ—ใŒๅ‡บใ‚‹ใ“ใจใŒๅคšใ„ๅ ดๅˆใ€่ปฝ้‡ใง้€ๆนฟๆ€งใฎ้ซ˜ใ„ๆœใ‚’้ธใถใ“ใจใŒ้‡่ฆใงใ™ใ€‚ใ“ใ‚Œใซใ‚ˆใ‚Šใ€ๆฑ—ใŒไฝ“ใ‹ใ‚‰ๅค–้ƒจใซ๏ฟฝ']
```
# Initial testing results
# Training details
The model was trained on 2 open source datasets (one multilingual) for one epoch on a A100 (80GB) x 4 environment for 3 hours.
## Training data
* [jondurbin/airoboros-3.2](https://huggingface.co/datasets/jondurbin/airoboros-3.2)
A ~59K example dataset of curated LLM tasks in English, primarily generated with GPT-4. This dataset has been used by some of the best performing open source LLMs in the world (e.g. [jondurbin/bagel-7b-v0.4](https://huggingface.co/jondurbin/bagel-7b-v0.4), [NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO](https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO)) and contains a wide variety of tasks, so we hypothesized that this would lead to a multi-talented, accurate model. For this reason we chose this dataset was chosen for the bulk of our training data.
Note: Each element in jondurbin/airoboros-3.2 already contains a system message.
* [openchat/openchat_sharegpt4_dataset](https://huggingface.co/datasets/openchat/openchat_sharegpt4_dataset) (GPT-4 responses only)
A ~6K example dataset of multilingual multi-turn chats between users and GPT-4. While jondurbin/airoboros-3.2 has deilvered good results for models previously, it sadly contains no (or seemingly very little) multilingual data. We are a Japanese AI company, so require an LLM to be able to output in Japanese too. Hence we also selected a small, seemingly high quality dataset of GPT-4 responses in many languages from the ShareGPT dataset. We chose to only select the GPT-4 responses as we wanted to keep our dataset as small and high quality as possible to maximise the efficiency of our training.
Note: openchat/openchat_sharegpt4_dataset does not contain system messages, so we added 'You are GPT-4, a helpful assistant.' as our system message.
<details>
<summary>Data preparation code</summary>
```python
import os
import pandas as pd
from datasets import load_dataset, Dataset, concatenate_datasets
os.environ['HF_HOME'] = "/workspace/hf_home"
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = "1"
boros_dataset = load_dataset("jondurbin/airoboros-3.2", split='train')
gpt4_df = pd.read_json("https://huggingface.co/datasets/openchat/openchat_sharegpt4_dataset/resolve/main/sharegpt_gpt4.json?download=true")
gpt4_df["conversations"] = gpt4_df["items"].apply(lambda x: [{'from': 'system', 'value': 'You are GPT-4, a helpful assistant.'}] + x)
gpt4_dataset = Dataset.from_pandas(gpt4_df[["conversations"]])
dataset = concatenate_datasets([gpt4_dataset, boros_dataset]).shuffle()
dataset.select_columns(["conversations"]).to_json("/workspace/airoboros-3.2_plus_openchat_sharegpt4.json")
```
</details>
## Training
The Jamba-v0.1 base model was trained for roughly 3 hours in a A100 (80GB) x 4 environment on the Azure cloud (Standard_NC96ads_A100_v4).
Our training harness was Axolotl, with the following config as our training parameters:
<details>
<summary>Training config</summary>
```python
base_model: ai21labs/Jamba-v0.1
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: /workspace/airoboros-3.2_plus_openchat_sharegpt4.json
ds_type: json
type: sharegpt
conversation: chatml
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./airoboros-3.2_plus_openchat_sharegpt4_one_epoch
sequence_len: 6000
sample_packing: true
pad_to_sequence_len: false
eval_sample_packing: true
use_wandb: true
wandb_project: axolotl
wandb_entity: peterd
wandb_name: airoboros-3.2_plus_openchat_sharegpt4
adapter: qlora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
low_cpu_mem_usage: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 5
saves_per_epoch: 5
debug:
deepspeed: /workspace/axolotl/deepspeed_configs/zero2.json
weight_decay: 0.0
special_tokens:
```
</details>
<details>
<summary>Training graphs</summary>
![image/png](https://cdn-uploads.huggingface.co/production/uploads/64b63f8ad57e02621dc93c8b/umxTIsNRHUtKS_kL81Uyf.png)
![image/png](https://cdn-uploads.huggingface.co/production/uploads/64b63f8ad57e02621dc93c8b/mpuCoL99rxX8RCgXH1CJo.png)
![image/png](https://cdn-uploads.huggingface.co/production/uploads/64b63f8ad57e02621dc93c8b/5FvwYNdte-bgzEvcvFO8I.png)
</details>
<br/>
# Developers
Lead developer - Peter Devine [ptrdvn](https://huggingface.co/ptrdvn)
Administrative supervisor - Shunichi Taniguchi [ptrdvn](https://huggingface.co/ptrdvn)