Source code for mlflow.dspy.autolog

import importlib
import logging

from packaging.version import Version

import mlflow
from mlflow.dspy.constant import FLAVOR_NAME
from mlflow.tracing.provider import trace_disabled
from mlflow.tracing.utils import construct_full_inputs
from mlflow.utils.autologging_utils import (
    autologging_integration,
    get_autologging_config,
    safe_patch,
)
from mlflow.utils.autologging_utils.safety import exception_safe_function_for_class

_logger = logging.getLogger(__name__)


[docs]def autolog( log_traces: bool = True, log_traces_from_compile: bool = False, log_traces_from_eval: bool = True, log_compiles: bool = False, log_evals: bool = False, disable: bool = False, silent: bool = False, ): """ Enables (or disables) and configures autologging from DSPy to MLflow. Currently, the MLflow DSPy flavor only supports autologging for tracing. Args: log_traces: If ``True``, traces are logged for DSPy models by using. If ``False``, no traces are collected during inference. Default to ``True``. log_traces_from_compile: If ``True``, traces are logged when compiling (optimizing) DSPy programs. If ``False``, traces are only logged from normal model inference and disabled when compiling. Default to ``False``. log_traces_from_eval: If ``True``, traces are logged for DSPy models when running DSPy's `built-in evaluator <https://dspy.ai/learn/evaluation/metrics/#evaluation>`_. If ``False``, traces are only logged from normal model inference and disabled when running the evaluator. Default to ``True``. log_compiles: If ``True``, information about the optimization process is logged when `Teleprompter.compile()` is called. log_evals: If ``True``, information about the evaluation call is logged when `Evaluate.__call__()` is called. disable: If ``True``, disables the DSPy autologging integration. If ``False``, enables the DSPy autologging integration. silent: If ``True``, suppress all event logs and warnings from MLflow during DSPy autologging. If ``False``, show all events and warnings. """ # NB: The @autologging_integration annotation is used for adding shared logic. However, one # caveat is that the wrapped function is NOT executed when disable=True is passed. This prevents # us from running cleaning up logging when autologging is turned off. To workaround this, we # annotate _autolog() instead of this entrypoint, and define the cleanup logic outside it. # This needs to be called before doing any safe-patching (otherwise safe-patch will be no-op). # TODO: since this implementation is inconsistent, explore a universal way to solve the issue. _autolog( log_traces=log_traces, log_traces_from_compile=log_traces_from_compile, log_traces_from_eval=log_traces_from_eval, log_compiles=log_compiles, log_evals=log_evals, disable=disable, silent=silent, ) import dspy from mlflow.dspy.callback import MlflowCallback # Enable tracing by setting the MlflowCallback if not disable: if not any(isinstance(c, MlflowCallback) for c in dspy.settings.callbacks): dspy.settings.configure(callbacks=[*dspy.settings.callbacks, MlflowCallback()]) # DSPy token tracking has an issue before 3.0.4: https://github.com/stanfordnlp/dspy/pull/8831 if Version(importlib.metadata.version("dspy")) >= Version("3.0.4"): dspy.settings.configure(track_usage=True) else: dspy.settings.configure( callbacks=[c for c in dspy.settings.callbacks if not isinstance(c, MlflowCallback)] ) from dspy.teleprompt import Teleprompter compile_patch = "compile" for cls in Teleprompter.__subclasses__(): # NB: This is to avoid the abstraction inheritance of superclasses that are defined # only for the purposes of abstraction. The recursion behavior of the # __subclasses__ dunder method will target the appropriate subclasses we need to patch. if hasattr(cls, compile_patch): safe_patch( FLAVOR_NAME, cls, compile_patch, _patched_compile, manage_run=get_autologging_config(FLAVOR_NAME, "log_compiles"), ) from dspy.evaluate import Evaluate call_patch = "__call__" if hasattr(Evaluate, call_patch): safe_patch( FLAVOR_NAME, Evaluate, call_patch, _patched_evaluate, )
# This is required by mlflow.autolog() autolog.integration_name = FLAVOR_NAME @autologging_integration(FLAVOR_NAME) def _autolog( log_traces: bool, log_traces_from_compile: bool, log_traces_from_eval: bool, log_compiles: bool, log_evals: bool, disable: bool = False, silent: bool = False, ): pass def _active_callback(): import dspy from mlflow.dspy.callback import MlflowCallback for callback in dspy.settings.callbacks: if isinstance(callback, MlflowCallback): return callback def _patched_compile(original, self, *args, **kwargs): from mlflow.dspy.util import ( log_dspy_dataset, log_dspy_lm_state, log_dummy_model_outputs, save_dspy_module_state, ) # NB: Since calling mlflow.dspy.autolog() again does not unpatch a function, we need to # check this flag at runtime to determine if we should generate traces. # method to disable tracing for compile and evaluate by default @trace_disabled def _trace_disabled_fn(self, *args, **kwargs): return original(self, *args, **kwargs) def _compile_fn(self, *args, **kwargs): if callback := _active_callback(): callback.optimizer_stack_level += 1 try: if get_autologging_config(FLAVOR_NAME, "log_traces_from_compile"): result = original(self, *args, **kwargs) else: result = _trace_disabled_fn(self, *args, **kwargs) return result finally: if callback: callback.optimizer_stack_level -= 1 if callback.optimizer_stack_level == 0: # Reset the callback state after the completion of root compile callback.reset() if not get_autologging_config(FLAVOR_NAME, "log_compiles"): return _compile_fn(self, *args, **kwargs) # NB: Log a dummy run outputs such that "Run" tab is shown in the UI. Currently, the # GenAI experiment does not show the "Run" tab without this, which is critical gap for # DSPy users. This should be done BEFORE the compile call, because Run page is used # for tracking the compile progress, not only after finishing the compile. log_dummy_model_outputs() program = _compile_fn(self, *args, **kwargs) # Save the state of the best model in json format # so that users can see the demonstrations and instructions. save_dspy_module_state(program, "best_model.json") # Teleprompter.get_params is introduced in dspy 2.6.15 params = ( self.get_params() if Version(importlib.metadata.version("dspy")) >= Version("2.6.15") else {} ) # Construct the dict of arguments passed to the compile call inputs = construct_full_inputs(original, self, *args, **kwargs) # Update params with the arguments passed to the compile call params.update(inputs) mlflow.log_params({k: v for k, v in inputs.items() if isinstance(v, (int, float, str, bool))}) # Log the current DSPy LM state log_dspy_lm_state() if trainset := inputs.get("trainset"): log_dspy_dataset(trainset, "trainset.json") if valset := inputs.get("valset"): log_dspy_dataset(valset, "valset.json") return program if get_autologging_config(FLAVOR_NAME, "log_traces_from_compile"): return original(self, *args, **kwargs) else: return _trace_disabled_fn(self, *args, **kwargs) def _patched_evaluate(original, self, *args, **kwargs): # NB: Since calling mlflow.dspy.autolog() again does not unpatch a function, we need to # check this flag at runtime to determine if we should generate traces. # method to disable tracing for compile and evaluate by default @trace_disabled def _trace_disabled_fn(self, *args, **kwargs): return original(self, *args, **kwargs) if not get_autologging_config(FLAVOR_NAME, "log_traces_from_eval"): return _trace_disabled_fn(self, *args, **kwargs) # Patch metric call to log assessment results on the prediction traces new_kwargs = construct_full_inputs(original, self, *args, **kwargs) metric = new_kwargs.get("metric") or self.metric new_kwargs["metric"] = _patch_metric(metric) args_passed_positional = list(new_kwargs.keys())[: len(args)] new_args = [] for arg in args_passed_positional: new_args.append(new_kwargs.pop(arg)) return original(self, *new_args, **new_kwargs) def _patch_metric(metric): """Patch the metric call to log assessment results on the prediction traces.""" import dspy # NB: This patch MUST not raise an exception, otherwise may interrupt the evaluation call. @exception_safe_function_for_class def _patched(*args, **kwargs): # NB: DSPy runs prediction and the metric call in the same thread, so we can retrieve # the prediction trace ID using the last active trace ID. # https://github.com/stanfordnlp/dspy/blob/8224a99ca6402863540aae5aa3bc5eddbd2947c4/dspy/evaluate/evaluate.py#L170-L173 pred_trace_id = mlflow.get_last_active_trace_id(thread_local=True) if not pred_trace_id: _logger.debug("Tracing during evaluation is enabled, but no prediction trace found.") return metric(*args, **kwargs) try: score = metric(*args, **kwargs) except Exception as e: _logger.debug("Metric call failed, logging an assessment with error") mlflow.log_feedback(trace_id=pred_trace_id, name=metric.__name__, error=e) raise try: if isinstance(score, dspy.Prediction): # GEPA metric returns a Prediction object with score and feedback attributes. # https://dspy.ai/tutorials/gepa_aime/ value = getattr(score, "score", None) rationale = getattr(score, "feedback", None) else: value = score rationale = None mlflow.log_feedback( trace_id=pred_trace_id, name=metric.__name__, value=value, rationale=rationale, ) except Exception as e: _logger.debug(f"Failed to log feedback for metric on prediction trace: {e}") return score return _patched