team-10/env/Lib/site-packages/streamlit/external/langchain/streamlit_callback_handler.py
2025-08-02 07:34:44 +02:00

410 lines
15 KiB
Python

# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
LangChain CallbackHandler that prints to streamlit.
This is a special API that's imported and used by LangChain itself. Any updates
to the public API (the StreamlitCallbackHandler constructor, and the entirety
of LLMThoughtLabeler) *must* remain backwards-compatible to avoid breaking
LangChain.
This means that it's acceptable to add new optional kwargs to StreamlitCallbackHandler,
but no new positional args or required kwargs should be added, and no existing
args should be removed. If we need to overhaul the API, we must ensure that a
compatible API continues to exist.
Any major change to the StreamlitCallbackHandler should be tested by importing
the API *from LangChain itself*.
This module is lazy-loaded.
"""
# NOTE: We ignore all mypy import-not-found errors as top-level since
# this module is optional and the langchain dependency is not installed
# by default.
# mypy: disable-error-code="import-not-found, unused-ignore, misc"
# Deactivate unused argument errors for this file since we need lots of
# unused arguments to comply with the LangChain callback interface.
# ruff: noqa: ARG002
from __future__ import annotations
import time
from enum import Enum
from typing import TYPE_CHECKING, Any, NamedTuple
from langchain.callbacks.base import (
BaseCallbackHandler,
)
from streamlit.runtime.metrics_util import gather_metrics
if TYPE_CHECKING:
from langchain.schema import (
AgentAction,
AgentFinish,
LLMResult,
)
from streamlit.delta_generator import DeltaGenerator
from streamlit.elements.lib.mutable_status_container import StatusContainer
def _convert_newlines(text: str) -> str:
"""Convert newline characters to markdown newline sequences
(space, space, newline).
"""
return text.replace("\n", " \n")
# The maximum length of the "input_str" portion of a tool label.
# Strings that are longer than this will be truncated with "..."
MAX_TOOL_INPUT_STR_LENGTH = 60
class LLMThoughtState(Enum):
# The LLM is thinking about what to do next. We don't know which tool we'll run.
THINKING = "THINKING"
# The LLM has decided to run a tool. We don't have results from the tool yet.
RUNNING_TOOL = "RUNNING_TOOL"
# We have results from the tool.
COMPLETE = "COMPLETE"
# The LLM completed with an error.
ERROR = "ERROR"
class ToolRecord(NamedTuple):
name: str
input_str: str
class LLMThoughtLabeler:
"""
Generates markdown labels for LLMThought containers. Pass a custom
subclass of this to StreamlitCallbackHandler to override its default
labeling logic.
"""
def get_initial_label(self) -> str:
"""Return the markdown label for a new LLMThought that doesn't have
an associated tool yet.
"""
return "Thinking..."
def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
"""Return the label for an LLMThought that has an associated
tool.
Parameters
----------
tool
The tool's ToolRecord
is_complete
True if the thought is complete; False if the thought
is still receiving input.
Returns
-------
The markdown label for the thought's container.
"""
input_str = tool.input_str
name = tool.name
if name == "_Exception":
name = "Parsing error"
input_str_len = min(MAX_TOOL_INPUT_STR_LENGTH, len(input_str))
input_str = input_str[:input_str_len]
if len(tool.input_str) > input_str_len:
input_str = input_str + "..."
input_str = input_str.replace("\n", " ")
return f"**{name}:** {input_str}"
def get_final_agent_thought_label(self) -> str:
"""Return the markdown label for the agent's final thought -
the "Now I have the answer" thought, that doesn't involve
a tool.
"""
return "**Complete!**"
class LLMThought:
"""Encapsulates the Streamlit UI for a single LLM 'thought' during a LangChain Agent
run. Each tool usage gets its own thought; and runs also generally having a
concluding thought where the Agent determines that it has an answer to the prompt.
Each thought gets its own expander UI.
"""
def __init__(
self,
parent_container: DeltaGenerator,
labeler: LLMThoughtLabeler,
expanded: bool,
collapse_on_complete: bool,
) -> None:
self._container = parent_container.status(
labeler.get_initial_label(), expanded=expanded
)
self._state = LLMThoughtState.THINKING
self._llm_token_stream = ""
self._llm_token_stream_placeholder: DeltaGenerator | None = None
self._last_tool: ToolRecord | None = None
self._collapse_on_complete = collapse_on_complete
self._labeler = labeler
@property
def container(self) -> StatusContainer:
"""The container we're writing into."""
return self._container
@property
def last_tool(self) -> ToolRecord | None:
"""The last tool executed by this thought."""
return self._last_tool
def _reset_llm_token_stream(self) -> None:
if self._llm_token_stream_placeholder is not None:
self._llm_token_stream_placeholder.markdown(self._llm_token_stream)
self._llm_token_stream = ""
self._llm_token_stream_placeholder = None
def on_llm_start(self, serialized: dict[str, Any], prompts: list[str]) -> None:
self._reset_llm_token_stream()
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
# This is only called when the LLM is initialized with `streaming=True`
self._llm_token_stream += _convert_newlines(token)
if self._llm_token_stream_placeholder is None:
self._llm_token_stream_placeholder = self._container.empty()
self._llm_token_stream_placeholder.markdown(self._llm_token_stream + "")
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
# `response` is the concatenation of all the tokens received by the LLM.
# If we're receiving streaming tokens from `on_llm_new_token`, this response
# data is redundant
self._reset_llm_token_stream()
# set the container status to complete
self.complete(self._labeler.get_final_agent_thought_label())
def on_llm_error(self, error: BaseException, *args: Any, **kwargs: Any) -> None:
self._container.exception(error)
self._state = LLMThoughtState.ERROR
self.complete("LLM encountered an error...")
def on_tool_start(
self, serialized: dict[str, Any], input_str: str, **kwargs: Any
) -> None:
# Called with the name of the tool we're about to run (in `serialized[name]`),
# and its input. We change our container's label to be the tool name.
self._state = LLMThoughtState.RUNNING_TOOL
tool_name = serialized["name"]
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
self._container.update(
label=self._labeler.get_tool_label(self._last_tool, is_complete=False),
state="running",
)
if len(input_str) > MAX_TOOL_INPUT_STR_LENGTH:
# output is printed later in on_tool_end
self._container.markdown(f"**Input:**\n\n{input_str}\n\n**Output:**")
def on_tool_end(
self,
output: str,
color: str | None = None,
observation_prefix: str | None = None,
llm_prefix: str | None = None,
**kwargs: Any,
) -> None:
self._container.markdown(output)
def on_tool_error(self, error: BaseException, *args: Any, **kwargs: Any) -> None:
self._container.markdown("**Tool encountered an error...**")
self._container.exception(error)
self._container.update(state="error")
def on_agent_action(
self, action: AgentAction, color: str | None = None, **kwargs: Any
) -> Any:
# Called when we're about to kick off a new tool. The `action` data
# tells us the tool we're about to use, and the input we'll give it.
# We don't output anything here, because we'll receive this same data
# when `on_tool_start` is called immediately after.
pass
def complete(self, final_label: str | None = None) -> None:
"""Finish the thought."""
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
if self._last_tool is None:
raise RuntimeError(
"_last_tool should never be null when _state == RUNNING_TOOL"
)
final_label = self._labeler.get_tool_label(
self._last_tool, is_complete=True
)
if self._last_tool and self._last_tool.name == "_Exception":
self._state = LLMThoughtState.ERROR
elif self._state != LLMThoughtState.ERROR:
self._state = LLMThoughtState.COMPLETE
if self._collapse_on_complete:
# Add a quick delay to show the user the final output before we collapse
time.sleep(0.25)
self._container.update(
label=final_label,
expanded=False if self._collapse_on_complete else None,
state="error" if self._state == LLMThoughtState.ERROR else "complete",
)
class StreamlitCallbackHandler(BaseCallbackHandler):
@gather_metrics("external.langchain.StreamlitCallbackHandler")
def __init__(
self,
parent_container: DeltaGenerator,
*,
max_thought_containers: int = 4,
expand_new_thoughts: bool = False,
collapse_completed_thoughts: bool = False,
thought_labeler: LLMThoughtLabeler | None = None,
) -> None:
"""Construct a new StreamlitCallbackHandler. This CallbackHandler is geared
towards use with a LangChain Agent; it displays the Agent's LLM and tool-usage
"thoughts" inside a series of Streamlit expanders.
Parameters
----------
parent_container
The `st.container` that will contain all the Streamlit elements that the
Handler creates.
max_thought_containers
.. note::
This parameter is deprecated and is ignored in the latest version of
the callback handler.
The max number of completed LLM thought containers to show at once. When
this threshold is reached, a new thought will cause the oldest thoughts to
be collapsed into a "History" expander. Defaults to 4.
expand_new_thoughts
Each LLM "thought" gets its own `st.expander`. This param controls whether
that expander is expanded by default. Defaults to False.
collapse_completed_thoughts
If True, LLM thought expanders will be collapsed when completed.
Defaults to False.
thought_labeler
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
will use the default thought labeling logic. Defaults to None.
"""
self._parent_container = parent_container
self._history_parent = parent_container.container()
self._current_thought: LLMThought | None = None
self._completed_thoughts: list[LLMThought] = []
self._max_thought_containers = max(max_thought_containers, 1)
self._expand_new_thoughts = expand_new_thoughts
self._collapse_completed_thoughts = collapse_completed_thoughts
self._thought_labeler = thought_labeler or LLMThoughtLabeler()
def _require_current_thought(self) -> LLMThought:
"""Return our current LLMThought. Raise an error if we have no current
thought.
"""
if self._current_thought is None:
raise RuntimeError("Current LLMThought is unexpectedly None!")
return self._current_thought
def _get_last_completed_thought(self) -> LLMThought | None:
"""Return our most recent completed LLMThought, or None if we don't have one."""
if len(self._completed_thoughts) > 0:
return self._completed_thoughts[len(self._completed_thoughts) - 1]
return None
def _complete_current_thought(self, final_label: str | None = None) -> None:
"""Complete the current thought, optionally assigning it a new label.
Add it to our _completed_thoughts list.
"""
thought = self._require_current_thought()
thought.complete(final_label)
self._completed_thoughts.append(thought)
self._current_thought = None
def on_llm_start(
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
if self._current_thought is None:
self._current_thought = LLMThought(
parent_container=self._parent_container,
expanded=self._expand_new_thoughts,
collapse_on_complete=self._collapse_completed_thoughts,
labeler=self._thought_labeler,
)
self._current_thought.on_llm_start(serialized, prompts)
# We don't prune_old_thought_containers here, because our container won't
# be visible until it has a child.
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self._require_current_thought().on_llm_new_token(token, **kwargs)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self._require_current_thought().on_llm_end(response, **kwargs)
def on_llm_error(self, error: BaseException, *args: Any, **kwargs: Any) -> None:
self._require_current_thought().on_llm_error(error, **kwargs)
def on_tool_start(
self, serialized: dict[str, Any], input_str: str, **kwargs: Any
) -> None:
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
def on_tool_end(
self,
output: str,
color: str | None = None,
observation_prefix: str | None = None,
llm_prefix: str | None = None,
**kwargs: Any,
) -> None:
self._require_current_thought().on_tool_end(
output, color, observation_prefix, llm_prefix, **kwargs
)
self._complete_current_thought()
def on_tool_error(self, error: BaseException, *args: Any, **kwargs: Any) -> None:
self._require_current_thought().on_tool_error(error, **kwargs)
def on_agent_action(
self, action: AgentAction, color: str | None = None, **kwargs: Any
) -> Any:
self._require_current_thought().on_agent_action(action, color, **kwargs)
def on_agent_finish(
self, finish: AgentFinish, color: str | None = None, **kwargs: Any
) -> None:
if self._current_thought is not None:
self._current_thought.complete(
self._thought_labeler.get_final_agent_thought_label()
)
self._current_thought = None