import boto3 import logging import sagemaker from sagemaker.model import Model import argparse import os from datetime import datetime # Set up logging logging.basicConfig( format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO ) logger = logging.getLogger(__name__) def create_model_archive(model_path): """ Create a model archive if needed Args: model_path (str): Path to model files Returns: str: S3 URI of the model archive """ try: # Initialize S3 client s3 = boto3.client('s3') bucket = 'customer-support-gpt' model_key = 'models/model.tar.gz' # Check if model archive exists in S3 try: s3.head_object(Bucket=bucket, Key=model_key) logger.info("Model archive already exists in S3") except: logger.info("Model archive not found in S3, will be created during deployment") return f's3://{bucket}/{model_key}' except Exception as e: logger.error(f"Error creating model archive: {str(e)}") raise def deploy_app(acc_id, region_name, role_arn, ecr_repo_name, endpoint_name="customer-support-chatbot"): """ Deploys a Gradio app as a SageMaker endpoint using an ECR image. Args: acc_id (str): AWS account ID region_name (str): AWS region name role_arn (str): IAM role ARN for SageMaker ecr_repo_name (str): ECR repository name endpoint_name (str): SageMaker endpoint name """ try: logger.info("Starting SageMaker deployment process...") # Initialize SageMaker session sagemaker_session = sagemaker.Session() # Define the image URI in ECR ecr_image = f"{acc_id}.dkr.ecr.{region_name}.amazonaws.com/{ecr_repo_name}:latest" logger.info(f"Using ECR image: {ecr_image}") # Get model archive S3 URI model_data = create_model_archive("models/customer_support_gpt") # Define model configuration model_environment = { "MODEL_PATH": "/opt/ml/model", "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/code", "SAGEMAKER_PROGRAM": "inference.py" } # Create model logger.info("Creating SageMaker model...") model = Model( image_uri=ecr_image, model_data=model_data, role=role_arn, sagemaker_session=sagemaker_session, env=model_environment, enable_network_isolation=False ) # Define deployment configuration deployment_config = { "initial_instance_count": 1, "instance_type": "ml.m5.large", "endpoint_name": endpoint_name, "update_endpoint": True if _endpoint_exists(sagemaker_session, endpoint_name) else False } # Deploy model logger.info(f"Deploying model to endpoint: {endpoint_name}") logger.info(f"Deployment configuration: {deployment_config}") predictor = model.deploy(**deployment_config,health_check_timeout_seconds=180, # Optionally, increase the timeout environment={ "SAGEMAKER_CONTAINER_PORT": "8080", "FLASK_HEALTHCHECK_PORT": "8081", }) logger.info(f"Successfully deployed to endpoint: {endpoint_name}") return predictor except Exception as e: logger.error(f"Deployment failed: {str(e)}") raise def _endpoint_exists(sagemaker_session, endpoint_name): """Check if SageMaker endpoint already exists""" client = sagemaker_session.boto_session.client('sagemaker') try: client.describe_endpoint(EndpointName=endpoint_name) return True except client.exceptions.ClientError: return False def main(): parser = argparse.ArgumentParser(description="Deploy Gradio app to SageMaker") parser.add_argument("--account_id", type=str, required=True, help="AWS Account ID") parser.add_argument("--region", type=str, required=True, help="AWS Region") parser.add_argument("--role_arn", type=str, required=True, help="IAM Role ARN for SageMaker") parser.add_argument("--ecr_repo_name", type=str, required=True, help="ECR Repository name") parser.add_argument("--endpoint_name", type=str, default="customer-support-chatbot", help="SageMaker Endpoint Name") args = parser.parse_args() try: logger.info("Starting deployment process...") deploy_app( args.account_id, args.region, args.role_arn, args.ecr_repo_name, args.endpoint_name ) logger.info("Deployment completed successfully!") except Exception as e: logger.error(f"Deployment failed: {str(e)}") raise if __name__ == "__main__": main()