Spaces:
Runtime error
Runtime error
File size: 5,096 Bytes
e1d0160 b1d9c58 e1d0160 dd44216 e1d0160 671ee28 e1d0160 671ee28 b1d9c58 e1d0160 b1d9c58 671ee28 b1d9c58 671ee28 6ddaf47 671ee28 2c38e04 671ee28 b1d9c58 671ee28 e1d0160 671ee28 b1d9c58 671ee28 b1d9c58 671ee28 e1d0160 671ee28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import boto3
import logging
import sagemaker
from sagemaker.model import Model
import argparse
import os
from datetime import datetime
# Set up logging
logging.basicConfig(
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.INFO
)
logger = logging.getLogger(__name__)
def create_model_archive(model_path):
"""
Create a model archive if needed
Args:
model_path (str): Path to model files
Returns:
str: S3 URI of the model archive
"""
try:
# Initialize S3 client
s3 = boto3.client('s3')
bucket = 'customer-support-gpt'
model_key = 'models/model.tar.gz'
# Check if model archive exists in S3
try:
s3.head_object(Bucket=bucket, Key=model_key)
logger.info("Model archive already exists in S3")
except:
logger.info("Model archive not found in S3, will be created during deployment")
return f's3://{bucket}/{model_key}'
except Exception as e:
logger.error(f"Error creating model archive: {str(e)}")
raise
def deploy_app(acc_id, region_name, role_arn, ecr_repo_name, endpoint_name="customer-support-chatbot"):
"""
Deploys a Gradio app as a SageMaker endpoint using an ECR image.
Args:
acc_id (str): AWS account ID
region_name (str): AWS region name
role_arn (str): IAM role ARN for SageMaker
ecr_repo_name (str): ECR repository name
endpoint_name (str): SageMaker endpoint name
"""
try:
logger.info("Starting SageMaker deployment process...")
# Initialize SageMaker session
sagemaker_session = sagemaker.Session()
# Define the image URI in ECR
ecr_image = f"{acc_id}.dkr.ecr.{region_name}.amazonaws.com/{ecr_repo_name}:latest"
logger.info(f"Using ECR image: {ecr_image}")
# Get model archive S3 URI
model_data = create_model_archive("models/customer_support_gpt")
# Define model configuration
model_environment = {
"MODEL_PATH": "/opt/ml/model",
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/code",
"SAGEMAKER_PROGRAM": "inference.py"
}
# Create model
logger.info("Creating SageMaker model...")
model = Model(
image_uri=ecr_image,
model_data=model_data,
role=role_arn,
sagemaker_session=sagemaker_session,
env=model_environment,
enable_network_isolation=False
)
# Define deployment configuration
deployment_config = {
"initial_instance_count": 1,
"instance_type": "ml.m5.large",
"endpoint_name": endpoint_name,
"update_endpoint": True if _endpoint_exists(sagemaker_session, endpoint_name) else False
}
# Deploy model
logger.info(f"Deploying model to endpoint: {endpoint_name}")
logger.info(f"Deployment configuration: {deployment_config}")
predictor = model.deploy(**deployment_config,health_check_timeout_seconds=180, # Optionally, increase the timeout
environment={
"SAGEMAKER_CONTAINER_PORT": "8080",
"FLASK_HEALTHCHECK_PORT": "8081",
})
logger.info(f"Successfully deployed to endpoint: {endpoint_name}")
return predictor
except Exception as e:
logger.error(f"Deployment failed: {str(e)}")
raise
def _endpoint_exists(sagemaker_session, endpoint_name):
"""Check if SageMaker endpoint already exists"""
client = sagemaker_session.boto_session.client('sagemaker')
try:
client.describe_endpoint(EndpointName=endpoint_name)
return True
except client.exceptions.ClientError:
return False
def main():
parser = argparse.ArgumentParser(description="Deploy Gradio app to SageMaker")
parser.add_argument("--account_id", type=str, required=True,
help="AWS Account ID")
parser.add_argument("--region", type=str, required=True,
help="AWS Region")
parser.add_argument("--role_arn", type=str, required=True,
help="IAM Role ARN for SageMaker")
parser.add_argument("--ecr_repo_name", type=str, required=True,
help="ECR Repository name")
parser.add_argument("--endpoint_name", type=str,
default="customer-support-chatbot",
help="SageMaker Endpoint Name")
args = parser.parse_args()
try:
logger.info("Starting deployment process...")
deploy_app(
args.account_id,
args.region,
args.role_arn,
args.ecr_repo_name,
args.endpoint_name
)
logger.info("Deployment completed successfully!")
except Exception as e:
logger.error(f"Deployment failed: {str(e)}")
raise
if __name__ == "__main__":
main() |