"""
The ``mlflow.sagemaker`` module provides an API for deploying MLflow models to Amazon SageMaker.
"""
import json
import logging
import os
import platform
import signal
import sys
import tarfile
import time
import urllib.parse
import uuid
from subprocess import Popen
from typing import Any, Optional
import mlflow
import mlflow.version
from mlflow import pyfunc
from mlflow.deployments import BaseDeploymentClient, PredictionsResponse
from mlflow.environment_variables import (
    MLFLOW_DEPLOYMENT_FLAVOR_NAME,
    MLFLOW_SAGEMAKER_DEPLOY_IMG_URL,
)
from mlflow.exceptions import MlflowException
from mlflow.models import Model
from mlflow.models.container import (
    SERVING_ENVIRONMENT,
)
from mlflow.models.container import (
    SUPPORTED_FLAVORS as SUPPORTED_DEPLOYMENT_FLAVORS,
)
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.file_utils import TempDir
from mlflow.utils.proto_json_utils import dump_input_data
DEFAULT_IMAGE_NAME = "mlflow-pyfunc"
DEPLOYMENT_MODE_ADD = "add"
DEPLOYMENT_MODE_REPLACE = "replace"
DEPLOYMENT_MODE_CREATE = "create"
DEPLOYMENT_MODES = [DEPLOYMENT_MODE_CREATE, DEPLOYMENT_MODE_ADD, DEPLOYMENT_MODE_REPLACE]
DEFAULT_BUCKET_NAME_PREFIX = "mlflow-sagemaker"
DEFAULT_SAGEMAKER_INSTANCE_TYPE = "ml.m4.xlarge"
DEFAULT_SAGEMAKER_INSTANCE_COUNT = 1
DEFAULT_REGION_NAME = "us-west-2"
SAGEMAKER_SERVING_ENVIRONMENT = "SageMaker"
SAGEMAKER_APP_NAME_TAG_KEY = "app_name"
_logger = logging.getLogger(__name__)
_full_template = "{account}.dkr.ecr.{region}.amazonaws.com/{image}:{version}"
def _get_preferred_deployment_flavor(model_config):
    """
    Obtains the flavor that MLflow would prefer to use when deploying the model.
    If the model does not contain any supported flavors for deployment, an exception
    will be thrown.
    Args:
        model_config: An MLflow model object
    Returns:
        The name of the preferred deployment flavor for the specified model
    """
    if pyfunc.FLAVOR_NAME in model_config.flavors:
        return pyfunc.FLAVOR_NAME
    else:
        raise MlflowException(
            message=(
                "The specified model does not contain any of the supported flavors for"
                " deployment. The model contains the following flavors: {model_flavors}."
                " Supported flavors: {supported_flavors}".format(
                    model_flavors=model_config.flavors.keys(),
                    supported_flavors=SUPPORTED_DEPLOYMENT_FLAVORS,
                )
            ),
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
def _validate_deployment_flavor(model_config, flavor):
    """
    Checks that the specified flavor is a supported deployment flavor
    and is contained in the specified model. If one of these conditions
    is not met, an exception is thrown.
    Args:
        model_config: An MLflow Model object
        flavor: The deployment flavor to validate
    """
    if flavor not in SUPPORTED_DEPLOYMENT_FLAVORS:
        raise MlflowException(
            message=(
                f"The specified flavor: `{flavor}` is not supported for deployment."
                f" Please use one of the supported flavors: {SUPPORTED_DEPLOYMENT_FLAVORS}"
            ),
            error_code=INVALID_PARAMETER_VALUE,
        )
    elif flavor not in model_config.flavors:
        raise MlflowException(
            message=(
                "The specified model does not contain the specified deployment flavor:"
                f" `{flavor}`. Please use one of the following deployment flavors"
                f" that the model contains: {model_config.flavors.keys()}"
            ),
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
[docs]def push_image_to_ecr(image=DEFAULT_IMAGE_NAME):
    """
    Push local Docker image to AWS ECR.
    The image is pushed under currently active AWS account and to the currently active AWS region.
    Args:
        image: Docker image name.
    """
    import boto3
    _logger.info("Pushing image to ECR")
    client = boto3.client("sts")
    caller_id = client.get_caller_identity()
    account = caller_id["Account"]
    my_session = boto3.session.Session()
    region = my_session.region_name or "us-west-2"
    fullname = _full_template.format(
        account=account, region=region, image=image, version=mlflow.version.VERSION
    )
    _logger.info("Pushing docker image %s to %s", image, fullname)
    ecr_client = boto3.client("ecr")
    try:
        ecr_client.describe_repositories(repositoryNames=[image])["repositories"]
    except ecr_client.exceptions.RepositoryNotFoundException:
        ecr_client.create_repository(repositoryName=image)
        _logger.info("Created new ECR repository: %s", image)
    # TODO: it would be nice to translate the docker login, tag and push to python api.
    # x = ecr_client.get_authorization_token()['authorizationData'][0]
    # docker_login_cmd = "docker login -u AWS -p {token} {url}".format(token=x['authorizationToken']
    #                                                                ,url=x['proxyEndpoint'])
    docker_login_cmd = (
        "aws ecr get-login-password"
        " | docker login  --username AWS "
        "--password-stdin "
        f"{account}.dkr.ecr.{region}.amazonaws.com"
    )
    os_command_separator = ";\n"
    if platform.system() == "Windows":
        os_command_separator = " && "
    docker_tag_cmd = f"docker tag {image} {fullname}"
    docker_push_cmd = f"docker push {fullname}"
    cmd = os_command_separator.join([docker_login_cmd, docker_tag_cmd, docker_push_cmd])
    _logger.info("Executing: %s", cmd)
    os.system(cmd) 
def _deploy(
    app_name,
    model_uri,
    execution_role_arn=None,
    assume_role_arn=None,
    bucket=None,
    image_url=None,
    region_name="us-west-2",
    mode=DEPLOYMENT_MODE_CREATE,
    archive=False,
    instance_type=DEFAULT_SAGEMAKER_INSTANCE_TYPE,
    instance_count=DEFAULT_SAGEMAKER_INSTANCE_COUNT,
    vpc_config=None,
    flavor=None,
    synchronous=True,
    timeout_seconds=1200,
    data_capture_config=None,
    variant_name=None,
    async_inference_config=None,
    serverless_config=None,
    env=None,
    tags=None,
):
    """
    Deploy an MLflow model on AWS SageMaker.
    The currently active AWS account must have correct permissions set up.
    This function creates a SageMaker endpoint. For more information about the input data
    formats accepted by this endpoint, see the
    `MLflow deployment tools documentation <../../deployment/deploy-model-to-sagemaker.html>`_.
    Args:
        app_name: Name of the deployed application.
        model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker.
            For example:
            - ``/Users/me/path/to/local/model``
            - ``relative/path/to/local/model``
            - ``s3://my_bucket/path/to/model``
            - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
            - ``models:/<model_name>/<model_version>``
            - ``models:/<model_name>/<stage>``
            For more information about supported URI schemes, see
            `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
            artifact-locations>`_.
        execution_role_arn: The name of an IAM role granting the SageMaker service permissions to
            access the specified Docker image and S3 bucket containing MLflow
            model artifacts. If unspecified, the currently-assumed role will be
            used. This execution role is passed to the SageMaker service when
            creating a SageMaker model from the specified MLflow model. It is
            passed as the ``ExecutionRoleArn`` parameter of the `SageMaker
            CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/
            dg/API_CreateModel.html>`_. This role is *not* assumed for any other
            call. For more information about SageMaker execution roles for model
            creation, see
            https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html.
        assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker
            to another AWS account. If unspecified, SageMaker will be deployed to
            the the currently active AWS account.
        bucket: S3 bucket where model artifacts will be stored. Defaults to a
            SageMaker-compatible bucket name.
        image_url: URL of the ECR-hosted Docker image the model should be deployed into, produced
            by ``mlflow sagemaker build-and-push-container``. This parameter can also
            be specified by the environment variable ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``.
        region_name: Name of the AWS region to which to deploy the application.
        mode: The mode in which to deploy the application. Must be one of the following:
            ``mlflow.sagemaker.DEPLOYMENT_MODE_CREATE``
                Create an application with the specified name and model. This fails if an
                application of the same name already exists.
            ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE``
                If an application of the specified name exists, its model(s) is replaced with
                the specified model. If no such application exists, it is created with the
                specified name and model.
            ``mlflow.sagemaker.DEPLOYMENT_MODE_ADD``
                Add the specified model to a pre-existing application with the specified name,
                if one exists. If the application does not exist, a new application is created
                with the specified name and model. NOTE: If the application **already exists**,
                the specified model is added to the application's corresponding SageMaker
                endpoint with an initial weight of zero (0). To route traffic to the model,
                update the application's associated endpoint configuration using either the
                AWS console or the ``UpdateEndpointWeightsAndCapacities`` function defined in
                https://docs.aws.amazon.com/sagemaker/latest/dg/API_UpdateEndpointWeightsAndCapacities.html.
        archive: If ``True``, any pre-existing SageMaker application resources that become
            inactive (i.e. as a result of deploying in
            ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE`` mode) are preserved.
            These resources may include unused SageMaker models and endpoint configurations
            that were associated with a prior version of the application endpoint. If
            ``False``, these resources are deleted. In order to use ``archive=False``,
            ``deploy()`` must be executed synchronously with ``synchronous=True``.
        instance_type: The type of SageMaker ML instance on which to deploy the model. For a list
            of supported instance types, see
            https://aws.amazon.com/sagemaker/pricing/instance-types/.
        instance_count: The number of SageMaker ML instances on which to deploy the model.
        vpc_config: A dictionary specifying the VPC configuration to use when creating the
            new SageMaker model associated with this application. The acceptable values
            for this parameter are identical to those of the ``VpcConfig`` parameter in
            the `SageMaker boto3 client's create_model method
            <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html
            #SageMaker.Client.create_model>`_. For more information, see
            https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html.
            .. code-block:: python
                :caption: Example
                    import mlflow.sagemaker as mfs
                    vpc_config = {
                        "SecurityGroupIds": [
                            "sg-123456abc",
                        ],
                        "Subnets": [
                            "subnet-123456abc",
                        ],
                    }
                    mfs.deploy(..., vpc_config=vpc_config)
        flavor: The name of the flavor of the model to use for deployment. Must be either
            ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS. If ``None``,
            a flavor is automatically selected from the model's available flavors. If the
            specified flavor is not present or not supported for deployment, an exception
            will be thrown.
        synchronous: If ``True``, this function will block until the deployment process succeeds
            or encounters an irrecoverable failure. If ``False``, this function will
            return immediately after starting the deployment process. It will not wait
            for the deployment process to complete; in this case, the caller is
            responsible for monitoring the health and status of the pending deployment
            via native SageMaker APIs or the AWS console.
        timeout_seconds: If ``synchronous`` is ``True``, the deployment process will return after
            the specified number of seconds if no definitive result (success or
            failure) is achieved. Once the function returns, the caller is
            responsible for monitoring the health and status of the pending
            deployment using native SageMaker APIs or the AWS console. If
            ``synchronous`` is ``False``, this parameter is ignored.
        data_capture_config: A dictionary specifying the data capture configuration to use when
            creating the new SageMaker model associated with this application.
            For more information, see
            https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html.
            .. code-block:: python
                :caption: Example
                import mlflow.sagemaker as mfs
                data_capture_config = {
                    "EnableCapture": True,
                    "InitialSamplingPercentage": 100,
                    "DestinationS3Uri": "s3://my-bucket/path",
                    "CaptureOptions": [{"CaptureMode": "Output"}],
                }
                mfs.deploy(..., data_capture_config=data_capture_config)
        variant_name: The name to assign to the new production variant.
        async_inference_config: The name to assign to the endpoint_config
            on the sagemaker endpoint.
            .. code-block:: python
                :caption: Example
                {
                    "AsyncInferenceConfig": {
                        "ClientConfig": {"MaxConcurrentInvocationsPerInstance": 4},
                        "OutputConfig": {
                            "S3OutputPath": "s3://<path-to-output-bucket>",
                            "NotificationConfig": {},
                        },
                    }
                }
        serverless_config: An optional dictionary specifying the serverless configuration
            .. code-block:: python
                :caption: Example
                {
                    "ServerlessConfig": {
                        "MemorySizeInMB": 2048,
                        "MaxConcurrency": 20,
                    }
                }
        env: An optional dictionary of environment variables to set for the model.
        tags: An optional dictionary of tags to apply to the endpoint.
    """
    import boto3
    if (not archive) and (not synchronous):
        raise MlflowException(
            message=(
                "Resources must be archived when `deploy()` is executed in non-synchronous mode."
                " Either set `synchronous=True` or `archive=True`."
            ),
            error_code=INVALID_PARAMETER_VALUE,
        )
    if mode not in DEPLOYMENT_MODES:
        raise MlflowException(
            message="`mode` must be one of: {deployment_modes}".format(
                deployment_modes=",".join(DEPLOYMENT_MODES)
            ),
            error_code=INVALID_PARAMETER_VALUE,
        )
    model_path = _download_artifact_from_uri(model_uri)
    model_config_path = os.path.join(model_path, MLMODEL_FILE_NAME)
    if not os.path.exists(model_config_path):
        raise MlflowException(
            message=(
                f"Failed to find {MLMODEL_FILE_NAME} configuration within the specified model's "
                "root directory."
            ),
            error_code=INVALID_PARAMETER_VALUE,
        )
    model_config = Model.load(model_config_path)
    if flavor is None:
        flavor = _get_preferred_deployment_flavor(model_config)
    else:
        _validate_deployment_flavor(model_config, flavor)
    _logger.info("Using the %s flavor for deployment!", flavor)
    assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn)
    s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials)
    sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials)
    endpoint_exists = _find_endpoint(endpoint_name=app_name, sage_client=sage_client) is not None
    if endpoint_exists and mode == DEPLOYMENT_MODE_CREATE:
        raise MlflowException(
            message=(
                f"You are attempting to deploy an application with name: {app_name} in"
                f" '{DEPLOYMENT_MODE_CREATE}' mode. However, an application with the same name"
                " already exists. If you want to update this application, deploy in"
                f" '{DEPLOYMENT_MODE_ADD}' or '{DEPLOYMENT_MODE_REPLACE}' mode."
            ),
            error_code=INVALID_PARAMETER_VALUE,
        )
    model_name = _get_sagemaker_model_name(endpoint_name=app_name)
    if not image_url:
        image_url = _get_default_image_url(region_name=region_name)
    if not execution_role_arn:
        execution_role_arn = _get_assumed_role_arn(**assume_role_credentials)
    if not bucket:
        _logger.info("No model data bucket specified, using the default bucket")
        bucket = _get_default_s3_bucket(region_name, **assume_role_credentials)
    model_s3_path = _upload_s3(
        local_model_path=model_path,
        bucket=bucket,
        prefix=model_name,
        region_name=region_name,
        s3_client=s3_client,
        **assume_role_credentials,
    )
    if endpoint_exists:
        deployment_operation = _update_sagemaker_endpoint(
            endpoint_name=app_name,
            model_name=model_name,
            model_s3_path=model_s3_path,
            model_uri=model_uri,
            image_url=image_url,
            flavor=flavor,
            instance_type=instance_type,
            instance_count=instance_count,
            vpc_config=vpc_config,
            mode=mode,
            role=execution_role_arn,
            sage_client=sage_client,
            s3_client=s3_client,
            variant_name=variant_name,
            async_inference_config=async_inference_config,
            serverless_config=serverless_config,
            data_capture_config=data_capture_config,
            env=env,
            tags=tags,
        )
    else:
        deployment_operation = _create_sagemaker_endpoint(
            endpoint_name=app_name,
            model_name=model_name,
            model_s3_path=model_s3_path,
            model_uri=model_uri,
            image_url=image_url,
            flavor=flavor,
            instance_type=instance_type,
            instance_count=instance_count,
            vpc_config=vpc_config,
            data_capture_config=data_capture_config,
            role=execution_role_arn,
            sage_client=sage_client,
            variant_name=variant_name,
            async_inference_config=async_inference_config,
            serverless_config=serverless_config,
            env=env,
            tags=tags,
        )
    if synchronous:
        _logger.info("Waiting for the deployment operation to complete...")
        operation_status = deployment_operation.await_completion(timeout_seconds=timeout_seconds)
        if operation_status.state == _SageMakerOperationStatus.STATE_SUCCEEDED:
            _logger.info(
                'The deployment operation completed successfully with message: "%s"',
                operation_status.message,
            )
        else:
            raise MlflowException(
                "The deployment operation failed with the following error message:"
                f' "{operation_status.message}"'
            )
        if not archive:
            deployment_operation.clean_up()
    return app_name, flavor
