Customer-Support-Chatbot / src /deploy_sagemaker.py
VenkateshRoshan
flask ping code updated
2c38e04
import boto3
import logging
import sagemaker
from sagemaker.model import Model
import argparse
import os
from datetime import datetime
# Set up logging
logging.basicConfig(
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.INFO
)
logger = logging.getLogger(__name__)
def create_model_archive(model_path):
"""
Create a model archive if needed
Args:
model_path (str): Path to model files
Returns:
str: S3 URI of the model archive
"""
try:
# Initialize S3 client
s3 = boto3.client('s3')
bucket = 'customer-support-gpt'
model_key = 'models/model.tar.gz'
# Check if model archive exists in S3
try:
s3.head_object(Bucket=bucket, Key=model_key)
logger.info("Model archive already exists in S3")
except:
logger.info("Model archive not found in S3, will be created during deployment")
return f's3://{bucket}/{model_key}'
except Exception as e:
logger.error(f"Error creating model archive: {str(e)}")
raise
def deploy_app(acc_id, region_name, role_arn, ecr_repo_name, endpoint_name="customer-support-chatbot"):
"""
Deploys a Gradio app as a SageMaker endpoint using an ECR image.
Args:
acc_id (str): AWS account ID
region_name (str): AWS region name
role_arn (str): IAM role ARN for SageMaker
ecr_repo_name (str): ECR repository name
endpoint_name (str): SageMaker endpoint name
"""
try:
logger.info("Starting SageMaker deployment process...")
# Initialize SageMaker session
sagemaker_session = sagemaker.Session()
# Define the image URI in ECR
ecr_image = f"{acc_id}.dkr.ecr.{region_name}.amazonaws.com/{ecr_repo_name}:latest"
logger.info(f"Using ECR image: {ecr_image}")
# Get model archive S3 URI
model_data = create_model_archive("models/customer_support_gpt")
# Define model configuration
model_environment = {
"MODEL_PATH": "/opt/ml/model",
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/code",
"SAGEMAKER_PROGRAM": "inference.py"
}
# Create model
logger.info("Creating SageMaker model...")
model = Model(
image_uri=ecr_image,
model_data=model_data,
role=role_arn,
sagemaker_session=sagemaker_session,
env=model_environment,
enable_network_isolation=False
)
# Define deployment configuration
deployment_config = {
"initial_instance_count": 1,
"instance_type": "ml.m5.large",
"endpoint_name": endpoint_name,
"update_endpoint": True if _endpoint_exists(sagemaker_session, endpoint_name) else False
}
# Deploy model
logger.info(f"Deploying model to endpoint: {endpoint_name}")
logger.info(f"Deployment configuration: {deployment_config}")
predictor = model.deploy(**deployment_config,health_check_timeout_seconds=180, # Optionally, increase the timeout
environment={
"SAGEMAKER_CONTAINER_PORT": "8080",
"FLASK_HEALTHCHECK_PORT": "8081",
})
logger.info(f"Successfully deployed to endpoint: {endpoint_name}")
return predictor
except Exception as e:
logger.error(f"Deployment failed: {str(e)}")
raise
def _endpoint_exists(sagemaker_session, endpoint_name):
"""Check if SageMaker endpoint already exists"""
client = sagemaker_session.boto_session.client('sagemaker')
try:
client.describe_endpoint(EndpointName=endpoint_name)
return True
except client.exceptions.ClientError:
return False
def main():
parser = argparse.ArgumentParser(description="Deploy Gradio app to SageMaker")
parser.add_argument("--account_id", type=str, required=True,
help="AWS Account ID")
parser.add_argument("--region", type=str, required=True,
help="AWS Region")
parser.add_argument("--role_arn", type=str, required=True,
help="IAM Role ARN for SageMaker")
parser.add_argument("--ecr_repo_name", type=str, required=True,
help="ECR Repository name")
parser.add_argument("--endpoint_name", type=str,
default="customer-support-chatbot",
help="SageMaker Endpoint Name")
args = parser.parse_args()
try:
logger.info("Starting deployment process...")
deploy_app(
args.account_id,
args.region,
args.role_arn,
args.ecr_repo_name,
args.endpoint_name
)
logger.info("Deployment completed successfully!")
except Exception as e:
logger.error(f"Deployment failed: {str(e)}")
raise
if __name__ == "__main__":
main()