LangChain within MLflow (Experimental)

Attention

The langchain flavor is currently under active development and is marked as Experimental. Public APIs are evolving, and new features are being added to enhance its functionality.

Overview

LangChain is a Python framework for creating applications powered by language models. It offers unique features for developing context-aware applications that utilize language models for reasoning and generating responses. This integration with MLflow streamlines the development and deployment of complex NLP applications.

LangChain’s Technical Essence

  • Context-Aware Applications: LangChain specializes in connecting language models to various sources of context, enabling them to produce more relevant and accurate outputs.

  • Reasoning Capabilities: It uses the power of language models to reason about the given context and take appropriate actions based on it.

  • Flexible Chain Composition: The LangChain Expression Language (LCEL) allows for easy construction of complex chains from basic components, supporting functionalities like streaming, parallelism, and logging.

Building Chains with LangChain

  • Basic Components: LangChain facilitates chaining together components like prompt templates, models, and output parsers to create complex workflows.

  • Example - Joke Generator: - A basic chain can take a topic and generate a joke using a combination of a prompt template, a ChatOpenAI model, and an output parser. - The components are chained using the | operator, similar to a Unix pipe, allowing the output of one component to feed into the next.

  • Advanced Use Cases: - LangChain also supports more complex setups, like Retrieval-Augmented Generation (RAG) chains, which can add context when responding to questions.

Integration with MLflow

  • Simplified Logging and Loading: MLflow’s langchain flavor provides functions like log_model() and load_model(), enabling easy logging and retrieval of LangChain models within the MLflow ecosystem.

  • Simplified Deployment: LangChain models logged in MLflow can be interpreted as generic Python functions, simplifying their deployment and use in diverse applications. With dependency management incorporated directly into your logged model, you can deploy your application knowing that the environment that you used to train the model is what will be used to serve it.

  • Versatile Model Interaction: The integration allows developers to leverage LangChain’s unique features in conjunction with MLflow’s robust model tracking and management capabilities.

  • Autologging: MLflow’s langchain flavor provides autologging of LangChain models, which automatically logs artifacts, metrics and models for inference.

The langchain model flavor enables logging of LangChain models in MLflow format via the mlflow.langchain.save_model() and mlflow.langchain.log_model() functions. Use of these functions also adds the python_function flavor to the MLflow Models that they produce, allowing the model to be interpreted as a generic Python function for inference via mlflow.pyfunc.load_model().

You can also use the mlflow.langchain.load_model() function to load a saved or logged MLflow Model with the langchain flavor as a dictionary of the model’s attributes.

Basic Example: Logging a LangChain LLMChain in MLflow

import os

from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate

import mlflow

# Ensure the OpenAI API key is set in the environment
assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable."

# Initialize the OpenAI model and the prompt template
llm = OpenAI(temperature=0.9)
prompt = PromptTemplate(
    input_variables=["product"],
    template="What is a good name for a company that makes {product}?",
)

# Create the LLMChain with the specified model and prompt
chain = LLMChain(llm=llm, prompt=prompt)

# Log the LangChain LLMChain in an MLflow run
with mlflow.start_run():
    logged_model = mlflow.langchain.log_model(chain, "langchain_model")

# Load the logged model using MLflow's Python function flavor
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)

# Predict using the loaded model
print(loaded_model.predict([{"product": "colorful socks"}]))

The output of the example is shown below:

Output
["\n\nColorful Cozy Creations."]

What the Simple LLMChain Example Showcases

  • Integration Flexibility: The example highlights how LangChain’s LLMChain, consisting of an OpenAI model and a custom prompt template, can be easily logged in MLflow.

  • Simplified Model Management: Through MLflow’s langchain flavor, the chain is logged, enabling version control, tracking, and easy retrieval.

  • Ease of Deployment: The logged LangChain model is loaded using MLflow’s pyfunc module, illustrating the straightforward deployment process for LangChain models within MLflow.

  • Practical Application: The final prediction step demonstrates the model’s functionality in a real-world scenario, generating a company name based on a given product.

