|
import os
|
|
from transformers import AutoTokenizer
|
|
from llmcompressor.transformers import SparseAutoModelForCausalLM
|
|
from llmcompressor.transformers import oneshot
|
|
from llmcompressor.modifiers.quantization import QuantizationModifier
|
|
|
|
def get_user_input():
|
|
"""Get model configuration from user input"""
|
|
print("\n=== Model Quantization Configuration ===")
|
|
|
|
while True:
|
|
model_id = input("\nEnter the HuggingFace model ID (e.g., meta-llama/Llama-2-7b-chat-hf): ").strip()
|
|
if model_id:
|
|
break
|
|
print("Model ID cannot be empty. Please try again.")
|
|
|
|
return model_id
|
|
|
|
def quantize_model_fp8(model_id):
|
|
"""
|
|
Quantize a model to FP8 Dynamic format using llm-compressor on CPU.
|
|
|
|
Args:
|
|
model_id (str): HuggingFace model ID
|
|
"""
|
|
try:
|
|
print(f"\nLoading model and tokenizer: {model_id}")
|
|
model = SparseAutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
device_map="cpu",
|
|
torch_dtype="auto"
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
print("\nConfiguring FP8 quantization recipe...")
|
|
recipe = QuantizationModifier(
|
|
targets="Linear",
|
|
scheme="FP8_DYNAMIC",
|
|
ignore=["lm_head"]
|
|
)
|
|
|
|
print("\nApplying quantization (this may take a while)...")
|
|
oneshot(model=model, recipe=recipe)
|
|
|
|
model_name = model_id.split("/")[-1]
|
|
save_dir = f"{model_name}-FP8-Dynamic"
|
|
|
|
print(f"\nSaving quantized model to: {save_dir}")
|
|
model.save_pretrained(save_dir, save_compressed=True)
|
|
tokenizer.save_pretrained(save_dir)
|
|
|
|
print("\nβ
Quantization completed successfully!")
|
|
print(f"π Quantized model saved to: {os.path.abspath(save_dir)}")
|
|
return save_dir
|
|
|
|
except Exception as e:
|
|
print(f"\nβ Error during quantization: {str(e)}")
|
|
return None
|
|
|
|
if __name__ == "__main__":
|
|
print("""
|
|
ββββββββββββββββββββββββββββββββββββββββ
|
|
β Model Quantization to FP8 β
|
|
β (Dynamic Per-Token) β
|
|
ββββββββββββββββββββββββββββββββββββββββ
|
|
""")
|
|
|
|
model_id = get_user_input()
|
|
|
|
print("\n=== Configuration Summary ===")
|
|
print(f"Model ID: {model_id}")
|
|
print("Quantization Type: FP8 Dynamic (per-token)")
|
|
print("Device: CPU")
|
|
|
|
while True:
|
|
confirm = input("\nProceed with quantization? (y/n): ").lower().strip()
|
|
if confirm in ['y', 'n']:
|
|
break
|
|
print("Please enter 'y' for yes or 'n' for no.")
|
|
|
|
if confirm == 'y':
|
|
quantized_model_path = quantize_model_fp8(model_id)
|
|
else:
|
|
print("\nQuantization cancelled.") |