def _delete(
    app_name,
    region_name="us-west-2",
    assume_role_arn=None,
    archive=False,
    synchronous=True,
    timeout_seconds=300,
):
    """
    Delete a SageMaker application.
    Args:
        app_name: Name of the deployed application.
        region_name: Name of the AWS region in which the application is deployed.
        assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker
            to another AWS account. If unspecified, SageMaker will be deployed to
            the the currently active AWS account.
        archive: If ``True``, resources associated with the specified application, such
            as its associated models and endpoint configuration, are preserved.
            If ``False``, these resources are deleted. In order to use
            ``archive=False``, ``delete()`` must be executed synchronously with
            ``synchronous=True``.
        synchronous: If `True`, this function blocks until the deletion process succeeds
            or encounters an irrecoverable failure. If `False`, this function
            returns immediately after starting the deletion process. It will not wait
            for the deletion process to complete; in this case, the caller is
            responsible for monitoring the status of the deletion process via native
            SageMaker APIs or the AWS console.
        timeout_seconds: If `synchronous` is `True`, the deletion process returns after the
            specified number of seconds if no definitive result (success or failure)
            is achieved. Once the function returns, the caller is responsible
            for monitoring the status of the deletion process via native SageMaker
            APIs or the AWS console. If `synchronous` is False, this parameter
            is ignored.
    """
    import boto3
    if (not archive) and (not synchronous):
        raise MlflowException(
            message=(
                "Resources must be archived when `delete()` is executed in non-synchronous mode."
                " Either set `synchronous=True` or `archive=True`."
            ),
            error_code=INVALID_PARAMETER_VALUE,
        )
    assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn)
    s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials)
    sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials)
    endpoint_info = sage_client.describe_endpoint(EndpointName=app_name)
    endpoint_arn = endpoint_info["EndpointArn"]
    sage_client.delete_endpoint(EndpointName=app_name)
    _logger.info("Deleted endpoint with arn: %s", endpoint_arn)
    def status_check_fn():
        endpoint_info = _find_endpoint(endpoint_name=app_name, sage_client=sage_client)
        if endpoint_info is not None:
            return _SageMakerOperationStatus.in_progress(
                "Deletion is still in progress. Current endpoint status: {endpoint_status}".format(
                    endpoint_status=endpoint_info["EndpointStatus"]
                )
            )
        else:
            return _SageMakerOperationStatus.succeeded(
                "The SageMaker endpoint was deleted successfully."
            )
    def cleanup_fn():
        _logger.info("Cleaning up unused resources...")
        config_name = endpoint_info["EndpointConfigName"]
        config_info = sage_client.describe_endpoint_config(EndpointConfigName=config_name)
        config_arn = config_info["EndpointConfigArn"]
        sage_client.delete_endpoint_config(EndpointConfigName=config_name)
        _logger.info("Deleted associated endpoint configuration with arn: %s", config_arn)
        for pv in config_info["ProductionVariants"]:
            model_name = pv["ModelName"]
            model_arn = _delete_sagemaker_model(model_name, sage_client, s3_client)
            _logger.info("Deleted associated model with arn: %s", model_arn)
    delete_operation = _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn)
    if synchronous:
        _logger.info("Waiting for the delete operation to complete...")
        operation_status = delete_operation.await_completion(timeout_seconds=timeout_seconds)
        if operation_status.state == _SageMakerOperationStatus.STATE_SUCCEEDED:
            _logger.info(
                'The deletion operation completed successfully with message: "%s"',
                operation_status.message,
            )
        else:
            raise MlflowException(
                "The deletion operation failed with the following error message:"
                f' "{operation_status.message}"'
            )
        if not archive:
            delete_operation.clean_up()
[docs]def push_model_to_sagemaker(
    model_name,
    model_uri,
    execution_role_arn=None,
    assume_role_arn=None,
    bucket=None,
    image_url=None,
    region_name="us-west-2",
    vpc_config=None,
    flavor=None,
):
    """
    Create a SageMaker Model from an MLflow model artifact.
    The currently active AWS account must have correct permissions set up.
    Args:
        model_name: Name of the Sagemaker model.
        model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker.
            For example:
            - ``/Users/me/path/to/local/model``
            - ``relative/path/to/local/model``
            - ``s3://my_bucket/path/to/model``
            - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
            - ``models:/<model_name>/<model_version>``
            - ``models:/<model_name>/<stage>``
            For more information about supported URI schemes, see
            `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
            artifact-locations>`_.
        execution_role_arn: The name of an IAM role granting the SageMaker service permissions to
            access the specified Docker image and S3 bucket containing MLflow
            model artifacts. If unspecified, the currently-assumed role will be
            used. This execution role is passed to the SageMaker service when
            creating a SageMaker model from the specified MLflow model. It is
            passed as the ``ExecutionRoleArn`` parameter of the `SageMaker
            CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/
            dg/API_CreateModel.html>`_. This role is *not* assumed for any other
            call. For more information about SageMaker execution roles for model
            creation, see
            https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html.
        assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker
            to another AWS account. If unspecified, SageMaker will be deployed to
            the the currently active AWS account.
        bucket: S3 bucket where model artifacts will be stored. Defaults to a
            SageMaker-compatible bucket name.
        image_url: URL of the ECR-hosted Docker image the model should be deployed into, produced
            by ``mlflow sagemaker build-and-push-container``. This parameter can also
            be specified by the environment variable ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``.
        region_name: Name of the AWS region to which to deploy the application.
        vpc_config: A dictionary specifying the VPC configuration to use when creating the
            new SageMaker model. The acceptable values for this parameter are identical
            to those of the ``VpcConfig`` parameter in the `SageMaker boto3 client's
            create_model method
            <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html
            #SageMaker.Client.create_model>`_. For more information, see
            https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html.
            .. code-block:: python
                :caption: Example
                import mlflow.sagemaker as mfs
                vpc_config = {
                    "SecurityGroupIds": [
                        "sg-123456abc",
                    ],
                    "Subnets": [
                        "subnet-123456abc",
                    ],
                }
                mfs.push_model_to_sagemaker(..., vpc_config=vpc_config)
        flavor: The name of the flavor of the model to use for deployment. Must be either
            ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS. If ``None``,
            a flavor is automatically selected from the model's available flavors. If the
            specified flavor is not present or not supported for deployment, an exception
            will be thrown.
    """
    import boto3
    model_path = _download_artifact_from_uri(model_uri)
    model_config_path = os.path.join(model_path, MLMODEL_FILE_NAME)
    if not os.path.exists(model_config_path):
        raise MlflowException(
            message=(
                f"Failed to find {MLMODEL_FILE_NAME} configuration within the specified model's"
                " root directory."
            ),
            error_code=INVALID_PARAMETER_VALUE,
        )
    model_config = Model.load(model_config_path)
    if flavor is None:
        flavor = _get_preferred_deployment_flavor(model_config)
    else:
        _validate_deployment_flavor(model_config, flavor)
    _logger.info("Using the %s flavor for deployment!", flavor)
    assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn)
    s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials)
    sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials)
    if _does_model_exist(model_name=model_name, sage_client=sage_client):
        raise MlflowException(
            message=(
                f"You are attempting to create a Sagemaker model with name: {model_name}. "
                "However, a model with the same name already exists."
            ),
            error_code=INVALID_PARAMETER_VALUE,
        )
    if not image_url:
        image_url = _get_default_image_url(region_name=region_name)
    if not execution_role_arn:
        execution_role_arn = _get_assumed_role_arn(**assume_role_credentials)
    if not bucket:
        _logger.info("No model data bucket specified, using the default bucket")
        bucket = _get_default_s3_bucket(region_name, **assume_role_credentials)
    model_s3_path = _upload_s3(
        local_model_path=model_path,
        bucket=bucket,
        prefix=model_name,
        region_name=region_name,
        s3_client=s3_client,
        **assume_role_credentials,
    )
    model_response = _create_sagemaker_model(
        model_name=model_name,
        model_s3_path=model_s3_path,
        model_uri=model_uri,
        flavor=flavor,
        vpc_config=vpc_config,
        image_url=image_url,
        execution_role=execution_role_arn,
        sage_client=sage_client,
        env={},
        tags={},
    )
    _logger.info("Created Sagemaker model with arn: %s", model_response["ModelArn"]) 
