Skip to main content

MLflow TensorFlow Integration

Introduction

TensorFlow is an end-to-end open source platform for machine learning developed by Google. It provides a comprehensive ecosystem for building and deploying ML models, from research prototypes to production systems. TensorFlow's Keras API offers an intuitive interface for building neural networks while its powerful backend enables efficient computation across CPUs, GPUs, and TPUs.

MLflow's TensorFlow integration provides experiment tracking, model versioning, and deployment capabilities for deep learning workflows.

Why MLflow + TensorFlow?

Autologging

Enable comprehensive experiment tracking with one line: mlflow.tensorflow.autolog() automatically logs metrics, parameters, and models.

Experiment Tracking

Track training metrics, hyperparameters, model architectures, and artifacts across all TensorFlow experiments.

Model Registry

Version, stage, and deploy TensorFlow models with MLflow's model registry and serving infrastructure.

Reproducibility

Capture model states, training configurations, and environments for reproducible experiments.

Autologging

Enable comprehensive autologging with a single line:

python
import mlflow
import numpy as np
import tensorflow as tf
from tensorflow import keras

# Enable autologging
mlflow.tensorflow.autolog()

# Prepare sample data
data = np.random.uniform(size=[20, 28, 28, 3])
label = np.random.randint(2, size=20)

# Define model
model = keras.Sequential(
[
keras.Input([28, 28, 3]),
keras.layers.Conv2D(8, 2),
keras.layers.MaxPool2D(2),
keras.layers.Flatten(),
keras.layers.Dense(2),
keras.layers.Softmax(),
]
)

model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(0.001),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Training with automatic logging
with mlflow.start_run():
model.fit(data, label, batch_size=5, epochs=2)

Autologging captures training metrics, model parameters, optimizer configuration, and model artifacts automatically. Requires TensorFlow >= 2.3.0 and the model.fit() Keras API.

Configure autologging behavior:

python
mlflow.tensorflow.autolog(
log_models=True,
log_input_examples=True,
log_model_signatures=True,
log_every_n_steps=1,
)

Manual Logging with Keras Callback

For more control, use mlflow.tensorflow.MlflowCallback():

python
import mlflow
import numpy as np
from tensorflow import keras

# Prepare sample data
data = np.random.uniform(size=[100, 28, 28, 3])
labels = np.random.randint(2, size=100)

# Define and compile your model
model = keras.Sequential(
[
keras.Input([28, 28, 3]),
keras.layers.Conv2D(8, 3),
keras.layers.MaxPool2D(2),
keras.layers.Flatten(),
keras.layers.Dense(2, activation="softmax"),
]
)

model.compile(
loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.Adam(0.001),
metrics=["accuracy"],
)

# Create an MLflow run and add the callback
with mlflow.start_run() as run:
model.fit(
data,
labels,
batch_size=32,
epochs=10,
callbacks=[mlflow.tensorflow.MlflowCallback(run)],
)

Custom Callback

Create custom logging logic by subclassing keras.callbacks.Callback:

python
from tensorflow import keras
import math
import mlflow


class CustomMlflowCallback(keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
mlflow.log_metric("current_epoch", epoch)

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
# Log metrics in log scale
for k, v in logs.items():
if v > 0:
mlflow.log_metric(f"log_{k}", math.log(v), step=epoch)
mlflow.log_metric(k, v, step=epoch)

def on_train_end(self, logs=None):
# Log final model weights statistics
weights = self.model.get_weights()
mlflow.log_metric("total_parameters", sum(w.size for w in weights))

Model Logging

Save TensorFlow models with mlflow.tensorflow.log_model():

python
import mlflow
import tensorflow as tf
from tensorflow import keras

# Define model
model = keras.Sequential(
[
keras.Input([28, 28, 3]),
keras.layers.Conv2D(8, 2),
keras.layers.MaxPool2D(2),
keras.layers.Flatten(),
keras.layers.Dense(2),
keras.layers.Softmax(),
]
)

# Train model (code omitted for brevity)

# Log the model to MLflow
model_info = mlflow.tensorflow.log_model(model, name="model")

# Later, load the model for inference
loaded_model = mlflow.tensorflow.load_model(model_info.model_uri)
predictions = loaded_model.predict(tf.random.uniform([1, 28, 28, 3]))

Hyperparameter Optimization

Track hyperparameter tuning with MLflow:

python
import mlflow
import tensorflow as tf
from tensorflow import keras
import optuna


def objective(trial, x_train, y_train, x_val, y_val):
"""Optuna objective for TensorFlow hyperparameter tuning."""
with mlflow.start_run(nested=True):
# Define hyperparameter search space
params = {
"learning_rate": trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True),
"units": trial.suggest_int("units", 32, 512),
"dropout": trial.suggest_float("dropout", 0.1, 0.5),
}

# Create model with hyperparameters
model = keras.Sequential(
[
keras.layers.Input(shape=(28, 28, 3)),
keras.layers.Flatten(),
keras.layers.Dense(params["units"], activation="relu"),
keras.layers.Dropout(params["dropout"]),
keras.layers.Dense(10, activation="softmax"),
]
)

model.compile(
optimizer=keras.optimizers.Adam(learning_rate=params["learning_rate"]),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)

# Train and evaluate
history = model.fit(
x_train, y_train, validation_data=(x_val, y_val), epochs=5, verbose=0
)

val_accuracy = max(history.history["val_accuracy"])
mlflow.log_metric("val_accuracy", val_accuracy)

return val_accuracy


# Main experiment run
with mlflow.start_run(run_name="tensorflow_hyperparameter_optimization"):
study = optuna.create_study(direction="maximize")
study.optimize(
lambda trial: objective(trial, x_train, y_train, x_val, y_val), n_trials=20
)

# Log best parameters and results
mlflow.log_params({f"best_{k}": v for k, v in study.best_params.items()})
mlflow.log_metric("best_val_accuracy", study.best_value)

Model Registry Integration

Register TensorFlow models for version control and deployment:

python
import mlflow
from tensorflow import keras
from mlflow import MlflowClient

client = MlflowClient()

with mlflow.start_run():
# Create model for demonstration
model = keras.Sequential(
[
keras.layers.Conv2D(32, 3, activation="relu", input_shape=(224, 224, 3)),
keras.layers.MaxPooling2D(2),
keras.layers.Flatten(),
keras.layers.Dense(10, activation="softmax"),
]
)

# Log model to registry
model_info = mlflow.tensorflow.log_model(
model, name="tensorflow_model", registered_model_name="ImageClassifier"
)

# Tag for tracking
mlflow.set_tags(
{"model_type": "cnn", "dataset": "imagenet", "framework": "tensorflow"}
)

# Set model alias for deployment
client.set_registered_model_alias(
name="ImageClassifier",
alias="champion",
version=model_info.registered_model_version,
)

Learn More