MLflow TensorFlow Integration
Introduction
TensorFlow is an end-to-end open source platform for machine learning developed by Google. It provides a comprehensive ecosystem for building and deploying ML models, from research prototypes to production systems. TensorFlow's Keras API offers an intuitive interface for building neural networks while its powerful backend enables efficient computation across CPUs, GPUs, and TPUs.
MLflow's TensorFlow integration provides experiment tracking, model versioning, and deployment capabilities for deep learning workflows.
Why MLflow + TensorFlow?
Autologging
Enable comprehensive experiment tracking with one line: mlflow.tensorflow.autolog() automatically logs metrics, parameters, and models.
Experiment Tracking
Track training metrics, hyperparameters, model architectures, and artifacts across all TensorFlow experiments.
Model Registry
Version, stage, and deploy TensorFlow models with MLflow's model registry and serving infrastructure.
Reproducibility
Capture model states, training configurations, and environments for reproducible experiments.
Autologging
Enable comprehensive autologging with a single line:
import mlflow
import numpy as np
import tensorflow as tf
from tensorflow import keras
# Enable autologging
mlflow.tensorflow.autolog()
# Prepare sample data
data = np.random.uniform(size=[20, 28, 28, 3])
label = np.random.randint(2, size=20)
# Define model
model = keras.Sequential(
[
keras.Input([28, 28, 3]),
keras.layers.Conv2D(8, 2),
keras.layers.MaxPool2D(2),
keras.layers.Flatten(),
keras.layers.Dense(2),
keras.layers.Softmax(),
]
)
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(0.001),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Training with automatic logging
with mlflow.start_run():
model.fit(data, label, batch_size=5, epochs=2)
Autologging captures training metrics, model parameters, optimizer configuration, and model artifacts automatically. Requires TensorFlow >= 2.3.0 and the model.fit() Keras API.
Configure autologging behavior:
mlflow.tensorflow.autolog(
log_models=True,
log_input_examples=True,
log_model_signatures=True,
log_every_n_steps=1,
)
Manual Logging with Keras Callback
For more control, use mlflow.tensorflow.MlflowCallback():
import mlflow
import numpy as np
from tensorflow import keras
# Prepare sample data
data = np.random.uniform(size=[100, 28, 28, 3])
labels = np.random.randint(2, size=100)
# Define and compile your model
model = keras.Sequential(
[
keras.Input([28, 28, 3]),
keras.layers.Conv2D(8, 3),
keras.layers.MaxPool2D(2),
keras.layers.Flatten(),
keras.layers.Dense(2, activation="softmax"),
]
)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.Adam(0.001),
metrics=["accuracy"],
)
# Create an MLflow run and add the callback
with mlflow.start_run() as run:
model.fit(
data,
labels,
batch_size=32,
epochs=10,
callbacks=[mlflow.tensorflow.MlflowCallback(run)],
)
Custom Callback
Create custom logging logic by subclassing keras.callbacks.Callback:
from tensorflow import keras
import math
import mlflow
class CustomMlflowCallback(keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
mlflow.log_metric("current_epoch", epoch)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
# Log metrics in log scale
for k, v in logs.items():
if v > 0:
mlflow.log_metric(f"log_{k}", math.log(v), step=epoch)
mlflow.log_metric(k, v, step=epoch)
def on_train_end(self, logs=None):
# Log final model weights statistics
weights = self.model.get_weights()
mlflow.log_metric("total_parameters", sum(w.size for w in weights))
Model Logging
Save TensorFlow models with mlflow.tensorflow.log_model():
import mlflow
import tensorflow as tf
from tensorflow import keras
# Define model
model = keras.Sequential(
[
keras.Input([28, 28, 3]),
keras.layers.Conv2D(8, 2),
keras.layers.MaxPool2D(2),
keras.layers.Flatten(),
keras.layers.Dense(2),
keras.layers.Softmax(),
]
)
# Train model (code omitted for brevity)
# Log the model to MLflow
model_info = mlflow.tensorflow.log_model(model, name="model")
# Later, load the model for inference
loaded_model = mlflow.tensorflow.load_model(model_info.model_uri)
predictions = loaded_model.predict(tf.random.uniform([1, 28, 28, 3]))
Hyperparameter Optimization
Track hyperparameter tuning with MLflow:
import mlflow
import tensorflow as tf
from tensorflow import keras
import optuna
def objective(trial, x_train, y_train, x_val, y_val):
"""Optuna objective for TensorFlow hyperparameter tuning."""
with mlflow.start_run(nested=True):
# Define hyperparameter search space
params = {
"learning_rate": trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True),
"units": trial.suggest_int("units", 32, 512),
"dropout": trial.suggest_float("dropout", 0.1, 0.5),
}
# Create model with hyperparameters
model = keras.Sequential(
[
keras.layers.Input(shape=(28, 28, 3)),
keras.layers.Flatten(),
keras.layers.Dense(params["units"], activation="relu"),
keras.layers.Dropout(params["dropout"]),
keras.layers.Dense(10, activation="softmax"),
]
)
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=params["learning_rate"]),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
# Train and evaluate
history = model.fit(
x_train, y_train, validation_data=(x_val, y_val), epochs=5, verbose=0
)
val_accuracy = max(history.history["val_accuracy"])
mlflow.log_metric("val_accuracy", val_accuracy)
return val_accuracy
# Main experiment run
with mlflow.start_run(run_name="tensorflow_hyperparameter_optimization"):
study = optuna.create_study(direction="maximize")
study.optimize(
lambda trial: objective(trial, x_train, y_train, x_val, y_val), n_trials=20
)
# Log best parameters and results
mlflow.log_params({f"best_{k}": v for k, v in study.best_params.items()})
mlflow.log_metric("best_val_accuracy", study.best_value)
Model Registry Integration
Register TensorFlow models for version control and deployment:
import mlflow
from tensorflow import keras
from mlflow import MlflowClient
client = MlflowClient()
with mlflow.start_run():
# Create model for demonstration
model = keras.Sequential(
[
keras.layers.Conv2D(32, 3, activation="relu", input_shape=(224, 224, 3)),
keras.layers.MaxPooling2D(2),
keras.layers.Flatten(),
keras.layers.Dense(10, activation="softmax"),
]
)
# Log model to registry
model_info = mlflow.tensorflow.log_model(
model, name="tensorflow_model", registered_model_name="ImageClassifier"
)
# Tag for tracking
mlflow.set_tags(
{"model_type": "cnn", "dataset": "imagenet", "framework": "tensorflow"}
)
# Set model alias for deployment
client.set_registered_model_alias(
name="ImageClassifier",
alias="champion",
version=model_info.registered_model_version,
)