Spaces:
Runtime error
Runtime error
import boto3 | |
import logging | |
import sagemaker | |
from sagemaker.model import Model | |
import argparse | |
import os | |
from datetime import datetime | |
# Set up logging | |
logging.basicConfig( | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
def create_model_archive(model_path): | |
""" | |
Create a model archive if needed | |
Args: | |
model_path (str): Path to model files | |
Returns: | |
str: S3 URI of the model archive | |
""" | |
try: | |
# Initialize S3 client | |
s3 = boto3.client('s3') | |
bucket = 'customer-support-gpt' | |
model_key = 'models/model.tar.gz' | |
# Check if model archive exists in S3 | |
try: | |
s3.head_object(Bucket=bucket, Key=model_key) | |
logger.info("Model archive already exists in S3") | |
except: | |
logger.info("Model archive not found in S3, will be created during deployment") | |
return f's3://{bucket}/{model_key}' | |
except Exception as e: | |
logger.error(f"Error creating model archive: {str(e)}") | |
raise | |
def deploy_app(acc_id, region_name, role_arn, ecr_repo_name, endpoint_name="customer-support-chatbot"): | |
""" | |
Deploys a Gradio app as a SageMaker endpoint using an ECR image. | |
Args: | |
acc_id (str): AWS account ID | |
region_name (str): AWS region name | |
role_arn (str): IAM role ARN for SageMaker | |
ecr_repo_name (str): ECR repository name | |
endpoint_name (str): SageMaker endpoint name | |
""" | |
try: | |
logger.info("Starting SageMaker deployment process...") | |
# Initialize SageMaker session | |
sagemaker_session = sagemaker.Session() | |
# Define the image URI in ECR | |
ecr_image = f"{acc_id}.dkr.ecr.{region_name}.amazonaws.com/{ecr_repo_name}:latest" | |
logger.info(f"Using ECR image: {ecr_image}") | |
# Get model archive S3 URI | |
model_data = create_model_archive("models/customer_support_gpt") | |
# Define model configuration | |
model_environment = { | |
"MODEL_PATH": "/opt/ml/model", | |
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/code", | |
"SAGEMAKER_PROGRAM": "inference.py" | |
} | |
# Create model | |
logger.info("Creating SageMaker model...") | |
model = Model( | |
image_uri=ecr_image, | |
model_data=model_data, | |
role=role_arn, | |
sagemaker_session=sagemaker_session, | |
env=model_environment, | |
enable_network_isolation=False | |
) | |
# Define deployment configuration | |
deployment_config = { | |
"initial_instance_count": 1, | |
"instance_type": "ml.m5.large", | |
"endpoint_name": endpoint_name, | |
"update_endpoint": True if _endpoint_exists(sagemaker_session, endpoint_name) else False | |
} | |
# Deploy model | |
logger.info(f"Deploying model to endpoint: {endpoint_name}") | |
logger.info(f"Deployment configuration: {deployment_config}") | |
predictor = model.deploy(**deployment_config,health_check_timeout_seconds=180, # Optionally, increase the timeout | |
environment={ | |
"SAGEMAKER_CONTAINER_PORT": "8080", | |
"FLASK_HEALTHCHECK_PORT": "8081", | |
}) | |
logger.info(f"Successfully deployed to endpoint: {endpoint_name}") | |
return predictor | |
except Exception as e: | |
logger.error(f"Deployment failed: {str(e)}") | |
raise | |
def _endpoint_exists(sagemaker_session, endpoint_name): | |
"""Check if SageMaker endpoint already exists""" | |
client = sagemaker_session.boto_session.client('sagemaker') | |
try: | |
client.describe_endpoint(EndpointName=endpoint_name) | |
return True | |
except client.exceptions.ClientError: | |
return False | |
def main(): | |
parser = argparse.ArgumentParser(description="Deploy Gradio app to SageMaker") | |
parser.add_argument("--account_id", type=str, required=True, | |
help="AWS Account ID") | |
parser.add_argument("--region", type=str, required=True, | |
help="AWS Region") | |
parser.add_argument("--role_arn", type=str, required=True, | |
help="IAM Role ARN for SageMaker") | |
parser.add_argument("--ecr_repo_name", type=str, required=True, | |
help="ECR Repository name") | |
parser.add_argument("--endpoint_name", type=str, | |
default="customer-support-chatbot", | |
help="SageMaker Endpoint Name") | |
args = parser.parse_args() | |
try: | |
logger.info("Starting deployment process...") | |
deploy_app( | |
args.account_id, | |
args.region, | |
args.role_arn, | |
args.ecr_repo_name, | |
args.endpoint_name | |
) | |
logger.info("Deployment completed successfully!") | |
except Exception as e: | |
logger.error(f"Deployment failed: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
main() |