from pathlib import Path from huggingface_hub import HfApi from loguru import logger try: from sagemaker.huggingface import HuggingFace except ModuleNotFoundError: logger.warning("Couldn't load SageMaker imports. Run 'poetry install --with aws' to support AWS.") from llm_engineering.settings import settings finetuning_dir = Path(__file__).resolve().parent finetuning_requirements_path = finetuning_dir / "requirements.txt" def run_finetuning_on_sagemaker( finetuning_type: str = "sft", num_train_epochs: int = 3, per_device_train_batch_size: int = 2, learning_rate: float = 3e-4, dataset_huggingface_workspace: str = "mlabonne", is_dummy: bool = False, ) -> None: assert settings.HUGGINGFACE_ACCESS_TOKEN, "Hugging Face access token is required." assert settings.AWS_ARN_ROLE, "AWS ARN role is required." if not finetuning_dir.exists(): raise FileNotFoundError(f"The directory {finetuning_dir} does not exist.") if not finetuning_requirements_path.exists(): raise FileNotFoundError(f"The file {finetuning_requirements_path} does not exist.") api = HfApi() user_info = api.whoami(token=settings.HUGGINGFACE_ACCESS_TOKEN) huggingface_user = user_info["name"] logger.info(f"Current Hugging Face user: {huggingface_user}") hyperparameters = { "finetuning_type": finetuning_type, "num_train_epochs": num_train_epochs, "per_device_train_batch_size": per_device_train_batch_size, "learning_rate": learning_rate, "dataset_huggingface_workspace": dataset_huggingface_workspace, "model_output_huggingface_workspace": huggingface_user, } if is_dummy: hyperparameters["is_dummy"] = True # Create the HuggingFace SageMaker estimator huggingface_estimator = HuggingFace( entry_point="finetune.py", source_dir=str(finetuning_dir), instance_type="ml.g5.2xlarge", instance_count=1, role=settings.AWS_ARN_ROLE, transformers_version="4.36", pytorch_version="2.1", py_version="py310", hyperparameters=hyperparameters, requirements_file=finetuning_requirements_path, environment={ "HUGGING_FACE_HUB_TOKEN": settings.HUGGINGFACE_ACCESS_TOKEN, "COMET_API_KEY": settings.COMET_API_KEY, "COMET_PROJECT_NAME": settings.COMET_PROJECT, }, ) # Start the training job on SageMaker. huggingface_estimator.fit() if __name__ == "__main__": run_finetuning_on_sagemaker()