import os from transformers import AutoTokenizer from llmcompressor.transformers import SparseAutoModelForCausalLM from llmcompressor.transformers import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier def get_user_input(): """Get model configuration from user input""" print("\n=== Model Quantization Configuration ===") while True: model_id = input("\nEnter the HuggingFace model ID (e.g., meta-llama/Llama-2-7b-chat-hf): ").strip() if model_id: break print("Model ID cannot be empty. Please try again.") return model_id def quantize_model_fp8(model_id): """ Quantize a model to FP8 Dynamic format using llm-compressor on CPU. Args: model_id (str): HuggingFace model ID """ try: print(f"\nLoading model and tokenizer: {model_id}") model = SparseAutoModelForCausalLM.from_pretrained( model_id, device_map="cpu", torch_dtype="auto" ) tokenizer = AutoTokenizer.from_pretrained(model_id) print("\nConfiguring FP8 quantization recipe...") recipe = QuantizationModifier( targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"] ) print("\nApplying quantization (this may take a while)...") oneshot(model=model, recipe=recipe) model_name = model_id.split("/")[-1] save_dir = f"{model_name}-FP8-Dynamic" print(f"\nSaving quantized model to: {save_dir}") model.save_pretrained(save_dir, save_compressed=True) tokenizer.save_pretrained(save_dir) print("\nāœ… Quantization completed successfully!") print(f"šŸ“ Quantized model saved to: {os.path.abspath(save_dir)}") return save_dir except Exception as e: print(f"\nāŒ Error during quantization: {str(e)}") return None if __name__ == "__main__": print(""" ā•”ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•— ā•‘ Model Quantization to FP8 ā•‘ ā•‘ (Dynamic Per-Token) ā•‘ ā•šā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā• """) model_id = get_user_input() print("\n=== Configuration Summary ===") print(f"Model ID: {model_id}") print("Quantization Type: FP8 Dynamic (per-token)") print("Device: CPU") while True: confirm = input("\nProceed with quantization? (y/n): ").lower().strip() if confirm in ['y', 'n']: break print("Please enter 'y' for yes or 'n' for no.") if confirm == 'y': quantized_model_path = quantize_model_fp8(model_id) else: print("\nQuantization cancelled.")