[docs]def run_local(name, model_uri, flavor=None, config=None):
    """
    Serve the model locally in a SageMaker compatible Docker container.
    Note that models deployed locally cannot be managed by other deployment APIs
    (e.g. ``update_deployment``, ``delete_deployment``, etc).
    Args:
        name: Name of the local serving application.
        model_uri: The location, in URI format, of the MLflow model to deploy locally.
                        For example:
                        - ``/Users/me/path/to/local/model``
                        - ``relative/path/to/local/model``
                        - ``s3://my_bucket/path/to/model``
                        - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
                        - ``models:/<model_name>/<model_version>``
                        - ``models:/<model_name>/<stage>``
                        For more information about supported URI schemes, see
                        `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
                        artifact-locations>`_.
        flavor: The name of the flavor of the model to use for deployment. Must be either
                    ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS.
                    If ``None``, a flavor is automatically selected from the model's available
                    flavors. If the specified flavor is not present or not supported for
                    deployment, an exception will be thrown.
        config: Configuration parameters. The supported parameters are:
                    - ``image``: The name of the Docker image to use for model serving. Defaults
                                    to ``"mlflow-pyfunc"``.
                    - ``port``: The port at which to expose the model server on the local host.
                                Defaults to ``5000``.
    .. code-block:: python
        :caption: Python example
        from mlflow.models import build_docker
        from mlflow.deployments import get_deploy_client
        build_docker(name="mlflow-pyfunc")
        client = get_deploy_client("sagemaker")
        client.run_local(
            name="my-local-deployment",
            model_uri="/mlruns/0/abc/model",
            flavor="python_function",
            config={
                "port": 5000,
                "image": "mlflow-pyfunc",
            },
        )
    .. code-block:: bash
        :caption:  Command-line example
        mlflow models build-docker --name "mlflow-pyfunc"
        mlflow deployments run-local --target sagemaker \\
                --name my-local-deployment \\
                --model-uri "/mlruns/0/abc/model" \\
                --flavor python_function \\
                -C port=5000 \\
                -C image="mlflow-pyfunc"
    """
    model_path = _download_artifact_from_uri(model_uri)
    model_config_path = os.path.join(model_path, MLMODEL_FILE_NAME)
    model_config = Model.load(model_config_path)
    if flavor is None:
        flavor = _get_preferred_deployment_flavor(model_config)
    else:
        _validate_deployment_flavor(model_config, flavor)
    _logger.info("Using the %s flavor for local serving!", flavor)
    image = config.get("image", DEFAULT_IMAGE_NAME)
    port = int(config.get("port", 5000))
    deployment_config = _get_deployment_config(flavor_name=flavor)
    _logger.info("launching docker image with path %s", model_path)
    cmd = ["docker", "run", "-v", f"{model_path}:/opt/ml/model/", "-p", f"{port}:8080"]
    for key, value in deployment_config.items():
        cmd += ["-e", f"{key}={value}"]
    cmd += ["--rm", image, "serve"]
    _logger.info("executing: %s", " ".join(cmd))
    proc = Popen(cmd, stdout=sys.stdout, stderr=sys.stderr, text=True)
    def _sigterm_handler(*_):
        _logger.info("received termination signal => killing docker process")
        proc.send_signal(signal.SIGINT)
    signal.signal(signal.SIGTERM, _sigterm_handler)
    proc.wait() 
[docs]def target_help():
    """
    Provide help information for the SageMaker deployment client.
    """
    return """\
    For detailed documentation on the SageMaker deployment client, please visit
    https://mlflow.org/docs/latest/python_api/mlflow.sagemaker.html#mlflow.sagemaker.SageMakerDeploymentClient
    The target URI must follow the following formats:
    - sagemaker
    - sagemaker:/region_name
    - sagemaker:/region_name/assume_role_arn
    When the region_name or assume_role_arn are provided, they will be used as the default region
    and assumed role ARN when executing the commands.
    The `create` and `update` commands require a deployment name and a model_uri. The model flavor
    and deployment configuration can be optionally provided. These commands can also be executed
    in synchronous or asynchronous mode.
    The `delete` command accepts configurations to archive a model instead of deleting, execute
    in asynchronous mode and timeout period.
    """ 
def _get_default_image_url(region_name):
    import boto3
    if env_img := MLFLOW_SAGEMAKER_DEPLOY_IMG_URL.get():
        return env_img
    ecr_client = boto3.client("ecr", region_name=region_name)
    repository_conf = ecr_client.describe_repositories(repositoryNames=[DEFAULT_IMAGE_NAME])[
        "repositories"
    ][0]
    return (repository_conf["repositoryUri"] + ":{version}").format(version=mlflow.version.VERSION)
def _get_account_id(**assume_role_credentials):
    import boto3
    sess = boto3.Session()
    sts_client = sess.client("sts", **assume_role_credentials)
    identity_info = sts_client.get_caller_identity()
    return identity_info["Account"]
def _get_assumed_role_arn(**assume_role_credentials):
    """
    Returns:
        ARN of the user's current IAM role.
    """
    import boto3
    sess = boto3.Session()
    sts_client = sess.client("sts", **assume_role_credentials)
    identity_info = sts_client.get_caller_identity()
    sts_arn = identity_info["Arn"]
    role_name = sts_arn.split("/")[1]
    iam_client = sess.client("iam", **assume_role_credentials)
    role_response = iam_client.get_role(RoleName=role_name)
    return role_response["Role"]["Arn"]
def _assume_role_and_get_credentials(assume_role_arn=None):
    """
    Assume a new role in AWS and return the credentials for that role.
    When ``assume_role_arn`` is ``None`` or an empty string,
    this function does nothing and returns an empty dictionary.
    Args:
        assume_role_arn: Optional ARN of the role that will be assumed
    Returns:
        Dict with credentials of the assumed role
    """
    import boto3
    if not assume_role_arn:
        return {}
    sts_client = boto3.client("sts")
    sts_response = sts_client.assume_role(
        RoleArn=assume_role_arn, RoleSessionName="mlflow-sagemaker"
    )
    _logger.info("Assuming role %s for deployment!", assume_role_arn)
    return {
        "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
        "aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
        "aws_session_token": sts_response["Credentials"]["SessionToken"],
    }
def _get_default_s3_bucket(region_name, **assume_role_credentials):
    import boto3
    # create bucket if it does not exist
    sess = boto3.Session()
    account_id = _get_account_id(**assume_role_credentials)
    bucket_name = f"{DEFAULT_BUCKET_NAME_PREFIX}-{region_name}-{account_id}"
    s3 = sess.client("s3", **assume_role_credentials)
    response = s3.list_buckets()
    buckets = [b["Name"] for b in response["Buckets"]]
    if bucket_name not in buckets:
        _logger.info("Default bucket `%s` not found. Creating...", bucket_name)
        bucket_creation_kwargs = {
            "ACL": "bucket-owner-full-control",
            "Bucket": bucket_name,
        }
        if region_name != "us-east-1":
            # The location constraint is required during bucket creation for all regions
            # outside of us-east-1. This constraint cannot be specified in us-east-1;
            # specifying it in this region results in a failure, so we will only
            # add it if we are deploying outside of us-east-1.
            # See https://docs.aws.amazon.com/cli/latest/reference/s3api/create-bucket.html#examples
            bucket_creation_kwargs["CreateBucketConfiguration"] = {
                "LocationConstraint": region_name
            }
        response = s3.create_bucket(**bucket_creation_kwargs)
        _logger.info("Bucket creation response: %s", response)
    else:
        _logger.info("Default bucket `%s` already exists. Skipping creation.", bucket_name)
    return bucket_name
def _make_tarfile(output_filename, source_dir):
    """
    create a tar.gz from a directory.
    """
    with tarfile.open(output_filename, "w:gz") as tar:
        for f in os.listdir(source_dir):
            tar.add(os.path.join(source_dir, f), arcname=f)
def _upload_s3(local_model_path, bucket, prefix, region_name, s3_client, **assume_role_credentials):  # noqa: D417
    """
    Upload dir to S3 as .tar.gz.
    Args:
        local_model_path: Local path to a dir.
        bucket: S3 bucket where to store the data.
        prefix: Path within the bucket.
        region_name: The AWS region in which to upload data to S3.
        s3_client: A boto3 client for S3.
    Returns:
        S3 path of the uploaded artifact.
    """
    import boto3
    sess = boto3.Session(region_name=region_name, **assume_role_credentials)
    with TempDir() as tmp:
        model_data_file = tmp.path("model.tar.gz")
        _make_tarfile(model_data_file, local_model_path)
        with open(model_data_file, "rb") as fobj:
            key = os.path.join(prefix, "model.tar.gz")
            obj = sess.resource("s3").Bucket(bucket).Object(key)
            obj.upload_fileobj(fobj)
            response = s3_client.put_object_tagging(
                Bucket=bucket, Key=key, Tagging={"TagSet": [{"Key": "SageMaker", "Value": "true"}]}
            )
            _logger.info("tag response: %s", response)
            return f"s3://{bucket}/{key}"
def _get_deployment_config(flavor_name, env_override=None):
    """
    Returns:
        The deployment configuration as a dictionary
    """
    deployment_config = {
        MLFLOW_DEPLOYMENT_FLAVOR_NAME.name: flavor_name,
        SERVING_ENVIRONMENT: SAGEMAKER_SERVING_ENVIRONMENT,
    }
    if env_override:
        deployment_config.update(env_override)
    if os.getenv("http_proxy") is not None:
        deployment_config.update({"http_proxy": os.environ["http_proxy"]})
    if os.getenv("https_proxy") is not None:
        deployment_config.update({"https_proxy": os.environ["https_proxy"]})
    if os.getenv("no_proxy") is not None:
        deployment_config.update({"no_proxy": os.environ["no_proxy"]})
    return deployment_config
def _truncate_name(name, max_length):
    # NB: Sagemaker prevents the registration of models and configurations whose names
    # exceed 63 characters in length. For reference:
    # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_Model.html
    # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_TransformJob.html
    # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ModelConfiguration.html
    # This function middle-truncates the name provided to
    # ensure that the least critical name information is not lost
    if len(name) <= max_length:
        return name
    available_length = max_length - 3
    start_len = available_length // 2
    end_len = available_length - start_len
    truncated_name = f"{name[:start_len]}---{name[-end_len:]}"
    _logger.warning(
        f"Truncated name {name} to {truncated_name} to coerce total character counts to < 64"
    )
    return truncated_name
