Serving LLMs with MLflow: Leveraging Custom PyFunc

Introduction

In this tutorial, we’ll explore how to save custom cutting-edge Large Language Models (LLMs) using MLflow. Specifically, we’ll delve into the intricacies of a situation where the default MLflow ‘transformers’ flavor does not provide direct support for our model type and its dependencies. This necessitates the creation of a custom pyfunc to ensure seamless model deployment.

Through this tutorial, we aim to provide you with:

  • An understanding of why certain models might need custom pyfunc definitions.

  • A walk-through of creating a custom pyfunc to handle model dependencies and interface data.

  • Insight into how a custom pyfunc can offer a simplified interface to end-users in a deployed environment.

The Challenge with Default Implementations

MLflow’s transformers flavor provides a standardized way to handle models from the HuggingFace Transformers library. However, not all models or configurations might fit neatly into this standardized format.

In our scenario, the model cannot use the default pipeline type due to certain incompatibilities. This poses a challenge: how do we ensure that our model can be saved, loaded, and served using MLflow, given these constraints?

The Power of Custom PyFunc

The solution lies in MLflow’s ability to define custom pyfunc. By creating a custom pyfunc, we can:

  • Define how the model loads its dependencies.

  • Customize the inference process.

  • Manipulate interface data to create specific inputs for the model.

Let’s dive into the code to see this in action.

Important Considerations Before Proceeding

Hardware Recommendations

This guide demonstrates the usage of a particularly large and intricate Large Language Model (LLM). Given its complexity:

  • GPU Requirement: It’s strongly advised to run this example on a system with a CUDA-capable GPU that possesses at least 64GB of VRAM.

  • CPU Caution: While technically feasible, executing the model on a CPU can result in extremely prolonged inference times, potentially taking tens of minutes for a single prediction, even on top-tier CPUs. The final cell of this notebook is deliberately not executed due to the limitations with performance when running this model on a CPU-only system. However, with an appropriately powerful GPU, the total runtime of this notebook is ~8 minutes end to end.

Execution Recommendations

If you’re considering running the code in this notebook:

  • Performance: For a smoother experience and to truly harness the model’s capabilities, use hardware aligned with the model’s design.

  • Dependencies: Ensure you’ve installed the recommended dependencies for optimal model performance. These are crucial for efficient model loading, initialization, attention computations, and inference processing:

pip install xformers==0.0.20 einops==0.6.1 flash-attn==v1.0.3.post0 triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python

Learning Objectives

Remember, while hands-on execution provides valuable insights, the primary aim of this guide is to illustrate the effective use of MLflow in the showcased workflow. If you’re unable to run the notebook due to hardware constraints, you can still gain a comprehensive understanding by reviewing and analyzing the code and explanations provided.

[1]:
# Load necessary libraries

import transformers
import mlflow
import torch
from huggingface_hub import snapshot_download
import accelerate

Downloading the Model and Tokenizer

First, we need to download our model and tokenizer. Here’s how we do it:

[2]:
# Download the MPT-7B instruct model and tokenizer to a local directory cache
snapshot_location = snapshot_download(repo_id="mosaicml/mpt-7b-instruct", local_dir="mpt-7b")

Defining the Custom PyFunc

Now, let’s define our custom pyfunc. This will dictate how our model loads its dependencies and how it performs predictions. Notice how we’ve wrapped the intricacies of the model within this class.

[3]:
class MPT(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        """
        This method initializes the tokenizer and language model
        using the specified model snapshot directory.
        """
        # Initialize tokenizer and language model
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            context.artifacts["snapshot"], padding_side="left"
        )

        config = transformers.AutoConfig.from_pretrained(
            context.artifacts["snapshot"], trust_remote_code=True
        )
        # If you are running this in a system that has a sufficiently powerful GPU with available VRAM,
        # uncomment the configuration setting below to leverage triton.
        # Note that triton dramatically improves the inference speed performance

        # config.attn_config["attn_impl"] = "triton"

        self.model = transformers.AutoModelForCausalLM.from_pretrained(
            context.artifacts["snapshot"],
            config=config,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        )

        # NB: If you do not have a CUDA-capable device or have torch installed with CUDA support
        # this setting will not function correctly. Setting device to 'cpu' is valid, but
        # the performance will be very slow.
        self.model.to(device="cpu")
        # If running on a GPU-compatible environment, uncomment the following line:
        # self.model.to(device="cuda")

        self.model.eval()

    def _build_prompt(self, instruction):
        """
        This method generates the prompt for the model.
        """
        INSTRUCTION_KEY = "### Instruction:"
        RESPONSE_KEY = "### Response:"
        INTRO_BLURB = (
            "Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request."
        )

        return f"""{INTRO_BLURB}
        {INSTRUCTION_KEY}
        {instruction}
        {RESPONSE_KEY}
        """

    def predict(self, context, model_input, params=None):
        """
        This method generates prediction for the given input.
        """
        prompt = model_input["prompt"][0]

        # Retrieve or use default values for temperature and max_tokens
        temperature = params.get("temperature", 0.1) if params else 0.1
        max_tokens = params.get("max_tokens", 1000) if params else 1000

        # Build the prompt
        prompt = self._build_prompt(prompt)

        # Encode the input and generate prediction
        # NB: Sending the tokenized inputs to the GPU here explicitly will not work if your system does not have CUDA support.
        # If attempting to run this with GPU support, change 'cpu' to 'cuda' for maximum performance
        encoded_input = self.tokenizer.encode(prompt, return_tensors="pt").to("cpu")
        output = self.model.generate(
            encoded_input,
            do_sample=True,
            temperature=temperature,
            max_new_tokens=max_tokens,
        )

        # Decode the prediction to text
        generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)

        # Removing the prompt from the generated text
        prompt_length = len(self.tokenizer.encode(prompt, return_tensors="pt")[0])
        generated_response = self.tokenizer.decode(
            output[0][prompt_length:], skip_special_tokens=True
        )

        return {"candidates": [generated_response]}

