Source code for mlflow.data.delta_dataset_source

from typing import Any, Dict, Optional

from mlflow.data.dataset_source import DatasetSource
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.utils.annotations import experimental
from mlflow.utils.databricks_utils import is_in_databricks_runtime

DATABRICKS_HIVE_METASTORE_NAME = "hive_metastore"
# these two catalog names both points to the workspace local default HMS (hive metastore).
DATABRICKS_LOCAL_METASTORE_NAMES = [DATABRICKS_HIVE_METASTORE_NAME, "spark_catalog"]
# samples catalog is managed by databricks for hosting public dataset like NYC taxi dataset.
# it is neither a UC nor local metastore catalog
DATABRICKS_SAMPLES_CATALOG_NAME = "samples"


[docs]@experimental class DeltaDatasetSource(DatasetSource): """ Represents the source of a dataset stored at in a delta table. """ def __init__( self, path: Optional[str] = None, delta_table_name: Optional[str] = None, delta_table_version: Optional[int] = None, ): if (path, delta_table_name).count(None) != 1: raise MlflowException( 'Must specify exactly one of "path" or "table_name"', INVALID_PARAMETER_VALUE, ) self._path = path self._delta_table_name = delta_table_name self._delta_table_version = delta_table_version @staticmethod def _get_source_type() -> str: return "delta_table"
[docs] def load(self, **kwargs): # pylint: disable=unused-argument """ Loads the dataset source as a Delta Dataset Source. :return: An instance of ``pyspark.sql.DataFrame``. """ from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() spark_read_op = spark.read.format("delta") if self._delta_table_version is not None: spark_read_op = spark_read_op.option("versionAsOf", self._delta_table_version) if self._path: return spark_read_op.load(self._path) else: return spark_read_op.table(self._delta_table_name)
@property def path(self) -> Optional[str]: return self._path @property def delta_table_name(self) -> Optional[str]: return self._delta_table_name @property def delta_table_version(self) -> Optional[int]: return self._delta_table_version @staticmethod def _can_resolve(raw_source: Any): return False @classmethod def _resolve(cls, raw_source: str) -> "DeltaDatasetSource": raise NotImplementedError # check if table is in the Databricks Unity Catalog def _is_databricks_uc_table(self): if is_in_databricks_runtime() and self._delta_table_name is not None: catalog_name = self._delta_table_name.split(".", 1)[0] return ( catalog_name not in DATABRICKS_LOCAL_METASTORE_NAMES and catalog_name != DATABRICKS_SAMPLES_CATALOG_NAME ) def _to_dict(self) -> Dict[Any, Any]: info = {} if self._path: info["path"] = self._path if self._delta_table_name: info["delta_table_name"] = self._delta_table_name if self._delta_table_version: info["delta_table_version"] = self._delta_table_version if self._is_databricks_uc_table(): info["is_databricks_uc_table"] = True return info @classmethod def _from_dict(cls, source_dict: Dict[Any, Any]) -> "DeltaDatasetSource": return cls( path=source_dict.get("path"), delta_table_name=source_dict.get("delta_table_name"), delta_table_version=source_dict.get("delta_table_version"), )