def _get_unique_name(base_name, unique_suffix, unique_id_length=20):
    unique_id = uuid.uuid4().hex[:unique_id_length]
    unique_resource_string = f"{unique_suffix}{unique_id}"
    max_length = 63 - len(unique_resource_string)
    return _truncate_name(base_name, max_length) + unique_resource_string
def _get_sagemaker_model_name(endpoint_name):
    return _get_unique_name(endpoint_name, "-model-")
def _get_sagemaker_transform_model_name(job_name):
    return _get_unique_name(job_name, "-model-")
def _get_sagemaker_config_name(endpoint_name):
    return _get_unique_name(endpoint_name, "-config-")
def _get_sagemaker_config_tags(endpoint_name):
    return [{"Key": SAGEMAKER_APP_NAME_TAG_KEY, "Value": endpoint_name}]
def _prepare_sagemaker_tags(
    config_tags: list[dict[str, str]],
    sagemaker_tags: Optional[dict[str, str]] = None,
):
    if not sagemaker_tags:
        return config_tags
    if SAGEMAKER_APP_NAME_TAG_KEY in sagemaker_tags:
        raise MlflowException.invalid_parameter_value(
            f"Duplicate tag provided for '{SAGEMAKER_APP_NAME_TAG_KEY}'"
        )
    parsed = [{"Key": key, "Value": str(value)} for key, value in sagemaker_tags.items()]
    return config_tags + parsed
def _create_sagemaker_transform_job(
    job_name,
    model_name,
    model_s3_path,
    model_uri,
    image_url,
    flavor,
    vpc_config,
    role,
    sage_client,
    s3_client,
    instance_type,
    instance_count,
    s3_input_data_type,
    s3_input_uri,
    content_type,
    compression_type,
    split_type,
    s3_output_path,
    accept,
    assemble_with,
    input_filter,
    output_filter,
    join_resource,
):
    """
    Args:
        job_name: Name of the deployed Sagemaker batch transform job.
        model_name: The name to assign the new SageMaker model that will be associated with the
            specified batch transform job.
        model_s3_path: S3 path where we stored the model artifacts.
        model_uri: URI of the MLflow model to associate with the specified SageMaker batch
            transform job.
        image_url: URL of the ECR-hosted docker image the model is being deployed into.
        flavor: The name of the flavor of the model to use for deployment.
        vpc_config: A dictionary specifying the VPC configuration to use when creating the
            new SageMaker model associated with this SageMaker batch transform job.
        role: SageMaker execution ARN role.
        sage_client: A boto3 client for SageMaker.
        s3_client: A boto3 client for S3.
        instance_type: The type of SageMaker ML instance on which to deploy the model.
        instance_count: The number of SageMaker ML instances on which to deploy the model.
        s3_input_data_type: Input data type for the transform job.
        s3_input_uri: S3 key name prefix or a manifest of the input data.
        content_type: The multipurpose internet mail extension (MIME) type of the data.
        compression_type: The compression type of the transform data.
        split_type: The method to split the transform job's data files into smaller batches.
        s3_output_path: The S3 path to store the output results of the Sagemaker transform job.
        accept: The multipurpose internet mail extension (MIME) type of the output data.
        assemble_with: The method to assemble the results of the transform job as a single
            S3 object.
        input_filter: A JSONPath expression used to select a portion of the input data for the
            transform job.
        output_filter: A JSONPath expression used to select a portion of the output data from
            the transform job.
        join_resource: The source of the data to join with the transformed data.
    """
    _logger.info("Creating new batch transform job with name: %s ...", job_name)
    model_response = _create_sagemaker_model(
        model_name=model_name,
        model_s3_path=model_s3_path,
        model_uri=model_uri,
        flavor=flavor,
        vpc_config=vpc_config,
        image_url=image_url,
        execution_role=role,
        sage_client=sage_client,
        env={},
        tags={},
    )
    _logger.info("Created model with arn: %s", model_response["ModelArn"])
    transform_input = {
        "DataSource": {"S3DataSource": {"S3DataType": s3_input_data_type, "S3Uri": s3_input_uri}},
        "ContentType": content_type,
        "CompressionType": compression_type,
        "SplitType": split_type,
    }
    transform_output = {
        "S3OutputPath": s3_output_path,
        "Accept": accept,
        "AssembleWith": assemble_with,
    }
    transform_resources = {"InstanceType": instance_type, "InstanceCount": instance_count}
    data_processing = {
        "InputFilter": input_filter,
        "OutputFilter": output_filter,
        "JoinSource": join_resource,
    }
    transform_job_response = sage_client.create_transform_job(
        TransformJobName=job_name,
        ModelName=model_name,
        TransformInput=transform_input,
        TransformOutput=transform_output,
        TransformResources=transform_resources,
        DataProcessing=data_processing,
        Tags=[{"Key": "model_name", "Value": model_name}],
    )
    _logger.info(
        "Created batch transform job with arn: %s", transform_job_response["TransformJobArn"]
    )
    def status_check_fn():
        transform_job_info = sage_client.describe_transform_job(TransformJobName=job_name)
        if transform_job_info is None:
            return _SageMakerOperationStatus.in_progress(
                "Waiting for batch transform job to be created..."
            )
        transform_job_status = transform_job_info["TransformJobStatus"]
        if transform_job_status == "InProgress":
            return _SageMakerOperationStatus.in_progress(
                'Waiting for batch transform job to reach the "Completed" state.                   '
                f'  Current batch transform job status: "{transform_job_status}"'
            )
        elif transform_job_status == "Completed":
            return _SageMakerOperationStatus.succeeded(
                "The SageMaker batch transform job was processed successfully."
            )
        else:
            failure_reason = transform_job_info.get(
                "FailureReason",
                "An unknown SageMaker failure occurred. Please see the SageMaker console logs"
                " for more information.",
            )
            return _SageMakerOperationStatus.failed(failure_reason)
    def cleanup_fn():
        _logger.info("Cleaning up Sagemaker model and S3 model artifacts...")
        transform_job_info = sage_client.describe_transform_job(TransformJobName=job_name)
        model_name = transform_job_info["ModelName"]
        model_arn = _delete_sagemaker_model(model_name, sage_client, s3_client)
        _logger.info("Deleted associated model with arn: %s", model_arn)
    return _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn)
def _create_sagemaker_endpoint(  # noqa: D417
    endpoint_name,
    model_name,
    model_s3_path,
    model_uri,
    image_url,
    flavor,
    instance_type,
    vpc_config,
    data_capture_config,
    instance_count,
    role,
    sage_client,
    variant_name=None,
    async_inference_config=None,
    serverless_config=None,
    env=None,
    tags=None,
):
    """
    Args:
        endpoint_name: The name of the SageMaker endpoint to create.
        model_name: The name to assign the new SageMaker model that will be associated with the
            specified endpoint.
        model_s3_path: S3 path where we stored the model artifacts.
        model_uri: URI of the MLflow model to associate with the specified SageMaker endpoint.
        image_url: URL of the ECR-hosted docker image the model is being deployed into.
        flavor: The name of the flavor of the model to use for deployment.
        instance_type: The type of SageMaker ML instance on which to deploy the model.
        instance_count: The number of SageMaker ML instances on which to deploy the model.
        vpc_config: A dictionary specifying the VPC configuration to use when creating the
            new SageMaker model associated with this SageMaker endpoint.
        data_capture_config: A dictionary specifying the data capture configuration to use when
            creating the new SageMaker model associated with this application.
        role: SageMaker execution ARN role.
        sage_client: A boto3 client for SageMaker.
        variant_name: The name to assign to the new production variant.
        env: A dictionary of environment variables to set for the model.
        tags: A dictionary of tags to apply to the endpoint.
    """
    _logger.info("Creating new endpoint with name: %s ...", endpoint_name)
    model_response = _create_sagemaker_model(
        model_name=model_name,
        model_s3_path=model_s3_path,
        model_uri=model_uri,
        flavor=flavor,
        vpc_config=vpc_config,
        image_url=image_url,
        execution_role=role,
        sage_client=sage_client,
        env=env or {},
        tags=tags or {},
    )
    _logger.info("Created model with arn: %s", model_response["ModelArn"])
    if not variant_name:
        variant_name = model_name
    production_variant = {
        "VariantName": variant_name,
        "ModelName": model_name,
        "InitialVariantWeight": 1,
    }
    if serverless_config:
        production_variant["ServerlessConfig"] = serverless_config
    else:
        production_variant["InstanceType"] = instance_type
        production_variant["InitialInstanceCount"] = instance_count
    config_name = _get_sagemaker_config_name(endpoint_name)
    config_tags = _get_sagemaker_config_tags(endpoint_name)
    tags_list = _prepare_sagemaker_tags(config_tags, tags)
    endpoint_config_kwargs = {
        "EndpointConfigName": config_name,
        "ProductionVariants": [production_variant],
        "Tags": config_tags,
    }
    if async_inference_config:
        endpoint_config_kwargs["AsyncInferenceConfig"] = async_inference_config
    if data_capture_config is not None:
        endpoint_config_kwargs["DataCaptureConfig"] = data_capture_config
    endpoint_config_response = sage_client.create_endpoint_config(**endpoint_config_kwargs)
    _logger.info(
        "Created endpoint configuration with arn: %s", endpoint_config_response["EndpointConfigArn"]
    )
    endpoint_response = sage_client.create_endpoint(
        EndpointName=endpoint_name,
        EndpointConfigName=config_name,
        Tags=tags_list or [],
    )
    _logger.info("Created endpoint with arn: %s", endpoint_response["EndpointArn"])
    def status_check_fn():
        endpoint_info = _find_endpoint(endpoint_name=endpoint_name, sage_client=sage_client)
        if endpoint_info is None:
            return _SageMakerOperationStatus.in_progress("Waiting for endpoint to be created...")
        endpoint_status = endpoint_info["EndpointStatus"]
        if endpoint_status == "Creating":
            return _SageMakerOperationStatus.in_progress(
                'Waiting for endpoint to reach the "InService" state. Current endpoint status:'
                f' "{endpoint_status}"'
            )
        elif endpoint_status == "InService":
            return _SageMakerOperationStatus.succeeded(
                "The SageMaker endpoint was created successfully."
            )
        else:
            failure_reason = endpoint_info.get(
                "FailureReason",
                "An unknown SageMaker failure occurred. Please see the SageMaker console logs"
                " for more information.",
            )
            return _SageMakerOperationStatus.failed(failure_reason)
    def cleanup_fn():
        pass
    return _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn)
