Text Generation
Transformers
PyTorch
Japanese
mistral
text-generation-inference
ptrdvn commited on
Commit
bd0b6a8
1 Parent(s): b687f53

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +185 -0
README.md ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to use
2
+
3
+ We write our prompts in the ChatML format.
4
+
5
+ ### With vLLM (recommended for much faster inference)
6
+
7
+ <details><summary>Install vLLM</summary>
8
+ ```bash
9
+ pip install vllm
10
+ ```
11
+ [Reference](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
12
+ </details>
13
+
14
+ ```python
15
+ from vllm import LLM, SamplingParams
16
+ model_name = "lightblue/jod"
17
+ llm = LLM(model=model_name)
18
+
19
+ SYSTEM_MESSAGE = "You are a helpful assistant."
20
+ def process_chat_history(next_user_msg, text_chat_history = []):
21
+ prompt_text = "<|im_start|>system\n"
22
+ prompt_text += SYSTEM_MESSAGE
23
+ prompt_text += "<|im_end|>\n\n"
24
+
25
+ for user_msg, ai_msg in text_chat_history:
26
+ prompt_text += "<|im_start|>user\n"
27
+ prompt_text += user_msg
28
+ prompt_text += "<|im_end|>\n\n"
29
+ prompt_text += "<|im_start|>assistant\n"
30
+ prompt_text += ai_msg
31
+ prompt_text += "<|im_end|>\n\n"
32
+
33
+ prompt_text += "<|im_start|>user\n"
34
+ prompt_text += next_user_msg
35
+ prompt_text += "<|im_end|>\n\n"
36
+ prompt_text += "<|im_start|>assistant\n"
37
+ return prompt_text
38
+
39
+ user_prompt = "鏃ユ湰銇竴鐣珮銇勫北銇紵"
40
+ prompt = process_chat_history(user_prompt)
41
+ sampling_params = SamplingParams(temperature=0, max_tokens=528)
42
+ outputs = llm.generate(prompt, sampling_params)
43
+ bot_message = outputs[0].outputs[0].text.strip()
44
+ print(bot_message)
45
+ ```
46
+
47
+
48
+ ### With Huggingface
49
+
50
+ ```python
51
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
52
+
53
+ model_name = "lightblue/jod"
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
56
+ model = AutoModelForCausalLM.from_pretrained(
57
+ model_dir, torch_dtype=torch.bfloat16, device_map='auto', load_in_4bit=True,
58
+ )
59
+
60
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
61
+
62
+ SYSTEM_MESSAGE = "You are a helpful assistant."
63
+ def process_chat_history(next_user_msg, text_chat_history = []):
64
+ prompt_text = "<|im_start|>system\n"
65
+ prompt_text += SYSTEM_MESSAGE
66
+ prompt_text += "<|im_end|>\n\n"
67
+
68
+ for user_msg, ai_msg in text_chat_history:
69
+ prompt_text += "<|im_start|>user\n"
70
+ prompt_text += user_msg
71
+ prompt_text += "<|im_end|>\n\n"
72
+ prompt_text += "<|im_start|>assistant\n"
73
+ prompt_text += ai_msg
74
+ prompt_text += "<|im_end|>\n\n"
75
+
76
+ prompt_text += "<|im_start|>user\n"
77
+ prompt_text += next_user_msg
78
+ prompt_text += "<|im_end|>\n\n"
79
+ prompt_text += "<|im_start|>assistant\n"
80
+ return prompt_text
81
+
82
+ user_prompt = "鏃ユ湰銇竴鐣珮銇勫北銇紵"
83
+ prompt = process_chat_history(user_prompt)
84
+ bot_message = pipe(do_closed_qa(test_article, question), max_new_tokens=128, temperature=0)[0]["generated_text"]
85
+ print(bot_message)
86
+ ```
87
+
88
+
89
+ # Training datasets
90
+ This model was trained using the ChatML format, so it should be used for inference using the ChatML chatbot format.
91
+ We chose this format as the base model ([Open-Orca/Mistral-7B-SlimOrca](https://huggingface.co/Open-Orca/Mistral-7B-SlimOrca)) was trained with this format, and we find the chatbot format more compelling for practical use compared to the Alpaca style instruction format.
92
+
93
+ * [JASTER](https://github.com/llm-jp/llm-jp-eval)
94
+ * [kunishou/oasst1-89k-ja](https://huggingface.co/datasets/kunishou/oasst1-89k-ja/)
95
+ * [kunishou/databricks-dolly-15k-ja](https://huggingface.co/datasets/kunishou/databricks-dolly-15k-ja/)
96
+
97
+ We trained for 1 epoch using the following Axolotl config. (Early stopping was not performed during our training.)
98
+ <details><summary>Axolotl config .yaml</summary>
99
+
100
+ ```yaml
101
+ base_model: Open-Orca/Mistral-7B-SlimOrca
102
+ base_model_config: Open-Orca/Mistral-7B-SlimOrca
103
+ model_type: MistralForCausalLM
104
+ tokenizer_type: LlamaTokenizer
105
+ is_mistral_derived_model: true
106
+
107
+ load_in_8bit: false
108
+ load_in_4bit: true
109
+ strict: false
110
+
111
+ datasets:
112
+ - path: ./data/jaster_plus.jsonl
113
+ ds_type: json # see other options below
114
+ type: sharegpt
115
+ conversation: chatml
116
+ dataset_prepared_path: false
117
+ val_set_size: 0.002
118
+ output_dir: ./train_output/openorca-mistral-jaster-1epoch
119
+
120
+ use_wandb: true
121
+ wandb_project: \<HIDDEN\>
122
+ wandb_entity: \<HIDDEN\>
123
+
124
+ debug:
125
+
126
+ adapter: qlora
127
+ lora_model_dir:
128
+
129
+ sequence_len: 4096
130
+ sample_packing: true
131
+ pad_to_sequence_len: true
132
+
133
+ lora_r: 32
134
+ lora_alpha: 16
135
+ lora_dropout: 0.05
136
+ lora_target_linear: true
137
+ lora_fan_in_fan_out:
138
+ lora_target_modules:
139
+ - gate_proj
140
+ - down_proj
141
+ - up_proj
142
+ - q_proj
143
+ - v_proj
144
+ - k_proj
145
+ - o_proj
146
+
147
+ gradient_accumulation_steps: 1
148
+ micro_batch_size: 10
149
+ eval_batch_size: 4
150
+ num_epochs: 1
151
+ optimizer: adamw_bnb_8bit
152
+ lr_scheduler: cosine
153
+ learning_rate: 0.0002
154
+
155
+ train_on_inputs: false
156
+ group_by_length: false
157
+ bf16: true
158
+ fp16: false
159
+ tf32: false
160
+
161
+ gradient_checkpointing: true
162
+ early_stopping_patience: 10
163
+ resume_from_checkpoint:
164
+ local_rank:
165
+ logging_steps: 1
166
+ xformers_attention:
167
+ flash_attention: true
168
+
169
+ warmup_steps: 10
170
+ eval_steps: 10
171
+ eval_table_size: 5
172
+ eval_table_max_new_tokens: 128
173
+ save_steps: 10
174
+ debug:
175
+ deepspeed:
176
+ weight_decay: 0.0
177
+ fsdp:
178
+ fsdp_config:
179
+ special_tokens:
180
+ bos_token: "<s>"
181
+ eos_token: "</s>"
182
+ unk_token: "<unk>"
183
+ ```
184
+
185
+ </details>