Logging a LangChain Agent with MLflow

What is an Agent?

Agents in LangChain leverage language models to dynamically determine and execute a sequence of actions, contrasting with the hardcoded sequences in chains. To learn more about Agents and see additional examples within LangChain, you can read the LangChain docs on Agents.

Key Components of Agents

Agent
  • The core chain driving decision-making, utilizing a language model and a prompt.

  • Receives inputs like tool descriptions, user objectives, and previously executed steps.

  • Outputs the next action set (AgentActions) or the final response (AgentFinish).

Tools
  • Functions invoked by agents to fulfill tasks.

  • Essential to provide appropriate tools and accurately describe them for effective use.

Toolkits
  • Collections of tools tailored for specific tasks.

  • LangChain offers a range of built-in toolkits and supports custom toolkit creation.

AgentExecutor
  • The runtime environment executing agent decisions.

  • Handles complexities such as tool errors and agent output parsing.

  • Ensures comprehensive logging and observability.

Additional Agent Runtimes
  • Beyond AgentExecutor, LangChain supports experimental runtimes like Plan-and-execute Agent, Baby AGI, and Auto GPT.

  • Custom runtime logic creation is also facilitated.

An Example of Logging an LangChain Agent

This example illustrates the process of logging a LangChain Agent in MLflow, highlighting the integration of LangChain’s complex agent functionalities with MLflow’s robust model management.

import os

from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.llms import OpenAI

import mlflow

# Note: Ensure that the package 'google-search-results' is installed via pypi to run this example
# and that you have a accounts with SerpAPI and OpenAI to use their APIs.

# Ensuring necessary API keys are set
assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable."
assert "SERPAPI_API_KEY" in os.environ, "Please set the SERPAPI_API_KEY environment variable."

# Load the language model for agent control
llm = OpenAI(temperature=0)

# Next, let's load some tools to use. Note that the `llm-math` tool uses an LLM, so we need to pass that in.
tools = load_tools(["serpapi", "llm-math"], llm=llm)

# Finally, let's initialize an agent with the tools, the language model, and the type of agent we want to use.
agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)

# Log the agent in an MLflow run
with mlflow.start_run():
    logged_model = mlflow.langchain.log_model(agent, "langchain_model")

# Load the logged agent model for prediction
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)

# Generate an inference result using the loaded model
question = "What was the high temperature in SF yesterday in Fahrenheit? What is that number raised to the .023 power?"

answer = loaded_model.predict([{"input": question}])

print(answer)

The output of the example above is shown below:

Output
["1.1044000282035853"]

What the Simple Agent Example Showcases

  • Complex Agent Logging: Demonstrates how LangChain’s sophisticated agent, which utilizes multiple tools and a language model, can be logged in MLflow.

  • Integration of Advanced Tools: Showcases the use of additional tools like ‘serpapi’ and ‘llm-math’ with a LangChain agent, emphasizing the framework’s capability to integrate complex functionalities.

  • Agent Initialization and Usage: Details the initialization process of a LangChain agent with specific tools and model settings, and how it can be used to perform complex queries.

  • Efficient Model Management and Deployment: Illustrates the ease with which complex LangChain agents can be managed and deployed using MLflow, from logging to prediction.

Enhanced Management of RetrievalQA Chains with MLflow

LangChain’s integration with MLflow introduces a more efficient way to manage and utilize the RetrievalQA chains, a key aspect of LangChain’s capabilities. These chains adeptly combine data retrieval with question-answering processes, leveraging the strength of language models.

Key Insights into RetrievalQA Chains

  • RetrievalQA Chain Functionality: These chains represent a sophisticated LangChain feature where information retrieval is seamlessly blended with language model-based question answering. They excel in scenarios requiring the language model to consult specific data or documents for accurate responses.

  • Role of the Retrieval Object: At the core of RetrievalQA chains lies the retriever object, tasked with sourcing relevant documents or data in response to queries.