def _update_sagemaker_endpoint(  # noqa: D417
    endpoint_name,
    model_name,
    model_uri,
    image_url,
    model_s3_path,
    flavor,
    instance_type,
    instance_count,
    vpc_config,
    mode,
    role,
    sage_client,
    s3_client,
    variant_name=None,
    async_inference_config=None,
    serverless_config=None,
    data_capture_config=None,
    env=None,
    tags=None,
):
    """
    Args:
        endpoint_name: The name of the SageMaker endpoint to update.
        model_name: The name to assign the new SageMaker model that will be associated with the
            specified endpoint.
        model_uri: URI of the MLflow model to associate with the specified SageMaker endpoint.
        image_url: URL of the ECR-hosted Docker image the model is being deployed into
        model_s3_path: S3 path where we stored the model artifacts
        flavor: The name of the flavor of the model to use for deployment.
        instance_type: The type of SageMaker ML instance on which to deploy the model.
        instance_count: The number of SageMaker ML instances on which to deploy the model.
        vpc_config: A dictionary specifying the VPC configuration to use when creating the
            new SageMaker model associated with this SageMaker endpoint.
        mode: either mlflow.sagemaker.DEPLOYMENT_MODE_ADD or
            mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE.
        role: SageMaker execution ARN role.
        sage_client: A boto3 client for SageMaker.
        s3_client: A boto3 client for S3.
        variant_name: The name to assign to the new production variant if it doesn't already exist.
        async_inference_config: A dictionary specifying the async inference configuration to use.
            For more information, see https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AsyncInferenceConfig.html.
            Defaults to ``None``.
        data_capture_config: A dictionary specifying the data capture configuration to use.
            For more information, see https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html.
            Defaults to ``None``.
        env: A dictionary of environment variables to set for the model.
        tags: A dictionary of tags to apply to the endpoint configuration.
    """
    if mode not in [DEPLOYMENT_MODE_ADD, DEPLOYMENT_MODE_REPLACE]:
        msg = f"Invalid mode `{mode}` for deployment to a pre-existing application"
        raise ValueError(msg)
    endpoint_info = sage_client.describe_endpoint(EndpointName=endpoint_name)
    endpoint_arn = endpoint_info["EndpointArn"]
    deployed_config_name = endpoint_info["EndpointConfigName"]
    deployed_config_info = sage_client.describe_endpoint_config(
        EndpointConfigName=deployed_config_name
    )
    deployed_config_arn = deployed_config_info["EndpointConfigArn"]
    deployed_production_variants = deployed_config_info["ProductionVariants"]
    _logger.info("Found active endpoint with arn: %s. Updating...", endpoint_arn)
    new_model_response = _create_sagemaker_model(
        model_name=model_name,
        model_s3_path=model_s3_path,
        model_uri=model_uri,
        flavor=flavor,
        vpc_config=vpc_config,
        image_url=image_url,
        execution_role=role,
        sage_client=sage_client,
        env=env or {},
        tags=tags or {},
    )
    _logger.info("Created new model with arn: %s", new_model_response["ModelArn"])
    if not variant_name:
        variant_name = model_name
    if mode == DEPLOYMENT_MODE_ADD:
        new_model_weight = 0
        production_variants = deployed_production_variants
    elif mode == DEPLOYMENT_MODE_REPLACE:
        new_model_weight = 1
        production_variants = []
    new_production_variant = {
        "VariantName": variant_name,
        "ModelName": model_name,
        "InitialVariantWeight": new_model_weight,
    }
    if serverless_config:
        new_production_variant["ServerlessConfig"] = serverless_config
    else:
        new_production_variant["InstanceType"] = instance_type
        new_production_variant["InitialInstanceCount"] = instance_count
    production_variants.append(new_production_variant)
    # Create the new endpoint configuration and update the endpoint
    # to adopt the new configuration
    new_config_name = _get_sagemaker_config_name(endpoint_name)
    config_tags = _get_sagemaker_config_tags(endpoint_name)
    endpoint_config_kwargs = {
        "EndpointConfigName": new_config_name,
        "ProductionVariants": production_variants,
        "Tags": config_tags,
    }
    if async_inference_config:
        endpoint_config_kwargs["AsyncInferenceConfig"] = async_inference_config
    if data_capture_config is not None:
        endpoint_config_kwargs["DataCaptureConfig"] = data_capture_config
    endpoint_config_response = sage_client.create_endpoint_config(**endpoint_config_kwargs)
    _logger.info(
        "Created new endpoint configuration with arn: %s",
        endpoint_config_response["EndpointConfigArn"],
    )
    sage_client.update_endpoint(EndpointName=endpoint_name, EndpointConfigName=new_config_name)
    _logger.info("Updated endpoint with new configuration!")
    operation_start_time = time.time()
    def status_check_fn():
        if time.time() - operation_start_time < 20:
            # Wait at least 20 seconds before checking the status of the update; this ensures
            # that we don't consider the operation to have failed if small delays occur at
            # initialization time
            return _SageMakerOperationStatus.in_progress()
        endpoint_info = sage_client.describe_endpoint(EndpointName=endpoint_name)
        endpoint_update_was_rolled_back = (
            endpoint_info["EndpointStatus"] == "InService"
            and endpoint_info["EndpointConfigName"] != new_config_name
        )
        if endpoint_update_was_rolled_back or endpoint_info["EndpointStatus"] == "Failed":
            failure_reason = endpoint_info.get(
                "FailureReason",
                "An unknown SageMaker failure occurred."
                " Please see the SageMaker console logs for"
                " more information.",
            )
            return _SageMakerOperationStatus.failed(failure_reason)
        elif endpoint_info["EndpointStatus"] == "InService":
            return _SageMakerOperationStatus.succeeded(
                "The SageMaker endpoint was updated successfully."
            )
        else:
            return _SageMakerOperationStatus.in_progress(
                "The update operation is still in progress. Current endpoint status:"
                ' "{endpoint_status}"'.format(endpoint_status=endpoint_info["EndpointStatus"])
            )
    def cleanup_fn():
        _logger.info("Cleaning up unused resources...")
        if mode == DEPLOYMENT_MODE_REPLACE:
            for pv in deployed_production_variants:
                deployed_model_arn = _delete_sagemaker_model(
                    model_name=pv["ModelName"], sage_client=sage_client, s3_client=s3_client
                )
                _logger.info("Deleted model with arn: %s", deployed_model_arn)
        sage_client.delete_endpoint_config(EndpointConfigName=deployed_config_name)
        _logger.info("Deleted endpoint configuration with arn: %s", deployed_config_arn)
    return _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn)
def _create_sagemaker_model(
    model_name,
    model_s3_path,
    model_uri,
    flavor,
    vpc_config,
    image_url,
    execution_role,
    sage_client,
    env,
    tags,
):
    """
    Args:
        model_name: The name to assign the new SageMaker model that is created.
        model_s3_path: S3 path where the model artifacts are stored.
        model_uri: URI of the MLflow model associated with the new SageMaker model.
        flavor: The name of the flavor of the model.
        vpc_config: A dictionary specifying the VPC configuration to use when creating the
            new SageMaker model associated with this SageMaker endpoint.
        image_url: URL of the ECR-hosted Docker image that will serve as the
            model's container,
        execution_role: The ARN of the role that SageMaker will assume when creating the model.
        sage_client: A boto3 client for SageMaker.
        env: A dictionary of environment variables to set for the model.
        tags: A dictionary of tags to apply to the SageMaker model.
    Returns:
        AWS response containing metadata associated with the new model.
    """
    tags["model_uri"] = str(model_uri)
    create_model_args = {
        "ModelName": model_name,
        "PrimaryContainer": {
            "Image": image_url,
            "ModelDataUrl": model_s3_path,
            "Environment": _get_deployment_config(flavor_name=flavor, env_override=env),
        },
        "ExecutionRoleArn": execution_role,
        "Tags": [{"Key": key, "Value": str(value)} for key, value in tags.items()],
    }
    if vpc_config is not None:
        create_model_args["VpcConfig"] = vpc_config
    return sage_client.create_model(**create_model_args)
def _delete_sagemaker_model(model_name, sage_client, s3_client):  # noqa: D417
    """
    Args:
        sage_client: A boto3 client for SageMaker.
        s3_client: A boto3 client for S3.
    Returns:
        ARN of the deleted model.
    """
    model_info = sage_client.describe_model(ModelName=model_name)
    model_arn = model_info["ModelArn"]
    model_data_url = model_info["PrimaryContainer"]["ModelDataUrl"]
    # Parse the model data url to obtain a bucket path. The following
    # procedure is safe due to the well-documented structure of the `ModelDataUrl`
    # (see https://docs.aws.amazon.com/sagemaker/latest/dg/API_ContainerDefinition.html)
    parsed_data_url = urllib.parse.urlparse(model_data_url)
    bucket_name = parsed_data_url.netloc
    bucket_key = parsed_data_url.path.lstrip("/")
    s3_client.delete_object(Bucket=bucket_name, Key=bucket_key)
    sage_client.delete_model(ModelName=model_name)
    return model_arn
def _delete_sagemaker_endpoint_configuration(endpoint_config_name, sage_client):  # noqa: D417
    """
    Args:
        sage_client: A boto3 client for SageMaker.
    Returns:
        ARN of the deleted endpoint configuration.
    """
    endpoint_config_info = sage_client.describe_endpoint_config(
        EndpointConfigName=endpoint_config_name
    )
    sage_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
    return endpoint_config_info["EndpointConfigArn"]
def _find_endpoint(endpoint_name, sage_client):  # noqa: D417
    """
    Finds a SageMaker endpoint with the specified name in the caller's AWS account, returning a
    NoneType if the endpoint is not found.
    Args:
        sage_client: A boto3 client for SageMaker.
    Returns:
        If the endpoint exists, a dictionary of endpoint attributes. If the endpoint does not
        exist, ``None``.
    """
    endpoints_page = sage_client.list_endpoints(MaxResults=100, NameContains=endpoint_name)
    while True:
        for endpoint in endpoints_page["Endpoints"]:
            if endpoint["EndpointName"] == endpoint_name:
                return endpoint
        if "NextToken" in endpoints_page:
            endpoints_page = sage_client.list_endpoints(
                MaxResults=100, NextToken=endpoints_page["NextToken"], NameContains=endpoint_name
            )
        else:
            return None
def _find_transform_job(job_name, sage_client):  # noqa: D417
    """
    Finds a SageMaker batch transform job with the specified name in the caller's AWS account,
    returning a NoneType if the transform job is not found.
    Args:
        sage_client: A boto3 client for SageMaker.
    Returns:
        If the transform job exists, a dictionary of transform job attributes. If the
        transform job does not exist, ``None``.
    """
    transform_jobs_page = sage_client.list_transform_jobs(MaxResults=100, NameContains=job_name)
    while True:
        for transform_job in transform_jobs_page["TransformJobSummaries"]:
            if transform_job["TransformJobName"] == job_name:
                return transform_job
        if "NextToken" in transform_jobs_page:
            transform_jobs_page = sage_client.list_transform_jobs(
                MaxResults=100,
                NextToken=transform_jobs_page["NextToken"],
                NameContains=job_name,
            )
        else:
            return None
def _does_model_exist(model_name, sage_client):  # noqa: D417
    """
    Determines whether a SageMaker model exists with the specified name in the caller's AWS account,
    returning True if the model exists, returning False if the model does not exist.
    Args:
        sage_client: A boto3 client for SageMaker.
    Returns:
        If the model exists, ``True``. If the model does not
        exist, ``False``.
    """
    try:
        response = sage_client.describe_model(ModelName=model_name)
    except sage_client.exceptions.ClientError as error:
        if "Could not find model" in error.response["Error"]["Message"]:
            return False
    else:
        return bool(response)
