410 lines
15 KiB
Python
410 lines
15 KiB
Python
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
LangChain CallbackHandler that prints to streamlit.
|
|
|
|
This is a special API that's imported and used by LangChain itself. Any updates
|
|
to the public API (the StreamlitCallbackHandler constructor, and the entirety
|
|
of LLMThoughtLabeler) *must* remain backwards-compatible to avoid breaking
|
|
LangChain.
|
|
|
|
This means that it's acceptable to add new optional kwargs to StreamlitCallbackHandler,
|
|
but no new positional args or required kwargs should be added, and no existing
|
|
args should be removed. If we need to overhaul the API, we must ensure that a
|
|
compatible API continues to exist.
|
|
|
|
Any major change to the StreamlitCallbackHandler should be tested by importing
|
|
the API *from LangChain itself*.
|
|
|
|
This module is lazy-loaded.
|
|
"""
|
|
|
|
# NOTE: We ignore all mypy import-not-found errors as top-level since
|
|
# this module is optional and the langchain dependency is not installed
|
|
# by default.
|
|
# mypy: disable-error-code="import-not-found, unused-ignore, misc"
|
|
|
|
# Deactivate unused argument errors for this file since we need lots of
|
|
# unused arguments to comply with the LangChain callback interface.
|
|
# ruff: noqa: ARG002
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING, Any, NamedTuple
|
|
|
|
from langchain.callbacks.base import (
|
|
BaseCallbackHandler,
|
|
)
|
|
|
|
from streamlit.runtime.metrics_util import gather_metrics
|
|
|
|
if TYPE_CHECKING:
|
|
from langchain.schema import (
|
|
AgentAction,
|
|
AgentFinish,
|
|
LLMResult,
|
|
)
|
|
|
|
from streamlit.delta_generator import DeltaGenerator
|
|
from streamlit.elements.lib.mutable_status_container import StatusContainer
|
|
|
|
|
|
def _convert_newlines(text: str) -> str:
|
|
"""Convert newline characters to markdown newline sequences
|
|
(space, space, newline).
|
|
"""
|
|
return text.replace("\n", " \n")
|
|
|
|
|
|
# The maximum length of the "input_str" portion of a tool label.
|
|
# Strings that are longer than this will be truncated with "..."
|
|
MAX_TOOL_INPUT_STR_LENGTH = 60
|
|
|
|
|
|
class LLMThoughtState(Enum):
|
|
# The LLM is thinking about what to do next. We don't know which tool we'll run.
|
|
THINKING = "THINKING"
|
|
# The LLM has decided to run a tool. We don't have results from the tool yet.
|
|
RUNNING_TOOL = "RUNNING_TOOL"
|
|
# We have results from the tool.
|
|
COMPLETE = "COMPLETE"
|
|
# The LLM completed with an error.
|
|
ERROR = "ERROR"
|
|
|
|
|
|
class ToolRecord(NamedTuple):
|
|
name: str
|
|
input_str: str
|
|
|
|
|
|
class LLMThoughtLabeler:
|
|
"""
|
|
Generates markdown labels for LLMThought containers. Pass a custom
|
|
subclass of this to StreamlitCallbackHandler to override its default
|
|
labeling logic.
|
|
"""
|
|
|
|
def get_initial_label(self) -> str:
|
|
"""Return the markdown label for a new LLMThought that doesn't have
|
|
an associated tool yet.
|
|
"""
|
|
return "Thinking..."
|
|
|
|
def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
|
|
"""Return the label for an LLMThought that has an associated
|
|
tool.
|
|
|
|
Parameters
|
|
----------
|
|
tool
|
|
The tool's ToolRecord
|
|
|
|
is_complete
|
|
True if the thought is complete; False if the thought
|
|
is still receiving input.
|
|
|
|
Returns
|
|
-------
|
|
The markdown label for the thought's container.
|
|
|
|
"""
|
|
input_str = tool.input_str
|
|
name = tool.name
|
|
if name == "_Exception":
|
|
name = "Parsing error"
|
|
input_str_len = min(MAX_TOOL_INPUT_STR_LENGTH, len(input_str))
|
|
input_str = input_str[:input_str_len]
|
|
if len(tool.input_str) > input_str_len:
|
|
input_str = input_str + "..."
|
|
input_str = input_str.replace("\n", " ")
|
|
return f"**{name}:** {input_str}"
|
|
|
|
def get_final_agent_thought_label(self) -> str:
|
|
"""Return the markdown label for the agent's final thought -
|
|
the "Now I have the answer" thought, that doesn't involve
|
|
a tool.
|
|
"""
|
|
return "**Complete!**"
|
|
|
|
|
|
class LLMThought:
|
|
"""Encapsulates the Streamlit UI for a single LLM 'thought' during a LangChain Agent
|
|
run. Each tool usage gets its own thought; and runs also generally having a
|
|
concluding thought where the Agent determines that it has an answer to the prompt.
|
|
|
|
Each thought gets its own expander UI.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
parent_container: DeltaGenerator,
|
|
labeler: LLMThoughtLabeler,
|
|
expanded: bool,
|
|
collapse_on_complete: bool,
|
|
) -> None:
|
|
self._container = parent_container.status(
|
|
labeler.get_initial_label(), expanded=expanded
|
|
)
|
|
|
|
self._state = LLMThoughtState.THINKING
|
|
self._llm_token_stream = ""
|
|
self._llm_token_stream_placeholder: DeltaGenerator | None = None
|
|
self._last_tool: ToolRecord | None = None
|
|
self._collapse_on_complete = collapse_on_complete
|
|
self._labeler = labeler
|
|
|
|
@property
|
|
def container(self) -> StatusContainer:
|
|
"""The container we're writing into."""
|
|
return self._container
|
|
|
|
@property
|
|
def last_tool(self) -> ToolRecord | None:
|
|
"""The last tool executed by this thought."""
|
|
return self._last_tool
|
|
|
|
def _reset_llm_token_stream(self) -> None:
|
|
if self._llm_token_stream_placeholder is not None:
|
|
self._llm_token_stream_placeholder.markdown(self._llm_token_stream)
|
|
|
|
self._llm_token_stream = ""
|
|
self._llm_token_stream_placeholder = None
|
|
|
|
def on_llm_start(self, serialized: dict[str, Any], prompts: list[str]) -> None:
|
|
self._reset_llm_token_stream()
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
# This is only called when the LLM is initialized with `streaming=True`
|
|
self._llm_token_stream += _convert_newlines(token)
|
|
if self._llm_token_stream_placeholder is None:
|
|
self._llm_token_stream_placeholder = self._container.empty()
|
|
self._llm_token_stream_placeholder.markdown(self._llm_token_stream + "▕")
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
# `response` is the concatenation of all the tokens received by the LLM.
|
|
# If we're receiving streaming tokens from `on_llm_new_token`, this response
|
|
# data is redundant
|
|
self._reset_llm_token_stream()
|
|
# set the container status to complete
|
|
self.complete(self._labeler.get_final_agent_thought_label())
|
|
|
|
def on_llm_error(self, error: BaseException, *args: Any, **kwargs: Any) -> None:
|
|
self._container.exception(error)
|
|
self._state = LLMThoughtState.ERROR
|
|
self.complete("LLM encountered an error...")
|
|
|
|
def on_tool_start(
|
|
self, serialized: dict[str, Any], input_str: str, **kwargs: Any
|
|
) -> None:
|
|
# Called with the name of the tool we're about to run (in `serialized[name]`),
|
|
# and its input. We change our container's label to be the tool name.
|
|
self._state = LLMThoughtState.RUNNING_TOOL
|
|
tool_name = serialized["name"]
|
|
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
|
|
self._container.update(
|
|
label=self._labeler.get_tool_label(self._last_tool, is_complete=False),
|
|
state="running",
|
|
)
|
|
if len(input_str) > MAX_TOOL_INPUT_STR_LENGTH:
|
|
# output is printed later in on_tool_end
|
|
self._container.markdown(f"**Input:**\n\n{input_str}\n\n**Output:**")
|
|
|
|
def on_tool_end(
|
|
self,
|
|
output: str,
|
|
color: str | None = None,
|
|
observation_prefix: str | None = None,
|
|
llm_prefix: str | None = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._container.markdown(output)
|
|
|
|
def on_tool_error(self, error: BaseException, *args: Any, **kwargs: Any) -> None:
|
|
self._container.markdown("**Tool encountered an error...**")
|
|
self._container.exception(error)
|
|
self._container.update(state="error")
|
|
|
|
def on_agent_action(
|
|
self, action: AgentAction, color: str | None = None, **kwargs: Any
|
|
) -> Any:
|
|
# Called when we're about to kick off a new tool. The `action` data
|
|
# tells us the tool we're about to use, and the input we'll give it.
|
|
# We don't output anything here, because we'll receive this same data
|
|
# when `on_tool_start` is called immediately after.
|
|
pass
|
|
|
|
def complete(self, final_label: str | None = None) -> None:
|
|
"""Finish the thought."""
|
|
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
|
|
if self._last_tool is None:
|
|
raise RuntimeError(
|
|
"_last_tool should never be null when _state == RUNNING_TOOL"
|
|
)
|
|
final_label = self._labeler.get_tool_label(
|
|
self._last_tool, is_complete=True
|
|
)
|
|
|
|
if self._last_tool and self._last_tool.name == "_Exception":
|
|
self._state = LLMThoughtState.ERROR
|
|
elif self._state != LLMThoughtState.ERROR:
|
|
self._state = LLMThoughtState.COMPLETE
|
|
|
|
if self._collapse_on_complete:
|
|
# Add a quick delay to show the user the final output before we collapse
|
|
time.sleep(0.25)
|
|
|
|
self._container.update(
|
|
label=final_label,
|
|
expanded=False if self._collapse_on_complete else None,
|
|
state="error" if self._state == LLMThoughtState.ERROR else "complete",
|
|
)
|
|
|
|
|
|
class StreamlitCallbackHandler(BaseCallbackHandler):
|
|
@gather_metrics("external.langchain.StreamlitCallbackHandler")
|
|
def __init__(
|
|
self,
|
|
parent_container: DeltaGenerator,
|
|
*,
|
|
max_thought_containers: int = 4,
|
|
expand_new_thoughts: bool = False,
|
|
collapse_completed_thoughts: bool = False,
|
|
thought_labeler: LLMThoughtLabeler | None = None,
|
|
) -> None:
|
|
"""Construct a new StreamlitCallbackHandler. This CallbackHandler is geared
|
|
towards use with a LangChain Agent; it displays the Agent's LLM and tool-usage
|
|
"thoughts" inside a series of Streamlit expanders.
|
|
|
|
Parameters
|
|
----------
|
|
parent_container
|
|
The `st.container` that will contain all the Streamlit elements that the
|
|
Handler creates.
|
|
|
|
max_thought_containers
|
|
|
|
.. note::
|
|
This parameter is deprecated and is ignored in the latest version of
|
|
the callback handler.
|
|
|
|
The max number of completed LLM thought containers to show at once. When
|
|
this threshold is reached, a new thought will cause the oldest thoughts to
|
|
be collapsed into a "History" expander. Defaults to 4.
|
|
|
|
expand_new_thoughts
|
|
Each LLM "thought" gets its own `st.expander`. This param controls whether
|
|
that expander is expanded by default. Defaults to False.
|
|
|
|
collapse_completed_thoughts
|
|
If True, LLM thought expanders will be collapsed when completed.
|
|
Defaults to False.
|
|
|
|
thought_labeler
|
|
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
|
|
will use the default thought labeling logic. Defaults to None.
|
|
"""
|
|
self._parent_container = parent_container
|
|
self._history_parent = parent_container.container()
|
|
self._current_thought: LLMThought | None = None
|
|
self._completed_thoughts: list[LLMThought] = []
|
|
self._max_thought_containers = max(max_thought_containers, 1)
|
|
self._expand_new_thoughts = expand_new_thoughts
|
|
self._collapse_completed_thoughts = collapse_completed_thoughts
|
|
self._thought_labeler = thought_labeler or LLMThoughtLabeler()
|
|
|
|
def _require_current_thought(self) -> LLMThought:
|
|
"""Return our current LLMThought. Raise an error if we have no current
|
|
thought.
|
|
"""
|
|
if self._current_thought is None:
|
|
raise RuntimeError("Current LLMThought is unexpectedly None!")
|
|
return self._current_thought
|
|
|
|
def _get_last_completed_thought(self) -> LLMThought | None:
|
|
"""Return our most recent completed LLMThought, or None if we don't have one."""
|
|
if len(self._completed_thoughts) > 0:
|
|
return self._completed_thoughts[len(self._completed_thoughts) - 1]
|
|
return None
|
|
|
|
def _complete_current_thought(self, final_label: str | None = None) -> None:
|
|
"""Complete the current thought, optionally assigning it a new label.
|
|
Add it to our _completed_thoughts list.
|
|
"""
|
|
thought = self._require_current_thought()
|
|
thought.complete(final_label)
|
|
self._completed_thoughts.append(thought)
|
|
self._current_thought = None
|
|
|
|
def on_llm_start(
|
|
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
|
) -> None:
|
|
if self._current_thought is None:
|
|
self._current_thought = LLMThought(
|
|
parent_container=self._parent_container,
|
|
expanded=self._expand_new_thoughts,
|
|
collapse_on_complete=self._collapse_completed_thoughts,
|
|
labeler=self._thought_labeler,
|
|
)
|
|
|
|
self._current_thought.on_llm_start(serialized, prompts)
|
|
|
|
# We don't prune_old_thought_containers here, because our container won't
|
|
# be visible until it has a child.
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
self._require_current_thought().on_llm_new_token(token, **kwargs)
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
self._require_current_thought().on_llm_end(response, **kwargs)
|
|
|
|
def on_llm_error(self, error: BaseException, *args: Any, **kwargs: Any) -> None:
|
|
self._require_current_thought().on_llm_error(error, **kwargs)
|
|
|
|
def on_tool_start(
|
|
self, serialized: dict[str, Any], input_str: str, **kwargs: Any
|
|
) -> None:
|
|
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
|
|
|
|
def on_tool_end(
|
|
self,
|
|
output: str,
|
|
color: str | None = None,
|
|
observation_prefix: str | None = None,
|
|
llm_prefix: str | None = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._require_current_thought().on_tool_end(
|
|
output, color, observation_prefix, llm_prefix, **kwargs
|
|
)
|
|
self._complete_current_thought()
|
|
|
|
def on_tool_error(self, error: BaseException, *args: Any, **kwargs: Any) -> None:
|
|
self._require_current_thought().on_tool_error(error, **kwargs)
|
|
|
|
def on_agent_action(
|
|
self, action: AgentAction, color: str | None = None, **kwargs: Any
|
|
) -> Any:
|
|
self._require_current_thought().on_agent_action(action, color, **kwargs)
|
|
|
|
def on_agent_finish(
|
|
self, finish: AgentFinish, color: str | None = None, **kwargs: Any
|
|
) -> None:
|
|
if self._current_thought is not None:
|
|
self._current_thought.complete(
|
|
self._thought_labeler.get_final_agent_thought_label()
|
|
)
|
|
self._current_thought = None
|