Detailed Overview of the RAG Process

  • Document Loaders: Facilitate loading documents from a diverse array of sources, boasting over 100 loaders and integrations.

  • Document Transformers: Prepare documents for retrieval by transforming and segmenting them into manageable parts.

  • Text Embedding Models: Generate semantic embeddings of texts, enhancing the relevance and efficiency of data retrieval.

  • Vector Stores: Specialized databases that store and facilitate the search of text embeddings.

  • Retrievers: Employ various retrieval techniques, ranging from simple semantic searches to more sophisticated methods like the Parent Document Retriever and Ensemble Retriever.

Clarifying Vector Database Management with MLflow

  • Traditional LangChain Serialization: LangChain typically requires manual management for the serialization of retriever objects, including handling of the vector database.

  • MLflow’s Simplification: The langchain flavor in MLflow substantially simplifies this process. It automates serialization, managing the contents of the persist_dir and the pickling of the loader_fn function.

Key MLflow Components and VectorDB Logging

  1. persist_dir: The directory where the retriever object, including the vector database, is stored.

  2. loader_fn: The function for loading the retriever object from its storage location.

Important Considerations

  • VectorDB Logging: MLflow, through its langchain flavor, does manage the vector database as part of the retriever object. However, the vector database itself is not explicitly logged as a separate entity in MLflow.

  • Runtime VectorDB Maintenance: It’s essential to maintain consistency in the vector database between the training and runtime environments. While MLflow manages the serialization of the retriever object, ensuring that the same vector database is accessible during runtime remains crucial for consistent performance.

An Example of logging a LangChain RetrievalQA Chain

import os
import tempfile

from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS

import mlflow

assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable."

with tempfile.TemporaryDirectory() as temp_dir:
    persist_dir = os.path.join(temp_dir, "faiss_index")

    # Create the vector db, persist the db to a local fs folder
    loader = TextLoader("tests/langchain/state_of_the_union.txt")
    documents = loader.load()
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    docs = text_splitter.split_documents(documents)
    embeddings = OpenAIEmbeddings()
    db = FAISS.from_documents(docs, embeddings)
    db.save_local(persist_dir)

    # Create the RetrievalQA chain
    retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=db.as_retriever())

    # Log the retrievalQA chain
    def load_retriever(persist_directory):
        embeddings = OpenAIEmbeddings()
        vectorstore = FAISS.load_local(persist_directory, embeddings)
        return vectorstore.as_retriever()

    with mlflow.start_run() as run:
        logged_model = mlflow.langchain.log_model(
            retrievalQA,
            artifact_path="retrieval_qa",
            loader_fn=load_retriever,
            persist_dir=persist_dir,
        )

# Load the retrievalQA chain
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))

The output of the example above is shown below:

Output (truncated)
[" The president said..."]

Logging and Evaluating a LangChain Retriever in MLflow

The langchain flavor in MLflow extends its functionalities to include the logging and individual evaluation of retriever objects. This capability is particularly valuable for assessing the quality of documents retrieved by a retriever without needing to process them through a large language model (LLM).

Purpose of Logging Individual Retrievers

  • Independent Evaluation: Allows for the assessment of a retriever’s performance in fetching relevant documents, independent of their subsequent use in LLMs.

  • Quality Assurance: Facilitates the evaluation of the retriever’s effectiveness in sourcing accurate and contextually appropriate documents.

Requirements for Logging Retrievers in MLflow

  • persist_dir: Specifies where the retriever object is stored.

  • loader_fn: Details the function used to load the retriever object from its storage location.

  • These requirements align with those for logging RetrievalQA chains, ensuring consistency in the process.

An example of logging a LangChain Retriever

import os
import tempfile

from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS

import mlflow

assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable."

