Upload quant.py
Browse files- quant/quant.py +84 -0
quant/quant.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from llmcompressor.transformers import SparseAutoModelForCausalLM
|
4 |
+
from llmcompressor.transformers import oneshot
|
5 |
+
from llmcompressor.modifiers.quantization import QuantizationModifier
|
6 |
+
|
7 |
+
def get_user_input():
|
8 |
+
"""Get model configuration from user input"""
|
9 |
+
print("\n=== Model Quantization Configuration ===")
|
10 |
+
|
11 |
+
while True:
|
12 |
+
model_id = input("\nEnter the HuggingFace model ID (e.g., meta-llama/Llama-2-7b-chat-hf): ").strip()
|
13 |
+
if model_id:
|
14 |
+
break
|
15 |
+
print("Model ID cannot be empty. Please try again.")
|
16 |
+
|
17 |
+
return model_id
|
18 |
+
|
19 |
+
def quantize_model_fp8(model_id):
|
20 |
+
"""
|
21 |
+
Quantize a model to FP8 Dynamic format using llm-compressor on CPU.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
model_id (str): HuggingFace model ID
|
25 |
+
"""
|
26 |
+
try:
|
27 |
+
print(f"\nLoading model and tokenizer: {model_id}")
|
28 |
+
model = SparseAutoModelForCausalLM.from_pretrained(
|
29 |
+
model_id,
|
30 |
+
device_map="cpu",
|
31 |
+
torch_dtype="auto"
|
32 |
+
)
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
34 |
+
|
35 |
+
print("\nConfiguring FP8 quantization recipe...")
|
36 |
+
recipe = QuantizationModifier(
|
37 |
+
targets="Linear",
|
38 |
+
scheme="FP8_DYNAMIC",
|
39 |
+
ignore=["lm_head"]
|
40 |
+
)
|
41 |
+
|
42 |
+
print("\nApplying quantization (this may take a while)...")
|
43 |
+
oneshot(model=model, recipe=recipe)
|
44 |
+
|
45 |
+
model_name = model_id.split("/")[-1]
|
46 |
+
save_dir = f"{model_name}-FP8-Dynamic"
|
47 |
+
|
48 |
+
print(f"\nSaving quantized model to: {save_dir}")
|
49 |
+
model.save_pretrained(save_dir, save_compressed=True)
|
50 |
+
tokenizer.save_pretrained(save_dir)
|
51 |
+
|
52 |
+
print("\nβ
Quantization completed successfully!")
|
53 |
+
print(f"π Quantized model saved to: {os.path.abspath(save_dir)}")
|
54 |
+
return save_dir
|
55 |
+
|
56 |
+
except Exception as e:
|
57 |
+
print(f"\nβ Error during quantization: {str(e)}")
|
58 |
+
return None
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
print("""
|
62 |
+
ββββββββββββββββββββββββββββββββββββββββ
|
63 |
+
β Model Quantization to FP8 β
|
64 |
+
β (Dynamic Per-Token) β
|
65 |
+
ββββββββββββββββββββββββββββββββββββββββ
|
66 |
+
""")
|
67 |
+
|
68 |
+
model_id = get_user_input()
|
69 |
+
|
70 |
+
print("\n=== Configuration Summary ===")
|
71 |
+
print(f"Model ID: {model_id}")
|
72 |
+
print("Quantization Type: FP8 Dynamic (per-token)")
|
73 |
+
print("Device: CPU")
|
74 |
+
|
75 |
+
while True:
|
76 |
+
confirm = input("\nProceed with quantization? (y/n): ").lower().strip()
|
77 |
+
if confirm in ['y', 'n']:
|
78 |
+
break
|
79 |
+
print("Please enter 'y' for yes or 'n' for no.")
|
80 |
+
|
81 |
+
if confirm == 'y':
|
82 |
+
quantized_model_path = quantize_model_fp8(model_id)
|
83 |
+
else:
|
84 |
+
print("\nQuantization cancelled.")
|