VenkateshRoshan commited on
Commit
671ee28
·
1 Parent(s): a562c0d

dockerfile updated

Browse files
Files changed (2) hide show
  1. dockerfile +0 -3
  2. src/deploy_sagemaker.py +122 -32
dockerfile CHANGED
@@ -31,9 +31,6 @@ FROM python:3.10-slim
31
  # # Run the application
32
  # CMD ["python", "app.py"]
33
 
34
- # Use NVIDIA CUDA base image
35
- # FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
36
-
37
  # Set environment variables
38
  ENV PYTHONUNBUFFERED=TRUE
39
  ENV PYTHONDONTWRITEBYTECODE=TRUE
 
31
  # # Run the application
32
  # CMD ["python", "app.py"]
33
 
 
 
 
34
  # Set environment variables
35
  ENV PYTHONUNBUFFERED=TRUE
36
  ENV PYTHONDONTWRITEBYTECODE=TRUE
src/deploy_sagemaker.py CHANGED
@@ -7,9 +7,40 @@ import os
7
  from datetime import datetime
8
 
9
  # Set up logging
10
- logging.basicConfig(level=logging.INFO)
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def deploy_app(acc_id, region_name, role_arn, ecr_repo_name, endpoint_name="customer-support-chatbot"):
14
  """
15
  Deploys a Gradio app as a SageMaker endpoint using an ECR image.
@@ -19,40 +50,99 @@ def deploy_app(acc_id, region_name, role_arn, ecr_repo_name, endpoint_name="cust
19
  region_name (str): AWS region name
20
  role_arn (str): IAM role ARN for SageMaker
21
  ecr_repo_name (str): ECR repository name
22
- endpoint_name (str): SageMaker endpoint name (default: "customer-support-chatbot")
23
  """
24
- # Initialize SageMaker session
25
- sagemaker_session = sagemaker.Session()
26
-
27
- # Define the image URI in ECR
28
- ecr_image = f"{acc_id}.dkr.ecr.{region_name}.amazonaws.com/{ecr_repo_name}:latest"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # Define model
31
- model = Model(
32
- image_uri=ecr_image,
33
- role=role_arn,
34
- sagemaker_session=sagemaker_session,
35
- entry_point="serve",
36
- )
 
37
 
38
- # Deploy model as a SageMaker endpoint
39
- logger.info(f"Starting deployment of Gradio app to SageMaker endpoint {endpoint_name}...")
40
- predictor = model.deploy(
41
- initial_instance_count=1,
42
- instance_type="ml.t3.large", #"ml.g4dn.xlarge",
43
- endpoint_name=endpoint_name
44
- )
45
- logger.info(f"Gradio app deployed successfully to endpoint: {endpoint_name}")
46
-
47
- if __name__ == "__main__":
48
- # Parse arguments from CLI
49
  parser = argparse.ArgumentParser(description="Deploy Gradio app to SageMaker")
50
- parser.add_argument("--account_id", type=str, required=True, help="AWS Account ID")
51
- parser.add_argument("--region", type=str, required=True, help="AWS Region")
52
- parser.add_argument("--role_arn", type=str, required=True, help="IAM Role ARN for SageMaker")
53
- parser.add_argument("--ecr_repo_name", type=str, required=True, help="ECR Repository name")
54
- parser.add_argument("--endpoint_name", type=str, default="customer-support-chatbot", help="SageMaker Endpoint Name")
 
 
 
 
 
 
 
55
  args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Deploy the Gradio app to SageMaker
58
- deploy_app(args.account_id, args.region, args.role_arn, args.ecr_repo_name, args.endpoint_name)
 
7
  from datetime import datetime
8
 
9
  # Set up logging
10
+ logging.basicConfig(
11
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
12
+ level=logging.INFO
13
+ )
14
  logger = logging.getLogger(__name__)
15
 
16
+ def create_model_archive(model_path):
17
+ """
18
+ Create a model archive if needed
19
+
20
+ Args:
21
+ model_path (str): Path to model files
22
+
23
+ Returns:
24
+ str: S3 URI of the model archive
25
+ """
26
+ try:
27
+ # Initialize S3 client
28
+ s3 = boto3.client('s3')
29
+ bucket = 'customer-support-gpt'
30
+ model_key = 'models/model.tar.gz'
31
+
32
+ # Check if model archive exists in S3
33
+ try:
34
+ s3.head_object(Bucket=bucket, Key=model_key)
35
+ logger.info("Model archive already exists in S3")
36
+ except:
37
+ logger.info("Model archive not found in S3, will be created during deployment")
38
+
39
+ return f's3://{bucket}/{model_key}'
40
+ except Exception as e:
41
+ logger.error(f"Error creating model archive: {str(e)}")
42
+ raise
43
+
44
  def deploy_app(acc_id, region_name, role_arn, ecr_repo_name, endpoint_name="customer-support-chatbot"):
