MLflow Keras 3.0 Integration
Introduction
Keras 3.0 is a high-level neural networks API that runs on TensorFlow, JAX, and PyTorch backends. It provides a user-friendly interface for building and training deep learning models with the flexibility to switch backends without changing your code.
MLflow's Keras integration provides experiment tracking, model versioning, and deployment capabilities for deep learning workflows.
Why MLflow + Keras?
Autologging
Enable comprehensive experiment tracking with one line: mlflow.tensorflow.autolog() automatically logs metrics, parameters, and models.
Experiment Tracking
Track training metrics, hyperparameters, model architectures, and artifacts across all Keras experiments.
Model Registry
Version, stage, and deploy Keras models with MLflow's model registry and serving infrastructure.
Multi-Backend Support
Track experiments consistently across TensorFlow, JAX, and PyTorch backends.
Autologging
Enable comprehensive autologging with a single line:
import mlflow
import numpy as np
from tensorflow import keras
# Enable autologging
mlflow.tensorflow.autolog()
# Prepare sample data
X_train = np.random.rand(1000, 20)
y_train = np.random.randint(0, 2, 1000)
# Define model
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu", input_shape=(20,)),
keras.layers.Dense(32, activation="relu"),
keras.layers.Dense(1, activation="sigmoid"),
]
)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
# Training with automatic logging
with mlflow.start_run():
model.fit(X_train, y_train, epochs=10, batch_size=32, validation_split=0.2)
Autologging captures training metrics, model parameters, optimizer configuration, and model artifacts automatically.
Configure autologging behavior:
mlflow.tensorflow.autolog(
log_models=True,
log_input_examples=True,
log_model_signatures=True,
log_every_n_steps=1,
)
Manual Logging with Keras Callback
For more control, use mlflow.tensorflow.MlflowCallback():
import mlflow
import numpy as np
from tensorflow import keras
# Prepare sample data
X_train = np.random.rand(100, 20)
y_train = np.random.randint(0, 2, 100)
# Define and compile model
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu", input_shape=(20,)),
keras.layers.Dense(1, activation="sigmoid"),
]
)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
# Create an MLflow run and add the callback
with mlflow.start_run() as run:
model.fit(
X_train,
y_train,
epochs=10,
batch_size=32,
validation_split=0.2,
callbacks=[mlflow.tensorflow.MlflowCallback(run)],
)
Model Logging
Save Keras models with mlflow.tensorflow.log_model():
import mlflow
from tensorflow import keras
# Define model
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu", input_shape=(20,)),
keras.layers.Dense(1, activation="sigmoid"),
]
)
# Train model (code omitted for brevity)
# Log the model to MLflow
model_info = mlflow.tensorflow.log_model(model, name="model")
# Later, load the model for inference
loaded_model = mlflow.tensorflow.load_model(model_info.model_uri)
predictions = loaded_model.predict(X_test)
Model Registry Integration
Register Keras models for version control and deployment:
import mlflow
from tensorflow import keras
from mlflow import MlflowClient
with mlflow.start_run():
# Create a simple model for demonstration
model = keras.Sequential(
[
keras.layers.Conv2D(32, 3, activation="relu", input_shape=(28, 28, 1)),
keras.layers.MaxPooling2D(2),
keras.layers.Flatten(),
keras.layers.Dense(10, activation="softmax"),
]
)
# Log model to registry
model_info = mlflow.tensorflow.log_model(
model, name="keras_model", registered_model_name="ImageClassifier"
)
# Tag for tracking
mlflow.set_tags({"model_type": "cnn", "dataset": "mnist", "framework": "keras"})
# Set alias for production deployment
client = MlflowClient()
client.set_registered_model_alias(
name="ImageClassifier",
alias="champion",
version=model_info.registered_model_version,
)