[docs]class SageMakerDeploymentClient(BaseDeploymentClient):
    """
    Initialize a deployment client for SageMaker. The default region and assumed role ARN will
    be set according to the value of the `target_uri`.
    This class is meant to supersede the other ``mlflow.sagemaker`` real-time serving API's.
    It is also designed to be used through the :py:mod:`mlflow.deployments` module.
    This means that you can deploy to SageMaker using the
    `mlflow deployments CLI <https://www.mlflow.org/docs/latest/cli.html#mlflow-deployments>`_ and
    get a client through the :py:mod:`mlflow.deployments.get_deploy_client` function.
    Args:
        target_uri: A URI that follows one of the following formats:
            - ``sagemaker``: This will set the default region to `us-west-2` and
              the default assumed role ARN to `None`.
            - ``sagemaker:/region_name``: This will set the default region to
              `region_name` and the default assumed role ARN to `None`.
            - ``sagemaker:/region_name/assumed_role_arn``: This will set the default
              region to `region_name` and the default assumed role ARN to
              `assumed_role_arn`.
            When an `assumed_role_arn` is provided without a `region_name`,
            an MlflowException will be raised.
    """
    def __init__(self, target_uri):
        super().__init__(target_uri=target_uri)
        # Default region_name and assumed_role_arn when
        # the target_uri is `sagemaker` or `sagemaker:/`
        self.region_name = DEFAULT_REGION_NAME
        self.assumed_role_arn = None
        self._get_values_from_target_uri()
    def _get_values_from_target_uri(self):
        parsed = urllib.parse.urlparse(self.target_uri)
        values_str = parsed.path.strip("/")
        if not parsed.scheme or not values_str:
            return
        separator_index = values_str.find("/")
        if separator_index == -1:
            # values_str would look like us-east-1
            self.region_name = values_str
        else:
            # values_str could look like us-east-1/arn:aws:1234:role/assumed_role
            self.region_name = values_str[:separator_index]
            self.assumed_role_arn = values_str[separator_index + 1 :]
            # if values_str contains multiple interior slashes such as
            # us-east-1/////arn:aws:1234:role/assumed_role, remove
            # the extra slashes that come before "arn"
            self.assumed_role_arn = self.assumed_role_arn.strip("/")
        if self.region_name.startswith("arn"):
            raise MlflowException(
                message=(
                    "It looks like the target_uri contains an IAM role ARN without a region name.\n"
                    "A region name must be provided when the target_uri contains a role ARN.\n"
                    "In this case, the target_uri must follow the format: "
                    "sagemaker:/region_name/assumed_role_arn.\n"
                    f"The provided target_uri is: {self.target_uri}\n"
                ),
                error_code=INVALID_PARAMETER_VALUE,
            )
    def _default_deployment_config(self, create_mode=True):
        config = {
            "assume_role_arn": self.assumed_role_arn,
            "execution_role_arn": None,
            "bucket": None,
            "image_url": None,
            "region_name": self.region_name,
            "archive": False,
            "instance_type": DEFAULT_SAGEMAKER_INSTANCE_TYPE,
            "instance_count": DEFAULT_SAGEMAKER_INSTANCE_COUNT,
            "vpc_config": None,
            "data_capture_config": None,
            "synchronous": True,
            "timeout_seconds": 1200,
            "variant_name": None,
            "env": None,
            "tags": None,
            "async_inference_config": {},
            "serverless_config": {},
        }
        if create_mode:
            config["mode"] = DEPLOYMENT_MODE_CREATE
        else:
            config["mode"] = DEPLOYMENT_MODE_REPLACE
        return config
    def _apply_custom_config(self, config, custom_config):
        int_fields = {"instance_count", "timeout_seconds"}
        bool_fields = {"synchronous", "archive"}
        dict_fields = {
            "vpc_config",
            "data_capture_config",
            "tags",
            "env",
            "async_inference_config",
            "serverless_config",
        }
        for key, value in custom_config.items():
            if key not in config:
                continue
            if key in int_fields and not isinstance(value, int):
                value = int(value)
            elif key in bool_fields and not isinstance(value, bool):
                value = value == "True"
            elif key in dict_fields and not isinstance(value, dict):
                value = json.loads(value)
            config[key] = value
[docs]    def create_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None):
        """
        Deploy an MLflow model on AWS SageMaker.
        The currently active AWS account must have correct permissions set up.
        This function creates a SageMaker endpoint. For more information about the input data
        formats accepted by this endpoint, see the
        `MLflow deployment tools documentation <../../deployment/deploy-model-to-sagemaker.html>`_.
        Args:
            name: Name of the deployed application.
            model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker.
                For example:
                - ``/Users/me/path/to/local/model``
                - ``relative/path/to/local/model``
                - ``s3://my_bucket/path/to/model``
                - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
                - ``models:/<model_name>/<model_version>``
                - ``models:/<model_name>/<stage>``
                For more information about supported URI schemes, see
                `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
                artifact-locations>`_.
            flavor: The name of the flavor of the model to use for deployment. Must be either
                ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS.
                If ``None``, a flavor is automatically selected from the model's available
                flavors. If the specified flavor is not present or not supported for
                deployment, an exception will be thrown.
            config: Configuration parameters. The supported parameters are:
                - ``assume_role_arn``: The name of an IAM cross-account role to be assumed
                  to deploy SageMaker to another AWS account. If this parameter is not
                  specified, the role given in the ``target_uri`` will be used. If the
                  role is not given in the ``target_uri``, defaults to ``us-west-2``.
                - ``execution_role_arn``: The name of an IAM role granting the SageMaker
                  service permissions to access the specified Docker image and S3 bucket
                  containing MLflow model artifacts. If unspecified, the currently-assumed
                  role will be used. This execution role is passed to the SageMaker service
                  when creating a SageMaker model from the specified MLflow model. It is
                  passed as the ``ExecutionRoleArn`` parameter of the `SageMaker
                  CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/
                  dg/API_CreateModel.html>`_. This role is *not* assumed for any other
                  call. For more information about SageMaker execution roles for model
                  creation, see
                  https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html.
                - ``bucket``: S3 bucket where model artifacts will be stored. Defaults to a
                  SageMaker-compatible bucket name.
                - ``image_url``: URL of the ECR-hosted Docker image the model should be
                  deployed into, produced by ``mlflow sagemaker build-and-push-container``.
                  This parameter can also be specified by the environment variable
                  ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``.
                - ``region_name``: Name of the AWS region to which to deploy the application.
                  If unspecified, use the region name given in the ``target_uri``.
                  If it is also not specified in the ``target_uri``,
                  defaults to ``us-west-2``.
                - ``archive``: If ``True``, any pre-existing SageMaker application resources
                  that become inactive (i.e. as a result of deploying in
                  ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE`` mode) are preserved.
                  These resources may include unused SageMaker models and endpoint
                  configurations that were associated with a prior version of the
                  application endpoint. If ``False``, these resources are deleted.
                  In order to use ``archive=False``, ``create_deployment()`` must be executed
                  synchronously with ``synchronous=True``. Defaults to ``False``.
                - ``instance_type``: The type of SageMaker ML instance on which to deploy the
                  model. For a list of supported instance types, see
                  https://aws.amazon.com/sagemaker/pricing/instance-types/.
                  Defaults to ``ml.m4.xlarge``.
                - ``instance_count``: The number of SageMaker ML instances on which to deploy
                  the model. Defaults to ``1``.
                - ``synchronous``: If ``True``, this function will block until the deployment
                  process succeeds or encounters an irrecoverable failure. If ``False``,
                  this function will return immediately after starting the deployment
                  process. It will not wait for the deployment process to complete;
                  in this case, the caller is responsible for monitoring the health and
                  status of the pending deployment via native SageMaker APIs or the AWS
                  console. Defaults to ``True``.
                - ``timeout_seconds``: If ``synchronous`` is ``True``, the deployment process
                  will return after the specified number of seconds if no definitive result
                  (success or failure) is achieved. Once the function returns, the caller is
                  responsible for monitoring the health and status of the pending
                  deployment using native SageMaker APIs or the AWS console. If
                  ``synchronous`` is ``False``, this parameter is ignored.
                  Defaults to ``300``.
                - ``vpc_config``: A dictionary specifying the VPC configuration to use when
                  creating the new SageMaker model associated with this application.
                  The acceptable values for this parameter are identical to those of the
                  ``VpcConfig`` parameter in the `SageMaker boto3 client's create_model
                  method <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html
                  #SageMaker.Client.create_model>`_. For more information, see
                  https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html.
                  Defaults to ``None``.
                - ``data_capture_config``: A dictionary specifying the data capture
                  configuration to use when creating the new SageMaker model associated with
                  this application.
                  For more information, see
                  https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html.
                  Defaults to ``None``.
                - ``variant_name``: A string specifying the desired name when creating a production
                  variant.  Defaults to ``None``.
                - ``async_inference_config``: A dictionary specifying the
                  async_inference_configuration
                - ``serverless_config``: A dictionary specifying the serverless_configuration
                - ``env``: A dictionary specifying environment variables as key-value
                  pairs to be set for the deployed model. Defaults to ``None``.
                - ``tags``: A dictionary of key-value pairs representing additional
                  tags to be set for the deployed model. Defaults to ``None``.
            endpoint: (optional) Endpoint to create the deployment under. Currently unsupported
        .. code-block:: python
            :caption: Python example
            from mlflow.deployments import get_deploy_client
            vpc_config = {
                "SecurityGroupIds": [
                    "sg-123456abc",
                ],
                "Subnets": [
                    "subnet-123456abc",
                ],
            }
            config = dict(
                assume_role_arn="arn:aws:123:role/assumed_role",
                execution_role_arn="arn:aws:456:role/execution_role",
                bucket_name="my-s3-bucket",
                image_url="1234.dkr.ecr.us-east-1.amazonaws.com/mlflow-test:1.23.1",
                region_name="us-east-1",
                archive=False,
                instance_type="ml.m5.4xlarge",
                instance_count=1,
                synchronous=True,
                timeout_seconds=300,
                vpc_config=vpc_config,
                variant_name="prod-variant-1",
                env={"DISABLE_NGINX": "true", "GUNICORN_CMD_ARGS": '"--timeout 60"'},
                tags={"training_timestamp": "2022-11-01T05:12:26"},
            )
            client = get_deploy_client("sagemaker")
            client.create_deployment(
                "my-deployment",
                model_uri="/mlruns/0/abc/model",
                flavor="python_function",
                config=config,
            )
        .. code-block:: bash
            :caption:  Command-line example
            mlflow deployments create --target sagemaker:/us-east-1/arn:aws:123:role/assumed_role \\
                    --name my-deployment \\
                    --model-uri /mlruns/0/abc/model \\
                    --flavor python_function\\
                    -C execution_role_arn=arn:aws:456:role/execution_role \\
                    -C bucket_name=my-s3-bucket \\
                    -C image_url=1234.dkr.ecr.us-east-1.amazonaws.com/mlflow-test:1.23.1 \\
                    -C region_name=us-east-1 \\
                    -C archive=False \\
                    -C instance_type=ml.m5.4xlarge \\
                    -C instance_count=1 \\
                    -C synchronous=True \\
                    -C timeout_seconds=300 \\
                    -C variant_name=prod-variant-1 \\
                    -C vpc_config='{"SecurityGroupIds": ["sg-123456abc"], \\
                    "Subnets": ["subnet-123456abc"]}' \\
                    -C data_capture_config='{"EnableCapture": True, \\
                    'InitialSamplingPercentage': 100, 'DestinationS3Uri": 's3://my-bucket/path', \\
                    'CaptureOptions': [{'CaptureMode': 'Output'}]}'
                    -C env='{"DISABLE_NGINX": "true", "GUNICORN_CMD_ARGS": "\"--timeout 60\""}' \\
                    -C tags='{"training_timestamp": "2022-11-01T05:12:26"}' \\
        """
        final_config = self._default_deployment_config()
        if config:
            self._apply_custom_config(final_config, config)
        app_name, flavor = _deploy(
            app_name=name,
            model_uri=model_uri,
            flavor=flavor,
            execution_role_arn=final_config["execution_role_arn"],
            assume_role_arn=final_config["assume_role_arn"],
            bucket=final_config["bucket"],
            image_url=final_config["image_url"],
            region_name=final_config["region_name"],
            mode=mlflow.sagemaker.DEPLOYMENT_MODE_CREATE,
            archive=final_config["archive"],
            instance_type=final_config["instance_type"],
            instance_count=final_config["instance_count"],
            vpc_config=final_config["vpc_config"],
            data_capture_config=final_config["data_capture_config"],
            synchronous=final_config["synchronous"],
            timeout_seconds=final_config["timeout_seconds"],
            variant_name=final_config["variant_name"],
            async_inference_config=final_config["async_inference_config"],
            serverless_config=final_config["serverless_config"],
            env=final_config["env"],
            tags=final_config["tags"],
        )
        return {"name": app_name, "flavor": flavor} 
