Source code for mlflow.tracking._model_registry.fluent

from mlflow.tracking.client import MlflowClient
from mlflow.exceptions import MlflowException
from mlflow.entities.model_registry import ModelVersion
from mlflow.entities.model_registry import RegisteredModel
from mlflow.protos.databricks_pb2 import RESOURCE_ALREADY_EXISTS, ErrorCode
from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository
from mlflow.utils.logging_utils import eprint
from mlflow.utils import get_results_from_paginated_fn
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
from mlflow.store.model_registry import SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT
from typing import Any, Dict, Optional, List


[docs]def register_model( model_uri, name, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, *, tags: Optional[Dict[str, Any]] = None, ) -> ModelVersion: """ Create a new model version in model registry for the model files specified by ``model_uri``. Note that this method assumes the model registry backend URI is the same as that of the tracking backend. :param model_uri: URI referring to the MLmodel directory. Use a ``runs:/`` URI if you want to record the run ID with the model in model registry. ``models:/`` URIs are currently not supported. :param name: Name of the registered model under which to create a new model version. If a registered model with the given name does not exist, it will be created automatically. :param await_registration_for: Number of seconds to wait for the model version to finish being created and is in ``READY`` status. By default, the function waits for five minutes. Specify 0 or None to skip waiting. :param tags: A dictionary of key-value pairs that are converted into :py:class:`mlflow.entities.model_registry.ModelVersionTag` objects. :return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by backend. .. test-code-block:: python :caption: Example import mlflow.sklearn from sklearn.ensemble import RandomForestRegressor mlflow.set_tracking_uri("sqlite:////tmp/mlruns.db") params = {"n_estimators": 3, "random_state": 42} # Log MLflow entities with mlflow.start_run() as run: rfr = RandomForestRegressor(**params).fit([[0, 1]], [1]) mlflow.log_params(params) mlflow.sklearn.log_model(rfr, artifact_path="sklearn-model") model_uri = "runs:/{}/sklearn-model".format(run.info.run_id) mv = mlflow.register_model(model_uri, "RandomForestRegressionModel") print("Name: {}".format(mv.name)) print("Version: {}".format(mv.version)) .. code-block:: text :caption: Output Name: RandomForestRegressionModel Version: 1 """ client = MlflowClient() try: create_model_response = client.create_registered_model(name) eprint("Successfully registered model '%s'." % create_model_response.name) except MlflowException as e: if e.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS): eprint( "Registered model '%s' already exists. Creating a new version of this model..." % name ) else: raise e if RunsArtifactRepository.is_runs_uri(model_uri): source = RunsArtifactRepository.get_underlying_uri(model_uri) (run_id, _) = RunsArtifactRepository.parse_runs_uri(model_uri) create_version_response = client.create_model_version( name, source, run_id, tags=tags, await_creation_for=await_registration_for ) else: create_version_response = client.create_model_version( name, source=model_uri, run_id=None, tags=tags, await_creation_for=await_registration_for, ) eprint( "Created version '{version}' of model '{model_name}'.".format( version=create_version_response.version, model_name=create_version_response.name ) ) return create_version_response
[docs]def search_registered_models( max_results: Optional[int] = None, filter_string: Optional[str] = None, order_by: Optional[List[str]] = None, ) -> List[RegisteredModel]: """ Search for registered models that satisfy the filter criteria. :param filter_string: Filter query string (e.g., ``"name = 'a_model_name' and tag.key = 'value1'"``), defaults to searching for all registered models. The following identifiers, comparators, and logical operators are supported. Identifiers - ``name``: registered model name. - ``tags.<tag_key>``: registered model tag. If ``tag_key`` contains spaces, it must be wrapped with backticks (e.g., ``"tags.`extra key`"``). Comparators - ``=``: Equal to. - ``!=``: Not equal to. - ``LIKE``: Case-sensitive pattern match. - ``ILIKE``: Case-insensitive pattern match. Logical operators - ``AND``: Combines two sub-queries and returns True if both of them are True. :param max_results: If passed, specifies the maximum number of models desired. If not passed, all models will be returned. :param order_by: List of column names with ASC|DESC annotation, to be used for ordering matching search results. :return: A list of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects that satisfy the search expressions. .. test-code-block:: python :caption: Example import mlflow from sklearn.linear_model import LogisticRegression with mlflow.start_run(): mlflow.sklearn.log_model( LogisticRegression(), "Cordoba", registered_model_name="CordobaWeatherForecastModel", ) mlflow.sklearn.log_model( LogisticRegression(), "Boston", registered_model_name="BostonWeatherForecastModel", ) # Get search results filtered by the registered model name filter_string = "name = 'CordobaWeatherForecastModel'" results = mlflow.search_registered_models(filter_string=filter_string) print("-" * 80) for res in results: for mv in res.latest_versions: print("name={}; run_id={}; version={}".format(mv.name, mv.run_id, mv.version)) # Get search results filtered by the registered model name that matches # prefix pattern filter_string = "name LIKE 'Boston%'" results = mlflow.search_registered_models(filter_string=filter_string) print("-" * 80) for res in results: for mv in res.latest_versions: print("name={}; run_id={}; version={}".format(mv.name, mv.run_id, mv.version)) # Get all registered models and order them by ascending order of the names results = mlflow.search_registered_models(order_by=["name ASC"]) print("-" * 80) for res in results: for mv in res.latest_versions: print("name={}; run_id={}; version={}".format(mv.name, mv.run_id, mv.version)) .. code-block:: text :caption: Output -------------------------------------------------------------------------------- name=CordobaWeatherForecastModel; run_id=248c66a666744b4887bdeb2f9cf7f1c6; version=1 -------------------------------------------------------------------------------- name=BostonWeatherForecastModel; run_id=248c66a666744b4887bdeb2f9cf7f1c6; version=1 -------------------------------------------------------------------------------- name=BostonWeatherForecastModel; run_id=248c66a666744b4887bdeb2f9cf7f1c6; version=1 name=CordobaWeatherForecastModel; run_id=248c66a666744b4887bdeb2f9cf7f1c6; version=1 """ def pagination_wrapper_func(number_to_get, next_page_token): return MlflowClient().search_registered_models( max_results=number_to_get, filter_string=filter_string, order_by=order_by, page_token=next_page_token, ) return get_results_from_paginated_fn( pagination_wrapper_func, SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT, max_results, )