--- library_name: peft base_model: mistralai/Mistral-7B-v0.1 license: mit tags: - Mathematical Reasoning language: - en datasets: - meta-math/MetaMathQA - TIGER-Lab/MathInstruct --- **This repo contains LoRA adapter weights**. ### Model Description - **Project GitHub Page:** https://github.com/adityasihag1996/math_QA.git - **Developed by:** [Aditya Sihag](https://www.linkedin.com/in/aditya-sihag-ab29681a9/) - **Model type:** fine-tuned using QLoRA on 1x RTX 4090 - **Finetuned from model:** mistralai/Mistral-7B-v0.1 ## Results
Prompt Approach GSM8k MATH
Zero-Shot CoT 67.5 -
## Training procedure The following `bitsandbytes` quantization config was used during training: - quant_method: bitsandbytes - load_in_8bit: False - load_in_4bit: True - bnb_4bit_quant_type: nf4 - bnb_4bit_use_double_quant: True - bnb_4bit_compute_dtype: float16 `LoraConfig` params: - r: 128 - lora_alpha: lora_r * 2 - lora_dropout: 0.05 - bias: "none" - task_type: "CAUSAL_LM" - target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] The hyperparameters for the LoRA fine-tuning are listed below: - epochs: 3 - learning_rate: 5e-5 - batch_size: 256 - max_grad_norm: 1.0 - weight_decay: 0.001 - lr_scheduler_type: "cosine" - warmup_ratio: 0.03 ## Dataset math_QA dataset is prepared as combination of [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA) and [MathInstruct](https://huggingface.co/datasets/TIGER-Lab/MathInstruct). ## Model Usage ``` import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer ) from peft import PeftModel model_path = "mistralai/Mistral-7B-v0.1" model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype = torch.float16, device_map = {"": 0}, ) # Load LoRA and merge model = PeftModel.from_pretrained(model, "adityasihag/math_QA-Mistral-7B-QLoRA-adapter") model = model.merge_and_unload() tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token question = """Solve the linear equations. $3(x+2)-x=x + 9$""" sample_input = f"""Question: {question}. Find the value of x. \n Answer: Let's think step by step. """ sample_input_tokenised = tokenizer(sample_input, return_tensors = "pt").to("cuda") generated_ids = model.generate( **sample_input_tokenised, max_new_tokens = 1024, temperature = 0.3 ) output = tokenizer.decode(generated_ids[0], skip_special_tokens = True) print(output) ``` ##### Sample Input: ``` Question: Solve the linear equations. $3(x+2)-x=x + 9$. Find the value of x. \n Answer: Let's think step by step. ``` ##### Model Output: ``` To solve the linear equation $3(x+2)-x=x + 9$, we first distribute the 3 to the terms inside the parentheses: $3x + 6 - x = x + 9$ Now, we combine like terms: $2x + 6 = x + 9$ Next, we isolate the variable x by subtracting x from both sides: $2x - x = 9 - 6$ $x = 3$ So, the value of x is 3. ``` #### Prompt Template (CoT): ``` Question: Answer: Let's think step by step. ``` ## Comparing math_QA models with other SFT LLM models | Model | GSM8k Pass@1 | MATH Pass@1 | |---------------------|--------------|-------------| | LLaMA-2-7B | 14.6 | 2.5 | | LLaMA-2-13B | 28.7 | 3.9 | | LLaMA-2-34B | 42.2 | 6.24 | | WizardMath-7B | 54.9 | 10.7 | | LLaMA-2-70B | 56.8 | 13.5 | | WizardMath-13B | 63.9 | 14.0 | | MetaMath-7B | 66.5 | 19.8 | | **math_QA-Mistral-7B** | **67.5** | | | MetaMath-13B | 72.3 | 22.4 | | MetaMath-Mistral-7B | 77.7 | 28.2 | | Arithmo2-Mistral-7B | 76.4 | 27.2 | ### Reference

References

``` @article{yu2023metamath, title={MetaMath: Bootstrap Your Own Mathematical Questions for Large Language Models}, author={Yu, Longhui and Jiang, Weisen and Shi, Han and Yu, Jincheng and Liu, Zhengying and Zhang, Yu and Kwok, James T and Li, Zhenguo and Weller, Adrian and Liu, Weiyang}, journal={arXiv preprint arXiv:2309.12284}, year={2023} } @article{Yue2023mammoth, title={MAmmoTH: Building math generalist models through hybrid instruction tuning}, author={Xiang Yue, Xingwei Qu, Ge Zhang, Yao Fu, Wenhao Huang, Huan Sun, Yu Su, and Wenhu Chen}, journal={arXiv preprint arXiv:2309.05653}, year={2023} } ```