45
  """
46
  Deploys a Gradio app as a SageMaker endpoint using an ECR image.
 
50
  region_name (str): AWS region name
51
  role_arn (str): IAM role ARN for SageMaker
52
  ecr_repo_name (str): ECR repository name
53
+ endpoint_name (str): SageMaker endpoint name
54
  """
55
+ try:
56
+ logger.info("Starting SageMaker deployment process...")
57
+
58
+ # Initialize SageMaker session
59
+ sagemaker_session = sagemaker.Session()
60
+
61
+ # Define the image URI in ECR
62
+ ecr_image = f"{acc_id}.dkr.ecr.{region_name}.amazonaws.com/{ecr_repo_name}:latest"
63
+ logger.info(f"Using ECR image: {ecr_image}")
64
+
65
+ # Get model archive S3 URI
66
+ model_data = create_model_archive("models/customer_support_gpt")
67
+
68
+ # Define model configuration
69
+ model_environment = {
70
+ "MODEL_PATH": "/opt/ml/model",
71
+ "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/code",
72
+ "SAGEMAKER_PROGRAM": "inference.py"
73
+ }
74
+
75
+ # Create model
76
+ logger.info("Creating SageMaker model...")
77
+ model = Model(
78
+ image_uri=ecr_image,
79
+ model_data=model_data,
80
+ role=role_arn,
81
+ sagemaker_session=sagemaker_session,
82
+ env=model_environment,
83
+ enable_network_isolation=False
84
+ )
85
+
86
+ # Define deployment configuration
87
+ deployment_config = {
88
+ "initial_instance_count": 1,
89
+ "instance_type": "ml.t3.large",
90
+ "endpoint_name": endpoint_name,
91
+ "update_endpoint": True if _endpoint_exists(sagemaker_session, endpoint_name) else False
92
+ }
93
+
94
+ # Deploy model
95
+ logger.info(f"Deploying model to endpoint: {endpoint_name}")
96
+ logger.info(f"Deployment configuration: {deployment_config}")
97
+
98
+ predictor = model.deploy(**deployment_config)
99
+
100
+ logger.info(f"Successfully deployed to endpoint: {endpoint_name}")
101
+ return predictor
102
+
103
+ except Exception as e:
104
+ logger.error(f"Deployment failed: {str(e)}")
105
+ raise
106
 
107
+ def _endpoint_exists(sagemaker_session, endpoint_name):
108
+ """Check if SageMaker endpoint already exists"""
109
+ client = sagemaker_session.boto_session.client('sagemaker')
110
+ try:
111
+ client.describe_endpoint(EndpointName=endpoint_name)
112
+ return True
113
+ except client.exceptions.ClientError:
114
+ return False
115
 
116
+ def main():
 
 
 
 
 
 
 
 
 
 
117
  parser = argparse.ArgumentParser(description="Deploy Gradio app to SageMaker")
118
+ parser.add_argument("--account_id", type=str, required=True,
119
+ help="AWS Account ID")
120
+ parser.add_argument("--region", type=str, required=True,
121
+ help="AWS Region")
122
+ parser.add_argument("--role_arn", type=str, required=True,
123
+ help="IAM Role ARN for SageMaker")
124
+ parser.add_argument("--ecr_repo_name", type=str, required=True,
125
+ help="ECR Repository name")
126
+ parser.add_argument("--endpoint_name", type=str,
127
+ default="customer-support-chatbot",
128
+ help="SageMaker Endpoint Name")
129
+
130
  args = parser.parse_args()
131
+
132
+ try:
133
+ logger.info("Starting deployment process...")
134
+ deploy_app(
135
+ args.account_id,
136
+ args.region,
137
+ args.role_arn,
138
+ args.ecr_repo_name,
139
+ args.endpoint_name
140
+ )
141
+ logger.info("Deployment completed successfully!")
142
+
143
+ except Exception as e:
144
+ logger.error(f"Deployment failed: {str(e)}")
145
+ raise
146
 
147
+ if __name__ == "__main__":
148
+ main()