File size: 5,096 Bytes
e1d0160
 
b1d9c58
 
 
e1d0160
dd44216
e1d0160
 
671ee28
 
 
 
e1d0160
 
671ee28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1d9c58
 
 
e1d0160
b1d9c58
 
 
 
 
671ee28
b1d9c58
671ee28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ddaf47
671ee28
 
 
 
 
 
 
 
2c38e04
 
 
 
 
671ee28
 
 
 
 
 
 
b1d9c58
671ee28
 
 
 
 
 
 
 
e1d0160
671ee28
b1d9c58
671ee28
 
 
 
 
 
 
 
 
 
 
 
b1d9c58
671ee28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1d0160
671ee28
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import boto3
import logging
import sagemaker
from sagemaker.model import Model
import argparse
import os
from datetime import datetime

# Set up logging
logging.basicConfig(
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    level=logging.INFO
)
logger = logging.getLogger(__name__)

def create_model_archive(model_path):
    """
    Create a model archive if needed
    
    Args:
        model_path (str): Path to model files
        
    Returns:
        str: S3 URI of the model archive
    """
    try:
        # Initialize S3 client
        s3 = boto3.client('s3')
        bucket = 'customer-support-gpt'
        model_key = 'models/model.tar.gz'
        
        # Check if model archive exists in S3
        try:
            s3.head_object(Bucket=bucket, Key=model_key)
            logger.info("Model archive already exists in S3")
        except:
            logger.info("Model archive not found in S3, will be created during deployment")
            
        return f's3://{bucket}/{model_key}'
    except Exception as e:
        logger.error(f"Error creating model archive: {str(e)}")
        raise

def deploy_app(acc_id, region_name, role_arn, ecr_repo_name, endpoint_name="customer-support-chatbot"):
    """
    Deploys a Gradio app as a SageMaker endpoint using an ECR image.
    
    Args:
        acc_id (str): AWS account ID
        region_name (str): AWS region name
        role_arn (str): IAM role ARN for SageMaker
        ecr_repo_name (str): ECR repository name
        endpoint_name (str): SageMaker endpoint name
    """
    try:
        logger.info("Starting SageMaker deployment process...")
        
        # Initialize SageMaker session
        sagemaker_session = sagemaker.Session()
        
        # Define the image URI in ECR
        ecr_image = f"{acc_id}.dkr.ecr.{region_name}.amazonaws.com/{ecr_repo_name}:latest"
        logger.info(f"Using ECR image: {ecr_image}")
        
        # Get model archive S3 URI
        model_data = create_model_archive("models/customer_support_gpt")
        
        # Define model configuration
        model_environment = {
            "MODEL_PATH": "/opt/ml/model",
            "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/code",
            "SAGEMAKER_PROGRAM": "inference.py"
        }
        
        # Create model
        logger.info("Creating SageMaker model...")
        model = Model(
            image_uri=ecr_image,
            model_data=model_data,
            role=role_arn,
            sagemaker_session=sagemaker_session,
            env=model_environment,
            enable_network_isolation=False
        )
        
        # Define deployment configuration
        deployment_config = {
            "initial_instance_count": 1,
            "instance_type": "ml.m5.large",
            "endpoint_name": endpoint_name,
            "update_endpoint": True if _endpoint_exists(sagemaker_session, endpoint_name) else False
        }
        
        # Deploy model
        logger.info(f"Deploying model to endpoint: {endpoint_name}")
        logger.info(f"Deployment configuration: {deployment_config}")
        
        predictor = model.deploy(**deployment_config,health_check_timeout_seconds=180,  # Optionally, increase the timeout
        environment={
            "SAGEMAKER_CONTAINER_PORT": "8080",
            "FLASK_HEALTHCHECK_PORT": "8081",
        })
        
        logger.info(f"Successfully deployed to endpoint: {endpoint_name}")
        return predictor
        
    except Exception as e:
        logger.error(f"Deployment failed: {str(e)}")
        raise

def _endpoint_exists(sagemaker_session, endpoint_name):
    """Check if SageMaker endpoint already exists"""
    client = sagemaker_session.boto_session.client('sagemaker')
    try:
        client.describe_endpoint(EndpointName=endpoint_name)
        return True
    except client.exceptions.ClientError:
        return False

def main():
    parser = argparse.ArgumentParser(description="Deploy Gradio app to SageMaker")
    parser.add_argument("--account_id", type=str, required=True,
                      help="AWS Account ID")
    parser.add_argument("--region", type=str, required=True,
                      help="AWS Region")
    parser.add_argument("--role_arn", type=str, required=True,
                      help="IAM Role ARN for SageMaker")
    parser.add_argument("--ecr_repo_name", type=str, required=True,
                      help="ECR Repository name")
    parser.add_argument("--endpoint_name", type=str,
                      default="customer-support-chatbot",
                      help="SageMaker Endpoint Name")
    
    args = parser.parse_args()
    
    try:
        logger.info("Starting deployment process...")
        deploy_app(
            args.account_id,
            args.region,
            args.role_arn,
            args.ecr_repo_name,
            args.endpoint_name
        )
        logger.info("Deployment completed successfully!")
        
    except Exception as e:
        logger.error(f"Deployment failed: {str(e)}")
        raise

if __name__ == "__main__":
    main()