with tempfile.TemporaryDirectory() as temp_dir:
    persist_dir = os.path.join(temp_dir, "faiss_index")

    # Create the vector database and persist it to a local filesystem folder
    loader = TextLoader("tests/langchain/state_of_the_union.txt")
    documents = loader.load()
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    docs = text_splitter.split_documents(documents)
    embeddings = OpenAIEmbeddings()
    db = FAISS.from_documents(docs, embeddings)
    db.save_local(persist_dir)

    # Define a loader function to recall the retriever from the persisted vectorstore
    def load_retriever(persist_directory):
        embeddings = OpenAIEmbeddings()
        vectorstore = FAISS.load_local(persist_directory, embeddings)
        return vectorstore.as_retriever()

    # Log the retriever with the loader function
    with mlflow.start_run() as run:
        logged_model = mlflow.langchain.log_model(
            db.as_retriever(),
            artifact_path="retriever",
            loader_fn=load_retriever,
            persist_dir=persist_dir,
        )

# Load the retriever chain
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))

The output of the example above is shown below:

Output (truncated)
[
    [
        {
            "page_content": "Tonight. I call...",
            "metadata": {"source": "/state.txt"},
        },
        {
            "page_content": "A former top...",
            "metadata": {"source": "/state.txt"},
        },
    ]
]

MLflow Langchain Autologging

MLflow langchain flavor supports autologging of LangChain models, which provides the following benefits:

  • Streamlined Logging Process: Simplified Logging with Autologging eliminates the manual effort required to log LangChain models and metrics in MLflow. It achieves this by seamlessly integrating the MlflowCallbackHandler, which automatically records metrics and artifacts.

  • Effortless Artifact Logging: Autologging simplifies the process by automatically logging artifacts that encapsulate crucial details about the LangChain model. This includes information about various tools, chains, retrievers, agents, and llms used during inference, along with configurations and other relevant metadata.

  • Seamless Metrics Recording: Autologging effortlessly captures metrics for evaluating generated texts, as well as key objects such as llms and agents employed during inference.

  • Automated Input and Output Logging: Autologging takes care of logging inputs and outputs of the LangChain model during inference. The recorded results are neatly organized into an inference_inputs_outputs.json file, providing a comprehensive overview of the model’s inference history.

Note

To use MLflow LangChain autologging, please upgrade langchain to version 0.1.0 or higher. Depending on your existing environment, you may need to manually install langchain_community>=0.0.16 in order to enable the automatic logging of artifacts and metrics. (this behavior will be modified in the future to be an optional import) If autologging doesn’t log artifacts as expected, please check the warning messages in stdout logs. For langchain_community==0.0.16, you will need to install the textstat and spacy libraries manually, as well as restarting any active interactive environment (i.e., a notebook environment). On Databricks, you can achieve this via executing dbutils.library.restartPython() to force the Python REPL to restart, allowing the newly installed libraries to be available.

MLflow langchain autologging injects MlflowCallbackHandler into the langchain model inference process to log metrics and artifacts automatically. We will only log the model if both log_models is set to True when calling mlflow.langchain.autolog() and the objects being invoked are within the supported model types: Chain, AgentExecutor, BaseRetriever, RunnableSequence, RunnableParallel, RunnableBranch, SimpleChatModel, ChatPromptTemplate, RunnableLambda, RunnablePassthrough. Additional model types will be supported in the future.

Note

We patch invoke function for all supported langchain models, __call__ function for Chains, AgentExecutors models, and get_relevant_documents function for BaseRetrievers, so only when those functions are called MLflow autologs metrics and artifacts. If the model contains retrievers, we don’t support autologging the model because it requires saving loader_fn and persist_dir in order to load the model. Please log the model manually if you want to log the model with retrievers.

The following metrics and artifacts are logged by default (depending on the models involved):

Artifacts:

Artifact name

Explanation

table_action_records.html

Each action’s details, including chains, tools, llms, agents, retrievers.

table_session_analysis.html

Details about prompt and output for each prompt step; token usages; text analysis metrics

chat_html.html

LLM input and output details

llm_start_x_prompt_y.json

Includes prompt and kwargs passed during llm generate call

llm_end_x_generation_y.json

Includes llm_output of the LLM result

