# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ LangChain CallbackHandler that prints to streamlit. This is a special API that's imported and used by LangChain itself. Any updates to the public API (the StreamlitCallbackHandler constructor, and the entirety of LLMThoughtLabeler) *must* remain backwards-compatible to avoid breaking LangChain. This means that it's acceptable to add new optional kwargs to StreamlitCallbackHandler, but no new positional args or required kwargs should be added, and no existing args should be removed. If we need to overhaul the API, we must ensure that a compatible API continues to exist. Any major change to the StreamlitCallbackHandler should be tested by importing the API *from LangChain itself*. This module is lazy-loaded. """ # NOTE: We ignore all mypy import-not-found errors as top-level since # this module is optional and the langchain dependency is not installed # by default. # mypy: disable-error-code="import-not-found, unused-ignore, misc" # Deactivate unused argument errors for this file since we need lots of # unused arguments to comply with the LangChain callback interface. # ruff: noqa: ARG002 from __future__ import annotations import time from enum import Enum from typing import TYPE_CHECKING, Any, NamedTuple from langchain.callbacks.base import ( BaseCallbackHandler, ) from streamlit.runtime.metrics_util import gather_metrics if TYPE_CHECKING: from langchain.schema import ( AgentAction, AgentFinish, LLMResult, ) from streamlit.delta_generator import DeltaGenerator from streamlit.elements.lib.mutable_status_container import StatusContainer def _convert_newlines(text: str) -> str: """Convert newline characters to markdown newline sequences (space, space, newline). """ return text.replace("\n", " \n") # The maximum length of the "input_str" portion of a tool label. # Strings that are longer than this will be truncated with "..." MAX_TOOL_INPUT_STR_LENGTH = 60 class LLMThoughtState(Enum): # The LLM is thinking about what to do next. We don't know which tool we'll run. THINKING = "THINKING" # The LLM has decided to run a tool. We don't have results from the tool yet. RUNNING_TOOL = "RUNNING_TOOL" # We have results from the tool. COMPLETE = "COMPLETE" # The LLM completed with an error. ERROR = "ERROR" class ToolRecord(NamedTuple): name: str input_str: str class LLMThoughtLabeler: """ Generates markdown labels for LLMThought containers. Pass a custom subclass of this to StreamlitCallbackHandler to override its default labeling logic. """ def get_initial_label(self) -> str: """Return the markdown label for a new LLMThought that doesn't have an associated tool yet. """ return "Thinking..." def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str: """Return the label for an LLMThought that has an associated tool. Parameters ---------- tool The tool's ToolRecord is_complete True if the thought is complete; False if the thought is still receiving input. Returns ------- The markdown label for the thought's container. """ input_str = tool.input_str name = tool.name if name == "_Exception": name = "Parsing error" input_str_len = min(MAX_TOOL_INPUT_STR_LENGTH, len(input_str)) input_str = input_str[:input_str_len] if len(tool.input_str) > input_str_len: input_str = input_str + "..." input_str = input_str.replace("\n", " ") return f"**{name}:** {input_str}" def get_final_agent_thought_label(self) -> str: """Return the markdown label for the agent's final thought - the "Now I have the answer" thought, that doesn't involve a tool. """ return "**Complete!**" class LLMThought: """Encapsulates the Streamlit UI for a single LLM 'thought' during a LangChain Agent run. Each tool usage gets its own thought; and runs also generally having a concluding thought where the Agent determines that it has an answer to the prompt. Each thought gets its own expander UI. """ def __init__( self, parent_container: DeltaGenerator, labeler: LLMThoughtLabeler, expanded: bool, collapse_on_complete: bool, ) -> None: self._container = parent_container.status( labeler.get_initial_label(), expanded=expanded ) self._state = LLMThoughtState.THINKING self._llm_token_stream = "" self._llm_token_stream_placeholder: DeltaGenerator | None = None self._last_tool: ToolRecord | None = None self._collapse_on_complete = collapse_on_complete self._labeler = labeler @property def container(self) -> StatusContainer: """The container we're writing into.""" return self._container @property def last_tool(self) -> ToolRecord | None: """The last tool executed by this thought.""" return self._last_tool def _reset_llm_token_stream(self) -> None: if self._llm_token_stream_placeholder is not None: self._llm_token_stream_placeholder.markdown(self._llm_token_stream) self._llm_token_stream = "" self._llm_token_stream_placeholder = None def on_llm_start(self, serialized: dict[str, Any], prompts: list[str]) -> None: self._reset_llm_token_stream() def on_llm_new_token(self, token: str, **kwargs: Any) -> None: # This is only called when the LLM is initialized with `streaming=True` self._llm_token_stream += _convert_newlines(token) if self._llm_token_stream_placeholder is None: self._llm_token_stream_placeholder = self._container.empty() self._llm_token_stream_placeholder.markdown(self._llm_token_stream + "▕") def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: # `response` is the concatenation of all the tokens received by the LLM. # If we're receiving streaming tokens from `on_llm_new_token`, this response # data is redundant self._reset_llm_token_stream() # set the container status to complete self.complete(self._labeler.get_final_agent_thought_label()) def on_llm_error(self, error: BaseException, *args: Any, **kwargs: Any) -> None: self._container.exception(error) self._state = LLMThoughtState.ERROR self.complete("LLM encountered an error...") def on_tool_start( self, serialized: dict[str, Any], input_str: str, **kwargs: Any ) -> None: # Called with the name of the tool we're about to run (in `serialized[name]`), # and its input. We change our container's label to be the tool name. self._state = LLMThoughtState.RUNNING_TOOL tool_name = serialized["name"] self._last_tool = ToolRecord(name=tool_name, input_str=input_str) self._container.update( label=self._labeler.get_tool_label(self._last_tool, is_complete=False), state="running", ) if len(input_str) > MAX_TOOL_INPUT_STR_LENGTH: # output is printed later in on_tool_end self._container.markdown(f"**Input:**\n\n{input_str}\n\n**Output:**") def on_tool_end( self, output: str, color: str | None = None, observation_prefix: str | None = None, llm_prefix: str | None = None, **kwargs: Any, ) -> None: self._container.markdown(output) def on_tool_error(self, error: BaseException, *args: Any, **kwargs: Any) -> None: self._container.markdown("**Tool encountered an error...**") self._container.exception(error) self._container.update(state="error") def on_agent_action( self, action: AgentAction, color: str | None = None, **kwargs: Any ) -> Any: # Called when we're about to kick off a new tool. The `action` data # tells us the tool we're about to use, and the input we'll give it. # We don't output anything here, because we'll receive this same data # when `on_tool_start` is called immediately after. pass def complete(self, final_label: str | None = None) -> None: """Finish the thought.""" if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL: if self._last_tool is None: raise RuntimeError( "_last_tool should never be null when _state == RUNNING_TOOL" ) final_label = self._labeler.get_tool_label( self._last_tool, is_complete=True ) if self._last_tool and self._last_tool.name == "_Exception": self._state = LLMThoughtState.ERROR elif self._state != LLMThoughtState.ERROR: self._state = LLMThoughtState.COMPLETE if self._collapse_on_complete: # Add a quick delay to show the user the final output before we collapse time.sleep(0.25) self._container.update( label=final_label, expanded=False if self._collapse_on_complete else None, state="error" if self._state == LLMThoughtState.ERROR else "complete", ) class StreamlitCallbackHandler(BaseCallbackHandler): @gather_metrics("external.langchain.StreamlitCallbackHandler") def __init__( self, parent_container: DeltaGenerator, *, max_thought_containers: int = 4, expand_new_thoughts: bool = False, collapse_completed_thoughts: bool = False, thought_labeler: LLMThoughtLabeler | None = None, ) -> None: """Construct a new StreamlitCallbackHandler. This CallbackHandler is geared towards use with a LangChain Agent; it displays the Agent's LLM and tool-usage "thoughts" inside a series of Streamlit expanders. Parameters ---------- parent_container The `st.container` that will contain all the Streamlit elements that the Handler creates. max_thought_containers .. note:: This parameter is deprecated and is ignored in the latest version of the callback handler. The max number of completed LLM thought containers to show at once. When this threshold is reached, a new thought will cause the oldest thoughts to be collapsed into a "History" expander. Defaults to 4. expand_new_thoughts Each LLM "thought" gets its own `st.expander`. This param controls whether that expander is expanded by default. Defaults to False. collapse_completed_thoughts If True, LLM thought expanders will be collapsed when completed. Defaults to False. thought_labeler An optional custom LLMThoughtLabeler instance. If unspecified, the handler will use the default thought labeling logic. Defaults to None. """ self._parent_container = parent_container self._history_parent = parent_container.container() self._current_thought: LLMThought | None = None self._completed_thoughts: list[LLMThought] = [] self._max_thought_containers = max(max_thought_containers, 1) self._expand_new_thoughts = expand_new_thoughts self._collapse_completed_thoughts = collapse_completed_thoughts self._thought_labeler = thought_labeler or LLMThoughtLabeler() def _require_current_thought(self) -> LLMThought: """Return our current LLMThought. Raise an error if we have no current thought. """ if self._current_thought is None: raise RuntimeError("Current LLMThought is unexpectedly None!") return self._current_thought def _get_last_completed_thought(self) -> LLMThought | None: """Return our most recent completed LLMThought, or None if we don't have one.""" if len(self._completed_thoughts) > 0: return self._completed_thoughts[len(self._completed_thoughts) - 1] return None def _complete_current_thought(self, final_label: str | None = None) -> None: """Complete the current thought, optionally assigning it a new label. Add it to our _completed_thoughts list. """ thought = self._require_current_thought() thought.complete(final_label) self._completed_thoughts.append(thought) self._current_thought = None def on_llm_start( self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any ) -> None: if self._current_thought is None: self._current_thought = LLMThought( parent_container=self._parent_container, expanded=self._expand_new_thoughts, collapse_on_complete=self._collapse_completed_thoughts, labeler=self._thought_labeler, ) self._current_thought.on_llm_start(serialized, prompts) # We don't prune_old_thought_containers here, because our container won't # be visible until it has a child. def on_llm_new_token(self, token: str, **kwargs: Any) -> None: self._require_current_thought().on_llm_new_token(token, **kwargs) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self._require_current_thought().on_llm_end(response, **kwargs) def on_llm_error(self, error: BaseException, *args: Any, **kwargs: Any) -> None: self._require_current_thought().on_llm_error(error, **kwargs) def on_tool_start( self, serialized: dict[str, Any], input_str: str, **kwargs: Any ) -> None: self._require_current_thought().on_tool_start(serialized, input_str, **kwargs) def on_tool_end( self, output: str, color: str | None = None, observation_prefix: str | None = None, llm_prefix: str | None = None, **kwargs: Any, ) -> None: self._require_current_thought().on_tool_end( output, color, observation_prefix, llm_prefix, **kwargs ) self._complete_current_thought() def on_tool_error(self, error: BaseException, *args: Any, **kwargs: Any) -> None: self._require_current_thought().on_tool_error(error, **kwargs) def on_agent_action( self, action: AgentAction, color: str | None = None, **kwargs: Any ) -> Any: self._require_current_thought().on_agent_action(action, color, **kwargs) def on_agent_finish( self, finish: AgentFinish, color: str | None = None, **kwargs: Any ) -> None: if self._current_thought is not None: self._current_thought.complete( self._thought_labeler.get_final_agent_thought_label() ) self._current_thought = None