Building the Prompt

One key aspect of our custom pyfunc is the construction of a model prompt. Instead of the end-user having to understand and construct this prompt, our custom pyfunc takes care of it. This ensures that regardless of the intricacies of the model’s requirements, the end-user interface remains simple and consistent.

Review the method _build_prompt() inside our class above to see how custom input processing logic can be added to a custom pyfunc to support required translations of user-input data into a format that is compatible with the wrapped model instance.

[4]:
import pandas as pd
import numpy as np
import mlflow
from mlflow.models.signature import ModelSignature
from mlflow.types import DataType, Schema, ColSpec, ParamSchema, ParamSpec

# Define input and output schema
input_schema = Schema(
    [
        ColSpec(DataType.string, "prompt"),
    ]
)
output_schema = Schema([ColSpec(DataType.string, "candidates")])

parameters = ParamSchema(
    [
        ParamSpec("temperature", DataType.float, np.float32(0.1), None),
        ParamSpec("max_tokens", DataType.integer, np.int32(1000), None),
    ]
)

signature = ModelSignature(inputs=input_schema, outputs=output_schema, params=parameters)


# Define input example
input_example = pd.DataFrame({"prompt": ["What is machine learning?"]})

Set the experiment that we’re going to be logging our custom model to

If the the experiment doesn’t already exist, MLflow will create a new experiment with this name and will alert you that it has created a new experiment.

[5]:
mlflow.set_experiment(experiment_name="mpt-7b-instruct-evaluation")
2023/10/12 16:54:21 INFO mlflow.tracking.fluent: Experiment with name 'mpt-7b-instruct-evaluation' does not exist. Creating a new experiment.
[5]:
<Experiment: artifact_location='file:///Users/benjamin.wilson/repos/mlflow-fork/mlflow/docs/source/llms/custom-pyfunc-for-llms/notebooks/mlruns/528860847726625085', creation_time=1697144061460, experiment_id='528860847726625085', last_update_time=1697144061460, lifecycle_stage='active', name='mpt-7b-instruct-evaluation', tags={}>
[6]:
# Get the current base version of torch that is installed, without specific version modifiers
torch_version = torch.__version__.split("+")[0]

# Start an MLflow run context and log the MPT-7B model wrapper along with the param-included signature to
# allow for overriding parameters at inference time
with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        "mpt-7b-instruct",
        python_model=MPT(),
        # NOTE: the artifacts dictionary mapping is critical! This dict is used by the load_context() method in our MPT() class.
        artifacts={"snapshot": snapshot_location},
        pip_requirements=[
            f"torch=={torch_version}",
            f"transformers=={transformers.__version__}",
            f"accelerate=={accelerate.__version__}",
            "einops",
            "sentencepiece",
        ],
        input_example=input_example,
        signature=signature,
    )
2023/10/12 16:54:21 INFO mlflow.store.artifact.artifact_repo: The progress bar can be disabled by setting the environment variable MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR to false
/Users/benjamin.wilson/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils.
  warnings.warn("Setuptools is replacing distutils.")

Load the saved model

[7]:
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
Instantiating an MPTForCausalLM model from /Users/benjamin.wilson/.cache/huggingface/modules/transformers_modules/mpt-7b/modeling_mpt.py
You are using config.init_device='cpu', but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.

Test the model for inference

[8]:
# The execution of this is commented out for the purposes of runtime on CPU.
# If you are running this on a system with a sufficiently powerful GPU, you may uncomment and interface with the model!

# loaded_model.predict(pd.DataFrame(
#     {"prompt": ["What is machine learning?"]}), params={"temperature": 0.6}
# )

Conclusion

Through this tutorial, we’ve seen the power and flexibility of MLflow’s custom pyfunc. By understanding the specific needs of our model and defining a custom pyfunc to cater to those needs, we can ensure a seamless deployment process and a user-friendly interface.