forestav commited on
Commit
471f7d4
·
verified ·
1 Parent(s): 22dc869

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +69 -0
README.md CHANGED
@@ -9,4 +9,73 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
9
  pinned: false
10
  ---
11
 
12
+ # Fine-Tuned Medical Language Model
13
+
14
+ ## Overview
15
+ This project fine-tunes the LLaMA 3.2 3B model using the **FineTome-100k** instruction dataset. The goal is to develop a performant language model for medical instruction tasks, optimized for inference on CPU.
16
+
17
+ ## Key Features
18
+ - **Base Model**: LLaMA 3.2 3B (fine-tuned with Hugging Face Transformers and Unsloth).
19
+ - **Dataset**: FineTome-100k, a high-quality instruction dataset.
20
+ - **Inference Optimization**: Quantized to GGUF format for faster CPU inference using methods like Q4_K_M.
21
+
22
+ ## Improvements
23
+ ### Model-Centric Approach
24
+ 1. **Hyperparameter Tuning**:
25
+ - **Learning Rate**: Reduced to `1e-4` and tested against `2e-4` for better generalization.
26
+ - **Warmup Steps**: Increased to 100 to stabilize early training.
27
+ - **Batch Size**: Adjusted via gradient accumulation to simulate larger effective batch sizes.
28
+
29
+ 2. **Fine-Tuning Techniques**:
30
+ - Resumed training from a 3,000-step checkpoint to save time.
31
+ - Applied `adamw_8bit` optimizer for memory-efficient training.
32
+
33
+ 3. **Experimentation with Foundation Models**:
34
+ - Tested alternative open-source models, including Falcon-7B and Mistral 3B, for comparison.
35
+
36
+ ### Data-Centric Approach
37
+ 1. **Additional Data Sources**:
38
+ - Plans to augment training with datasets like PubMedQA or MedQA for domain-specific improvements.
39
+ - Diversity of instructions to improve robustness across medical queries.
40
+
41
+ 2. **Dataset Analysis**:
42
+ - Addressed class imbalances and ensured validation split consistency.
43
+
44
+ ## Hyperparameters
45
+ The final training used the following hyperparameters:
46
+ - **Learning Rate**: 1e-4
47
+ - **Warmup Steps**: 100
48
+ - **Batch Size**: Simulated effective batch size of 8 (2 samples per device with 4 gradient accumulation steps).
49
+ - **Optimizer**: AdamW (8-bit quantization).
50
+ - **Weight Decay**: 0.01
51
+ - **Learning Rate Scheduler**: Linear decay.
52
+
53
+ ## Model Performance
54
+ ### Training
55
+ - **Steps**: Fine-tuned for 6,000 steps total (3,000 initial + 3,000 resumed).
56
+ - **Validation Loss**: Improved from X to Y during fine-tuning.
57
+
58
+ ### Inference
59
+ - **Quantized Format**: Q4_K_M and F16 formats evaluated for inference speed.
60
+ - **CPU Latency**: Achieved X ms per query on a single-core CPU.
61
+
62
+ ## Next Steps
63
+ 1. Continue fine-tuning with additional data sources (e.g., MedQA).
64
+ 2. Explore LoRA or parameter-efficient tuning for larger models.
65
+ 3. Deploy and evaluate the model in real-world scenarios.
66
+
67
+ ## Usage
68
+ To load and use the model:
69
+ ```python
70
+ from transformers import AutoTokenizer, AutoModelForCausalLM
71
+
72
+ model_name = "forestav/medical_model"
73
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
74
+ model = AutoModelForCausalLM.from_pretrained(model_name)
75
+
76
+ # Generate predictions
77
+ inputs = tokenizer("What are the symptoms of diabetes?", return_tensors="pt")
78
+ outputs = model.generate(**inputs)
79
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
80
+
81
  An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).