ent-<hash string of generation.text>.html

Visualization of the generation text using spacy “en_core_web_sm” model with style ent (if spacy is installed and the model is downloaded)

dep-<hash string of generation.text>.html

Visualization of the generation text using spacy “en_core_web_sm” model with style dep (if spacy is installed and the model is downloaded)

llm_new_tokens_x.json

Records new tokens added to the LLM during inference

chain_start_x.json

Records the inputs and chain related information during inference

chain_end_x.json

Records the chain outputs

tool_start_x.json

Records the tool’s name, descriptions information during inference

tool_end_x.json

Records observation of the tool

retriever_start_x.json

Records the retriever’s information during inference

retriever_end_x.json

Records the retriever’s result documents

agent_finish_x.json

Records final return value of the ActionAgent, including output and log

agent_action_x.json

Records the ActionAgent’s action details

on_text_x.json

Records the text during inference

inference_inputs_outputs.json

Input and output details for each inference call (logged by default, can be turned off by setting log_inputs_outputs=False when turn on autolog)

Metrics:

Metric types

Details

Basic Metrics

step, starts, ends, errors, text_ctr, chain_starts, chain_ends, llm_starts llm_ends, llm_streams, tool_starts, tool_ends, agent_ends, retriever_ends retriever_starts (they’re the count number of each component invocation)

Text Analysis Metrics

flesch_reading_ease, flesch_kincaid_grade, smog_index, coleman_liau_index automated_readability_index, dale_chall_readability_score, difficult_words, linsear_write_formula, gunning_fog, fernandez_huerta, szigriszt_pazos, gutierrez_polini, crawford, gulpease_index, osman (they’re the text analysis metrics of the generation text if textstat library is installed)

Note

Each inference call logs those artifacts into a separate directory named artifacts-<session_id>-<idx>, where session_id is randomly generated uuid, and idx is the index of the inference call. session_id is also preserved in the inference_inputs_outputs.json file, so you can easily find the corresponding artifacts for each inference call.

If you encounter any issues unexpected, please feel free to open an issue in MLflow Github repo.

An example of MLflow langchain autologging

import os
from operator import itemgetter

from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda

import mlflow

# Uncomment the following to use the full abilities of langchain autologgin
# %pip install `langchain_community>=0.0.16`
# These two libraries enable autologging to log text analysis related artifacts
# %pip install textstat spacy

assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable."

# Enable mlflow langchain autologging
# Note: We only support auto-logging models that do not contain retrievers
mlflow.langchain.autolog(
    log_input_examples=True,
    log_model_signatures=True,
    log_models=True,
    log_inputs_outputs=True,
    registered_model_name="lc_model",
)

prompt_with_history_str = """
Here is a history between you and a human: {chat_history}

Now, please answer this question: {question}
"""
prompt_with_history = PromptTemplate(
    input_variables=["chat_history", "question"], template=prompt_with_history_str
)


def extract_question(input):
    return input[-1]["content"]


def extract_history(input):
    return input[:-1]


llm = OpenAI(temperature=0.9)

# Build a chain with LCEL
chain_with_history = (
    {
        "question": itemgetter("messages") | RunnableLambda(extract_question),
        "chat_history": itemgetter("messages") | RunnableLambda(extract_history),
    }
    | prompt_with_history
    | llm
    | StrOutputParser()
)

inputs = {"messages": [{"role": "user", "content": "Who owns MLflow?"}]}

print(chain_with_history.invoke(inputs))
# sample output:
# "1. Databricks\n2. Microsoft\n3. Google\n4. Amazon\n\nEnter your answer: 1\n\n
# Correct! MLflow is an open source project developed by Databricks. ...

# We automatically log the model and trace related artifacts
# A model with name `lc_model` is registered, we can load it back as a PyFunc model
model_name = "lc_model"
model_version = 1
loaded_model = mlflow.pyfunc.load_model(f"models:/{model_name}/{model_version}")
print(loaded_model.predict(inputs))