[docs]    def update_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None):
        """
        Update a deployment on AWS SageMaker. This function can replace or add a new model to
        an existing SageMaker endpoint. By default, this function replaces the existing model
        with the new one. The currently active AWS account must have correct permissions set up.
        Args:
            name: Name of the deployed application.
            model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker.
                For example:
                - ``/Users/me/path/to/local/model``
                - ``relative/path/to/local/model``
                - ``s3://my_bucket/path/to/model``
                - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
                - ``models:/<model_name>/<model_version>``
                - ``models:/<model_name>/<stage>``
                For more information about supported URI schemes, see
                `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
                artifact-locations>`_.
            flavor: The name of the flavor of the model to use for deployment. Must be either
                ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS.
                If ``None``, a flavor is automatically selected from the model's available
                flavors. If the specified flavor is not present or not supported for
                deployment, an exception will be thrown.
            config: Configuration parameters. The supported parameters are:
                - ``assume_role_arn``: The name of an IAM cross-account role to be assumed
                  to deploy SageMaker to another AWS account. If this parameter is not
                  specified, the role given in the ``target_uri`` will be used. If the
                  role is not given in the ``target_uri``, defaults to ``us-west-2``.
                - ``execution_role_arn``: The name of an IAM role granting the SageMaker
                  service permissions to access the specified Docker image and S3 bucket
                  containing MLflow model artifacts. If unspecified, the currently-assumed
                  role will be used. This execution role is passed to the SageMaker service
                  when creating a SageMaker model from the specified MLflow model. It is
                  passed as the ``ExecutionRoleArn`` parameter of the `SageMaker
                  CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/
                  dg/API_CreateModel.html>`_. This role is *not* assumed for any other
                  call. For more information about SageMaker execution roles for model
                  creation, see
                  https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html.
                - ``bucket``: S3 bucket where model artifacts will be stored. Defaults to a
                  SageMaker-compatible bucket name.
                - ``image_url``: URL of the ECR-hosted Docker image the model should be
                  deployed into, produced by ``mlflow sagemaker build-and-push-container``.
                  This parameter can also be specified by the environment variable
                  ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``.
                - ``region_name``: Name of the AWS region to which to deploy the application.
                  If unspecified, use the region name given in the ``target_uri``.
                  If it is also not specified in the ``target_uri``,
                  defaults to ``us-west-2``.
                - ``mode``: The mode in which to deploy the application.
                  Must be one of the following:
                  ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE``
                      If an application of the specified name exists, its model(s) is
                      replaced with the specified model. If no such application exists,
                      it is created with the specified name and model.
                      This is the default mode.
                  ``mlflow.sagemaker.DEPLOYMENT_MODE_ADD``
                      Add the specified model to a pre-existing application with the
                      specified name, if one exists. If the application does not exist,
                      a new application is created with the specified name and model.
                      NOTE: If the application **already exists**, the specified model is
                      added to the application's corresponding SageMaker endpoint with an
                      initial weight of zero (0). To route traffic to the model,
                      update the application's associated endpoint configuration using
                      either the AWS console or the ``UpdateEndpointWeightsAndCapacities``
                      function defined in https://docs.aws.amazon.com/sagemaker/latest/dg/API_UpdateEndpointWeightsAndCapacities.html.
                - ``archive``: If ``True``, any pre-existing SageMaker application resources
                  that become inactive (i.e. as a result of deploying in
                  ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE`` mode) are preserved.
                  These resources may include unused SageMaker models and endpoint
                  configurations that were associated with a prior version of the
                  application endpoint. If ``False``, these resources are deleted.
                  In order to use ``archive=False``, ``update_deployment()`` must be executed
                  synchronously with ``synchronous=True``. Defaults to ``False``.
                - ``instance_type``: The type of SageMaker ML instance on which to deploy the
                  model. For a list of supported instance types, see
                  https://aws.amazon.com/sagemaker/pricing/instance-types/.
                  Defaults to ``ml.m4.xlarge``.
                - ``instance_count``: The number of SageMaker ML instances on which to deploy
                  the model. Defaults to ``1``.
                - ``synchronous``: If ``True``, this function will block until the deployment
                  process succeeds or encounters an irrecoverable failure. If ``False``,
                  this function will return immediately after starting the deployment
                  process. It will not wait for the deployment process to complete;
                  in this case, the caller is responsible for monitoring the health and
                  status of the pending deployment via native SageMaker APIs or the AWS
                  console. Defaults to ``True``.
                - ``timeout_seconds``: If ``synchronous`` is ``True``, the deployment process
                  will return after the specified number of seconds if no definitive result
                  (success or failure) is achieved. Once the function returns, the caller is
                  responsible for monitoring the health and status of the pending
                  deployment using native SageMaker APIs or the AWS console. If
                  ``synchronous`` is ``False``, this parameter is ignored.
                  Defaults to ``300``.
                - ``variant_name``: A string specifying the desired name when creating a
                  production variant.  Defaults to ``None``.
                - ``vpc_config``: A dictionary specifying the VPC configuration to use when
                  creating the new SageMaker model associated with this application.
                  The acceptable values for this parameter are identical to those of the
                  ``VpcConfig`` parameter in the `SageMaker boto3 client's create_model
                  method <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html
                  #SageMaker.Client.create_model>`_. For more information, see
                  https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html.
                  Defaults to ``None``.
                - ``data_capture_config``: A dictionary specifying the data capture
                  configuration to use when creating the new SageMaker model associated with
                  this application.
                  For more information, see
                  https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html.
                  Defaults to ``None``.
                - ``async_inference_config``: A dictionary specifying the async config
                  configuration. Defaults to ``None``.
                - ``env``: A dictionary specifying environment variables as key-value pairs
                  to be set for the deployed model. Defaults to ``None``.
                - ``tags``: A dictionary of key-value pairs representing additional tags
                  to be set for the deployed model. Defaults to ``None``.
            endpoint: (optional) Endpoint containing the deployment to update. Currently unsupported
        .. code-block:: python
            :caption: Python example
            from mlflow.deployments import get_deploy_client
            vpc_config = {
                "SecurityGroupIds": [
                    "sg-123456abc",
                ],
                "Subnets": [
                    "subnet-123456abc",
                ],
            }
            data_capture_config = {
                "EnableCapture": True,
                "InitialSamplingPercentage": 100,
                "DestinationS3Uri": "s3://my-bucket/path",
                "CaptureOptions": [{"CaptureMode": "Output"}],
            }
            config = dict(
                assume_role_arn="arn:aws:123:role/assumed_role",
                execution_role_arn="arn:aws:456:role/execution_role",
                bucket_name="my-s3-bucket",
                image_url="1234.dkr.ecr.us-east-1.amazonaws.com/mlflow-test:1.23.1",
                region_name="us-east-1",
                mode="replace",
                archive=False,
                instance_type="ml.m5.4xlarge",
                instance_count=1,
                synchronous=True,
                timeout_seconds=300,
                variant_name="prod-variant-1",
                vpc_config=vpc_config,
                data_capture_config=data_capture_config,
                env={"DISABLE_NGINX": "true", "GUNICORN_CMD_ARGS": '"--timeout 60"'},
                tags={"training_timestamp": "2022-11-01T05:12:26"},
            )
            client = get_deploy_client("sagemaker")
            client.update_deployment(
                "my-deployment",
                model_uri="/mlruns/0/abc/model",
                flavor="python_function",
                config=config,
            )
        .. code-block:: bash
            :caption:  Command-line example
            mlflow deployments update --target sagemaker:/us-east-1/arn:aws:123:role/assumed_role \\
                    --name my-deployment \\
                    --model-uri /mlruns/0/abc/model \\
                    --flavor python_function\\
                    -C execution_role_arn=arn:aws:456:role/execution_role \\
                    -C bucket_name=my-s3-bucket \\
                    -C image_url=1234.dkr.ecr.us-east-1.amazonaws.com/mlflow-test:1.23.1 \\
                    -C region_name=us-east-1 \\
                    -C mode=replace \\
                    -C archive=False \\
                    -C instance_type=ml.m5.4xlarge \\
                    -C instance_count=1 \\
                    -C synchronous=True \\
                    -C timeout_seconds=300 \\
                    -C variant_name=prod-variant-1 \\
                    -C vpc_config='{"SecurityGroupIds": ["sg-123456abc"], \\
                    "Subnets": ["subnet-123456abc"]}' \\
                    -C data_capture_config='{"EnableCapture": True, \\
                    "InitialSamplingPercentage": 100, "DestinationS3Uri": "s3://my-bucket/path", \\
                    "CaptureOptions": [{"CaptureMode": "Output"}]}'
                    -C env='{"DISABLE_NGINX": "true", "GUNICORN_CMD_ARGS": "\"--timeout 60\""}' \\
                    -C tags='{"training_timestamp": "2022-11-01T05:12:26"}' \\
        """
        final_config = self._default_deployment_config(create_mode=False)
        if config:
            self._apply_custom_config(final_config, config)
        if model_uri is None:
            raise MlflowException(
                message="A model_uri must be provided when updating a SageMaker deployment",
                error_code=INVALID_PARAMETER_VALUE,
            )
        if final_config["mode"] not in [DEPLOYMENT_MODE_ADD, DEPLOYMENT_MODE_REPLACE]:
            raise MlflowException(
                message=(
                    f"Invalid mode `{final_config['mode']}` for deployment"
                    " to a pre-existing application"
                ),
                error_code=INVALID_PARAMETER_VALUE,
            )
        app_name, flavor = _deploy(
            app_name=name,
            model_uri=model_uri,
            flavor=flavor,
            execution_role_arn=final_config["execution_role_arn"],
            assume_role_arn=final_config["assume_role_arn"],
            bucket=final_config["bucket"],
            image_url=final_config["image_url"],
            region_name=final_config["region_name"],
            mode=final_config["mode"],
            archive=final_config["archive"],
            instance_type=final_config["instance_type"],
            instance_count=final_config["instance_count"],
            vpc_config=final_config["vpc_config"],
            data_capture_config=final_config["data_capture_config"],
            synchronous=final_config["synchronous"],
            timeout_seconds=final_config["timeout_seconds"],
            variant_name=final_config["variant_name"],
            async_inference_config=final_config["async_inference_config"],
            serverless_config=final_config["serverless_config"],
            env=final_config["env"],
            tags=final_config["tags"],
        )
        return {"name": app_name, "flavor": flavor} 
[docs]    def delete_deployment(self, name, config=None, endpoint=None):
        """
        Delete a SageMaker application.
        Args:
            name: Name of the deployed application.
            config: Configuration parameters. The supported parameters are:
                - ``assume_role_arn``: The name of an IAM role to be assumed to delete
                  the SageMaker deployment.
                - ``region_name``: Name of the AWS region in which the application
                  is deployed. Defaults to ``us-west-2`` or the region provided in
                  the `target_uri`.
                - ``archive``: If `True`, resources associated with the specified
                  application, such as its associated models and endpoint configuration,
                  are preserved. If `False`, these resources are deleted. In order to use
                  ``archive=False``, ``delete()`` must be executed synchronously with
                  ``synchronous=True``. Defaults to ``False``.
                - ``synchronous``: If `True`, this function blocks until the deletion process
                  succeeds or encounters an irrecoverable failure. If `False`, this function
                  returns immediately after starting the deletion process. It will not wait
                  for the deletion process to complete; in this case, the caller is
                  responsible for monitoring the status of the deletion process via native
                  SageMaker APIs or the AWS console. Defaults to ``True``.
                - ``timeout_seconds``: If `synchronous` is `True`, the deletion process
                  returns after the specified number of seconds if no definitive result
                  (success or failure) is achieved. Once the function returns, the caller
                  is responsible for monitoring the status of the deletion process via native
                  SageMaker APIs or the AWS console. If `synchronous` is False, this
                  parameter is ignored. Defaults to ``300``.
            endpoint: (optional) Endpoint containing the deployment to delete. Currently unsupported
        .. code-block:: python
            :caption: Python example
            from mlflow.deployments import get_deploy_client
            config = dict(
                assume_role_arn="arn:aws:123:role/assumed_role",
                region_name="us-east-1",
                archive=False,
                synchronous=True,
                timeout_seconds=300,
            )
            client = get_deploy_client("sagemaker")
            client.delete_deployment("my-deployment", config=config)
        .. code-block:: bash
            :caption: Command-line example
            mlflow deployments delete --target sagemaker \\
                    --name my-deployment \\
                    -C assume_role_arn=arn:aws:123:role/assumed_role \\
                    -C region_name=us-east-1 \\
                    -C archive=False \\
                    -C synchronous=True \\
                    -C timeout_seconds=300
        """
        final_config = {
            "region_name": self.region_name,
            "archive": False,
            "synchronous": True,
            "timeout_seconds": 300,
            "assume_role_arn": self.assumed_role_arn,
        }
        if config:
            self._apply_custom_config(final_config, config)
        _delete(
            name,
            region_name=final_config["region_name"],
            assume_role_arn=final_config["assume_role_arn"],
            archive=final_config["archive"],
            synchronous=final_config["synchronous"],
            timeout_seconds=final_config["timeout_seconds"],
        ) 
[docs]    def list_deployments(self, endpoint=None):
        """
        List deployments. This method returns a list of dictionaries that describes each deployment.
        If a region name needs to be specified, the plugin must be initialized
        with the AWS region in the ``target_uri`` such as ``sagemaker:/us-east-1``.
        To assume an IAM role, the plugin must be initialized
        with the AWS region and the role ARN in the ``target_uri`` such as
        ``sagemaker:/us-east-1/arn:aws:1234:role/assumed_role``.
        Args:
            endpoint: (optional) List deployments in the specified endpoint. Currently unsupported
        Returns:
            A list of dictionaries corresponding to deployments.
        .. code-block:: python
            :caption: Python example
            from mlflow.deployments import get_deploy_client
            client = get_deploy_client("sagemaker:/us-east-1/arn:aws:123:role/assumed_role")
            client.list_deployments()
        .. code-block:: bash
            :caption: Command-line example
            mlflow deployments list --target sagemaker:/us-east-1/arn:aws:1234:role/assumed_role
        """
        import boto3
        assume_role_credentials = _assume_role_and_get_credentials(
            assume_role_arn=self.assumed_role_arn
        )
        sage_client = boto3.client(
            "sagemaker", region_name=self.region_name, **assume_role_credentials
        )
        return sage_client.list_endpoints()["Endpoints"] 
[docs]    def get_deployment(self, name, endpoint=None):
        """
        Returns a dictionary describing the specified deployment.
        If a region name needs to be specified, the plugin must be initialized
        with the AWS region in the ``target_uri`` such as ``sagemaker:/us-east-1``.
        To assume an IAM role, the plugin must be initialized
        with the AWS region and the role ARN in the ``target_uri`` such as
        ``sagemaker:/us-east-1/arn:aws:1234:role/assumed_role``.
        A :py:class:`mlflow.exceptions.MlflowException` will also be thrown when an error occurs
        while retrieving the deployment.
        Args:
            name: Name of deployment to retrieve
            endpoint: (optional) Endpoint containing the deployment to get. Currently unsupported
        Returns:
            A dictionary that describes the specified deployment
        .. code-block:: python
            :caption: Python example
            from mlflow.deployments import get_deploy_client
            client = get_deploy_client("sagemaker:/us-east-1/arn:aws:123:role/assumed_role")
            client.get_deployment("my-deployment")
        .. code-block:: bash
            :caption: Command-line example
            mlflow deployments get --target sagemaker:/us-east-1/arn:aws:1234:role/assumed_role \\
                --name my-deployment
        """
        import boto3
        assume_role_credentials = _assume_role_and_get_credentials(
            assume_role_arn=self.assumed_role_arn
        )
        try:
            sage_client = boto3.client(
                "sagemaker", region_name=self.region_name, **assume_role_credentials
            )
            return sage_client.describe_endpoint(EndpointName=name)
        except Exception as exc:
            raise MlflowException(
                message=f"There was an error while retrieving the deployment: {exc}\n"
            ) 
[docs]    def predict(
        self,
        deployment_name=None,
        inputs=None,
        endpoint=None,
        params: Optional[dict[str, Any]] = None,
    ):
        """
        Compute predictions from the specified deployment using the provided PyFunc input.
        The input/output types of this method match the :ref:`MLflow PyFunc prediction
        interface <pyfunc-inference-api>`.
        If a region name needs to be specified, the plugin must be initialized
        with the AWS region in the ``target_uri`` such as ``sagemaker:/us-east-1``.
        To assume an IAM role, the plugin must be initialized
        with the AWS region and the role ARN in the ``target_uri`` such as
        ``sagemaker:/us-east-1/arn:aws:1234:role/assumed_role``.
        Args:
            deployment_name: Name of the deployment to predict against.
            inputs: Input data (or arguments) to pass to the deployment or model endpoint for
                inference. For a complete list of supported input types, see
                :ref:`pyfunc-inference-api`.
            endpoint: Endpoint to predict against. Currently unsupported
            params: Optional parameters to invoke the endpoint with.
        Returns:
            A PyFunc output, such as a Pandas DataFrame, Pandas Series, or NumPy array.
            For a complete list of supported output types, see :ref:`pyfunc-inference-api`.
        .. code-block:: python
            :caption: Python example
            import pandas as pd
            from mlflow.deployments import get_deploy_client
            df = pd.DataFrame(data=[[1, 2, 3]], columns=["feat1", "feat2", "feat3"])
            client = get_deploy_client("sagemaker:/us-east-1/arn:aws:123:role/assumed_role")
            client.predict("my-deployment", df)
        .. code-block:: bash
            :caption: Command-line example
            cat > ./input.json <<- input
            {"feat1": {"0": 1}, "feat2": {"0": 2}, "feat3": {"0": 3}}
            input
            mlflow deployments predict \\
                --target sagemaker:/us-east-1/arn:aws:1234:role/assumed_role \\
                --name my-deployment \\
                --input-path ./input.json
        """
        import boto3
        assume_role_credentials = _assume_role_and_get_credentials(
            assume_role_arn=self.assumed_role_arn
        )
        try:
            sage_client = boto3.client(
                "sagemaker-runtime", region_name=self.region_name, **assume_role_credentials
            )
            response = sage_client.invoke_endpoint(
                EndpointName=deployment_name,
                Body=dump_input_data(inputs, inputs_key="instances", params=params),
                ContentType="application/json",
            )
            response_body = response["Body"].read().decode("utf-8")
            return PredictionsResponse.from_json(response_body)
        except Exception as exc:
            raise MlflowException(
                message=f"There was an error while getting model prediction: {exc}\n"
            ) 
[docs]    def explain(self, deployment_name=None, df=None, endpoint=None):
        """
        *This function has not been implemented and will be coming in the future.*
        """
        raise NotImplementedError("This function is not implemented yet.") 
[docs]    def create_endpoint(self, name, config=None):
        """
        Create an endpoint with the specified target. By default, this method should block until
        creation completes (i.e. until it's possible to create a deployment within the endpoint).
        In the case of conflicts (e.g. if it's not possible to create the specified endpoint
        due to conflict with an existing endpoint), raises a
        :py:class:`mlflow.exceptions.MlflowException`. See target-specific plugin documentation
        for additional detail on support for asynchronous creation and other configuration.
        Args:
            name: Unique name to use for endpoint. If another endpoint exists with the same
                        name, raises a :py:class:`mlflow.exceptions.MlflowException`.
            config: (optional) Dict containing target-specific configuration for the endpoint.
        Returns:
            Dict corresponding to created endpoint, which must contain the 'name' key.
        """
        raise NotImplementedError("This function is not implemented yet.") 
[docs]    def update_endpoint(self, endpoint, config=None):
        """
        Update the endpoint with the specified name. You can update any target-specific attributes
        of the endpoint (via `config`). By default, this method should block until the update
        completes (i.e. until it's possible to create a deployment within the endpoint). See
        target-specific plugin documentation for additional detail on support for asynchronous
        update and other configuration.
        Args:
            endpoint: Unique name of endpoint to update
            config: (optional) dict containing target-specific configuration for the endpoint
        """
        raise NotImplementedError("This function is not implemented yet.") 
[docs]    def delete_endpoint(self, endpoint):
        """
        Delete the endpoint from the specified target. Deletion should be idempotent (i.e. deletion
        should not fail if retried on a non-existent deployment).
        Args:
            endpoint: Name of endpoint to delete
        """
        raise NotImplementedError("This function is not implemented yet.") 
[docs]    def list_endpoints(self):
        """
        List endpoints in the specified target. This method is expected to return an
        unpaginated list of all endpoints (an alternative would be to return a dict with
        an 'endpoints' field containing the actual endpoints, with plugins able to specify
        other fields, e.g. a next_page_token field, in the returned dictionary for pagination,
        and to accept a `pagination_args` argument to this method for passing
        pagination-related args).
        Returns:
            A list of dicts corresponding to endpoints. Each dict is guaranteed to
            contain a 'name' key containing the endpoint name. The other fields of
            the returned dictionary and their types may vary across targets.
        """
        raise NotImplementedError("This function is not implemented yet.") 
[docs]    def get_endpoint(self, endpoint):
        """
        Returns a dictionary describing the specified endpoint, throwing a
        py:class:`mlflow.exception.MlflowException` if no endpoint exists with the provided
        name.
        The dict is guaranteed to contain an 'name' key containing the endpoint name.
        The other fields of the returned dictionary and their types may vary across targets.
        Args:
            endpoint: Name of endpoint to fetch
        """
        raise NotImplementedError("This function is not implemented yet.")  
class _SageMakerOperation:
    def __init__(self, status_check_fn, cleanup_fn):
        self.status_check_fn = status_check_fn
        self.cleanup_fn = cleanup_fn
        self.start_time = time.time()
        self.status = _SageMakerOperationStatus(_SageMakerOperationStatus.STATE_IN_PROGRESS, None)
        self.cleaned_up = False
    def await_completion(self, timeout_seconds):
        iteration = 0
        begin = time.time()
        while (time.time() - begin) < timeout_seconds:
            status = self.status_check_fn()
            if status.state == _SageMakerOperationStatus.STATE_IN_PROGRESS:
                if iteration % 4 == 0:
                    # Log the progress status roughly every 20 seconds
                    _logger.info(status.message)
                time.sleep(5)
                iteration += 1
                continue
            else:
                self.status = status
                return status
        duration_seconds = time.time() - begin
        return _SageMakerOperationStatus.timed_out(duration_seconds)
    def clean_up(self):
        if self.status.state != _SageMakerOperationStatus.STATE_SUCCEEDED:
            raise ValueError(
                "Cannot clean up an operation that has not succeeded! Current operation state:"
                f" {self.status.state}"
            )
        if not self.cleaned_up:
            self.cleaned_up = True
        else:
            raise ValueError("`clean_up()` has already been executed for this operation!")
        self.cleanup_fn()
class _SageMakerOperationStatus:
    STATE_SUCCEEDED = "succeeded"
    STATE_FAILED = "failed"
    STATE_IN_PROGRESS = "in progress"
    STATE_TIMED_OUT = "timed_out"
    def __init__(self, state, message):
        self.state = state
        self.message = message
    @classmethod
    def in_progress(cls, message=None):
        if message is None:
            message = "The operation is still in progress."
        return cls(_SageMakerOperationStatus.STATE_IN_PROGRESS, message)
    @classmethod
    def timed_out(cls, duration_seconds):
        return cls(
            _SageMakerOperationStatus.STATE_TIMED_OUT,
            f"Timed out after waiting {duration_seconds} seconds for the operation to"
            " complete. This operation may still be in progress. Please check the AWS"
            " console for more information.",
        )
    @classmethod
    def failed(cls, message):
        return cls(_SageMakerOperationStatus.STATE_FAILED, message)
    @classmethod
    def succeeded(cls, message):
        return cls(_SageMakerOperationStatus.STATE_SUCCEEDED, message)