1440 lines
53 KiB
Python
1440 lines
53 KiB
Python
# mypy: allow-untyped-decorators
|
|
|
|
"""
|
|
This module implements TorchDynamo's core frame conversion functionality, transforming Python
|
|
frames into FX graphs. It handles:
|
|
|
|
- Frame analysis and bytecode transformation
|
|
- Guard creation and management for dynamic behaviors
|
|
- Cache management for recompilation
|
|
- Error handling and fallback mechanisms
|
|
|
|
Key classes:
|
|
- ConvertFrame: Main entry point for frame conversion with error handling
|
|
- ConvertFrameAssert: Implements core frame to graph conversion logic
|
|
- Tracker: Tracks input/output code objects during conversion
|
|
- CatchErrorsWrapper: Provides error handling and suppression logic
|
|
|
|
The conversion process preserves program semantics while enabling optimizations
|
|
through torch.compile() and related systems.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
import contextlib
|
|
import cProfile
|
|
import dis
|
|
import functools
|
|
import gc
|
|
import itertools
|
|
import logging
|
|
import os
|
|
import pstats
|
|
import random
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
import traceback
|
|
import typing
|
|
import weakref
|
|
from pathlib import Path
|
|
from types import CellType, CodeType, FunctionType, ModuleType
|
|
from typing import Any, Callable, Optional, TypeVar, Union
|
|
from typing_extensions import ParamSpec
|
|
from weakref import ReferenceType
|
|
|
|
import torch
|
|
import torch._logging
|
|
from torch._C._dynamo.guards import GlobalStateGuard
|
|
from torch._dynamo.distributed import get_compile_pg
|
|
from torch._dynamo.symbolic_convert import TensorifyState
|
|
from torch._guards import compile_context, CompileContext, CompileId, tracing
|
|
from torch._logging import structured
|
|
from torch._utils_internal import (
|
|
compile_time_strobelight_meta,
|
|
justknobs_check,
|
|
maybe_upload_prof_stats_to_manifold,
|
|
signpost_event,
|
|
)
|
|
from torch.fx._lazy_graph_module import _use_lazy_graph_module
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
ConstraintViolationError,
|
|
GuardOnDataDependentSymNode,
|
|
)
|
|
from torch.fx.graph_module import _forward_from_src as original_forward_from_src
|
|
from torch.monitor import _WaitCounter
|
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
|
from torch.utils._python_dispatch import (
|
|
_disable_current_modes,
|
|
is_in_torch_dispatch_mode,
|
|
)
|
|
from torch.utils._traceback import CapturedTraceback, format_traceback_short
|
|
|
|
from . import config, exc, graph_break_hints, trace_rules
|
|
from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
|
|
from .bytecode_transformation import (
|
|
check_inst_exn_tab_entries_valid,
|
|
Instruction,
|
|
is_generator,
|
|
propagate_inst_exn_table_entries,
|
|
transform_code_object,
|
|
)
|
|
from .cache_size import (
|
|
CacheSizeRelevantForFrame,
|
|
compute_cache_size,
|
|
exceeds_recompile_limit,
|
|
is_recompilation,
|
|
)
|
|
from .eval_frame import (
|
|
always_optimize_code_objects,
|
|
dynamo_tls,
|
|
skip_code,
|
|
TorchPatcher,
|
|
)
|
|
from .exc import (
|
|
augment_exc_message,
|
|
BackendCompilerFailed,
|
|
FailOnRecompileLimitHit,
|
|
format_error_msg,
|
|
InternalTorchDynamoError,
|
|
RecompileLimitExceeded,
|
|
ShortenTraceback,
|
|
SkipCodeRecursiveException,
|
|
TorchRuntimeError,
|
|
UncapturedHigherOrderOpError,
|
|
unimplemented_v2,
|
|
Unsupported,
|
|
)
|
|
from .guards import (
|
|
CheckFunctionManager,
|
|
get_and_maybe_log_recompilation_reasons,
|
|
GuardedCode,
|
|
)
|
|
from .hooks import Hooks
|
|
from .pgo import put_code_state
|
|
from .replay_record import ExecutionRecord
|
|
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
|
|
from .symbolic_convert import (
|
|
DistributedState,
|
|
ExceptionStack,
|
|
InstructionTranslator,
|
|
LocalState,
|
|
SpeculationLog,
|
|
)
|
|
from .trace_rules import is_numpy
|
|
from .types import ConvertFrameReturn, FrameAction, FrameExecStrategy, wrap_guarded_code
|
|
from .utils import (
|
|
chromium_event_timed,
|
|
CleanupManager,
|
|
CompileTimeInstructionCounter,
|
|
counters,
|
|
dynamo_timed,
|
|
format_bytecode,
|
|
gen_record_file_name,
|
|
get_metrics_context,
|
|
increment_frame,
|
|
is_namedtuple,
|
|
istype,
|
|
LazyString,
|
|
orig_code_map,
|
|
reset_graph_break_dup_checker,
|
|
setup_compile_debug,
|
|
to_int_us,
|
|
troubleshooting_url,
|
|
write_record_to_file,
|
|
)
|
|
from .variables.torch_function import torch_function_mode_stack_state_mgr
|
|
|
|
|
|
np: Optional[ModuleType]
|
|
try:
|
|
import numpy as np
|
|
except ModuleNotFoundError:
|
|
np = None
|
|
|
|
|
|
if typing.TYPE_CHECKING:
|
|
from .backends.registry import CompilerFn
|
|
from .repro.after_dynamo import WrapBackendDebug
|
|
from .types import BytecodeHook, CacheEntry, DynamoFrameType
|
|
from .variables.builder import FrameStateSizeEntry
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
|
|
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
|
|
|
|
|
|
compile_lock = threading.RLock()
|
|
|
|
_T = TypeVar("_T")
|
|
_P = ParamSpec("_P")
|
|
|
|
|
|
class TODO_UNKNOWN:
|
|
pass
|
|
|
|
|
|
class Tracker:
|
|
def __init__(self) -> None:
|
|
self.seen: list[ReferenceType[CodeType]] = []
|
|
self.seen_ids: set[int] = set()
|
|
|
|
def add(self, strong_obj: CodeType) -> None:
|
|
idx = id(strong_obj)
|
|
if idx not in self.seen_ids:
|
|
obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx))
|
|
self.seen.append(obj)
|
|
self.seen_ids.add(idx)
|
|
|
|
def __contains__(self, item: CodeType) -> bool:
|
|
return id(item) in self.seen_ids
|
|
|
|
def clear(self) -> None:
|
|
self.seen.clear()
|
|
self.seen_ids.clear()
|
|
|
|
|
|
input_codes = Tracker()
|
|
output_codes = Tracker()
|
|
|
|
initial_global_state: Optional[GlobalStateGuard] = None
|
|
|
|
|
|
@functools.wraps(original_forward_from_src)
|
|
def fx_forward_from_src_skip_result(
|
|
src: str, globals: dict[str, Any], co_fields: Optional[dict[str, str]] = None
|
|
) -> FunctionType:
|
|
# we monkey patch FX to prevent infinite loop of trying to convert
|
|
# our generated code
|
|
result = original_forward_from_src(src, globals, co_fields)
|
|
skip_code(result.__code__)
|
|
return result
|
|
|
|
|
|
def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
|
"""
|
|
Context manager to:
|
|
1) Save/restore torch.is_grad_enabled() state
|
|
2) Save/restore python random state
|
|
3) Save/restore torch random state
|
|
4) Monkey patch torch.fx.graph_module._forward_from_src
|
|
"""
|
|
|
|
@functools.wraps(fn)
|
|
def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
|
guards = GlobalStateGuard()
|
|
prior_grad_mode = torch.is_grad_enabled()
|
|
# Just in case we get left in a bad dispatch state we want to restore
|
|
# it. This can happen because the dispatch bits aren't a true
|
|
# stack/counter - so we can't just increment/decrement them as we enter
|
|
# and leave.
|
|
with torch._C._PreserveDispatchKeyGuard():
|
|
prior_inference_mode = torch.is_inference_mode_enabled()
|
|
prior_deterministic = torch.are_deterministic_algorithms_enabled()
|
|
prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
|
|
prior_mobile_allocator_state = (
|
|
torch._C._is_default_mobile_cpu_allocator_set()
|
|
)
|
|
py_rng_state = random.getstate()
|
|
prior_dtype = torch.get_default_dtype()
|
|
torch_rng_state = torch.random.get_rng_state()
|
|
cuda_rng_state = None
|
|
if torch.cuda.is_available():
|
|
cuda_rng_state = torch.cuda.get_rng_state()
|
|
allow_tf32 = torch._C._get_cublas_allow_tf32()
|
|
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
|
|
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
|
|
cleanup = setup_compile_debug()
|
|
exit_stack = contextlib.ExitStack()
|
|
exit_stack.enter_context(
|
|
torch.fx._symbolic_trace._maybe_revert_all_patches()
|
|
)
|
|
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
|
|
try:
|
|
return fn(*args, **kwargs)
|
|
finally:
|
|
cleanup.close()
|
|
assert torch._C._len_torch_function_stack() == 0, (
|
|
"Torch function mode stack state changed while dynamo tracing, please report a bug"
|
|
)
|
|
exit_stack.close()
|
|
torch._C._set_grad_enabled(prior_grad_mode)
|
|
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
|
|
torch.use_deterministic_algorithms(
|
|
prior_deterministic, warn_only=prior_warn_only
|
|
)
|
|
random.setstate(py_rng_state)
|
|
torch.random.set_rng_state(torch_rng_state)
|
|
torch.set_default_dtype(prior_dtype)
|
|
curr_mobile_allocator_state = (
|
|
torch._C._is_default_mobile_cpu_allocator_set()
|
|
)
|
|
if prior_mobile_allocator_state != curr_mobile_allocator_state:
|
|
torch._C._unset_default_mobile_cpu_allocator()
|
|
if cuda_rng_state is not None:
|
|
torch.cuda.set_rng_state(cuda_rng_state)
|
|
torch._C._set_cublas_allow_tf32(allow_tf32)
|
|
torch.fx.graph_module._forward_from_src = prior_fwd_from_src
|
|
assert guards.check(), (
|
|
f"Global {guards.reason()}state changed while dynamo tracing, please report a bug"
|
|
)
|
|
|
|
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
|
|
return _fn
|
|
|
|
|
|
@TorchPatcher.suppress_torch_distributed_warnings
|
|
def has_tensor_in_frame(frame: DynamoFrameType) -> bool:
|
|
"""Check if the frame has torch.* related bits"""
|
|
# Check if the function was decorated using torch._dynamo.optimize
|
|
if frame.f_code in always_optimize_code_objects:
|
|
return True
|
|
|
|
# Check if there is global import of torch.*
|
|
for co_name in frame.f_code.co_names:
|
|
if co_name in frame.f_globals:
|
|
obj = frame.f_globals[co_name]
|
|
if isinstance(obj, ModuleType) and (
|
|
obj.__name__.startswith("torch.") or obj is torch
|
|
):
|
|
return True
|
|
# ... or a global import of numpy.*
|
|
if np and config.trace_numpy and (obj is np or is_numpy(obj)):
|
|
return True
|
|
|
|
seen_ids: dict[int, bool] = {}
|
|
|
|
def has_tensor(obj: object) -> bool:
|
|
"""Recursively check if the obj has a tensor"""
|
|
obj_id = id(obj)
|
|
if obj_id in seen_ids:
|
|
return seen_ids[obj_id]
|
|
seen_ids[obj_id] = False
|
|
|
|
if isinstance(obj, (torch.Tensor, torch.nn.Module)) or (
|
|
istype(obj, type) and issubclass(obj, torch.nn.Module)
|
|
):
|
|
seen_ids[obj_id] = True
|
|
return seen_ids[obj_id]
|
|
elif (
|
|
config.trace_numpy
|
|
and np
|
|
and (istype(obj, np.ndarray) or isinstance(obj, np.generic))
|
|
):
|
|
seen_ids[obj_id] = True
|
|
return seen_ids[obj_id]
|
|
elif istype(obj, (list, tuple)):
|
|
seen_ids[obj_id] = any(has_tensor(v) for v in obj)
|
|
return seen_ids[obj_id]
|
|
elif istype(obj, dict):
|
|
# Some packages like pytest can be updated during runtime. So, make a
|
|
# copy of values to avoid issues like "RuntimeError: dictionary
|
|
# changed size during iteration"
|
|
values = list(obj.values())
|
|
seen_ids[obj_id] = any(has_tensor(v) for v in values)
|
|
return seen_ids[obj_id]
|
|
elif istype(obj, (str, int, float, type(None), bool)):
|
|
seen_ids[obj_id] = False
|
|
return seen_ids[obj_id]
|
|
elif is_namedtuple(obj) and hasattr(obj, "_fields"):
|
|
seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields)
|
|
return seen_ids[obj_id]
|
|
else:
|
|
# if config.debug:
|
|
# print(
|
|
# f"Assuming that object of type {type(obj)} does not have a tensor"
|
|
# )
|
|
return False
|
|
|
|
# Check if the passed arguments are of type Tensor
|
|
for value in frame.f_locals.values():
|
|
if has_tensor(value):
|
|
return True
|
|
|
|
log.debug(
|
|
"skipping because no torch.* %s \
|
|
%s %s",
|
|
frame.f_code.co_name,
|
|
frame.f_code.co_filename,
|
|
frame.f_code.co_firstlineno,
|
|
)
|
|
|
|
return False
|
|
|
|
|
|
def exception_handler(
|
|
e: Exception,
|
|
code: CodeType,
|
|
frame: Optional[DynamoFrameType] = None,
|
|
export: bool = False,
|
|
) -> None:
|
|
record_filename = None
|
|
if hasattr(e, "exec_record"):
|
|
record_filename = gen_record_file_name(e, code)
|
|
write_record_to_file(record_filename, e.exec_record)
|
|
e.record_filename = record_filename # type: ignore[attr-defined]
|
|
|
|
augment_exc_message(e, export=export)
|
|
|
|
|
|
FRAME_COUNTER = 0
|
|
FRAME_COMPILE_COUNTER: typing.Counter[Union[int, FrameStateSizeEntry]] = (
|
|
collections.Counter()
|
|
)
|
|
|
|
|
|
def maybe_cprofile(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
|
if config.cprofile:
|
|
return cprofile_wrapper(func)
|
|
return func
|
|
|
|
|
|
def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
|
@functools.wraps(func)
|
|
def profile_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
|
trace_id = CompileContext.current_trace_id()
|
|
assert trace_id, "Trace id is None"
|
|
profile_path = Path(
|
|
f"/tmp/{func.__name__}_{str(trace_id).replace('/', '_')}.profile"
|
|
)
|
|
prof = cProfile.Profile()
|
|
prof.enable()
|
|
start_ts = time.time()
|
|
retval = prof.runcall(func, *args, **kwargs)
|
|
profile_latency = time.time() - start_ts
|
|
prof.disable()
|
|
log.warning(
|
|
"### Cprofile for %s trace id [%s] took %.3f seconds ###",
|
|
func.__name__,
|
|
trace_id,
|
|
profile_latency,
|
|
)
|
|
ps = pstats.Stats(prof)
|
|
try:
|
|
prof.dump_stats(profile_path)
|
|
except PermissionError:
|
|
log.exception("Cannot write to %s", profile_path)
|
|
log.warning("Raw profile at %s", profile_path)
|
|
svg_path = profile_path.with_suffix(".svg")
|
|
try:
|
|
gprof2dot_process = subprocess.Popen(
|
|
[
|
|
"gprof2dot",
|
|
"-f",
|
|
"pstats",
|
|
"--node-label=total-time-percentage",
|
|
"--node-label=self-time-percentage",
|
|
"--node-label=total-time",
|
|
str(profile_path),
|
|
],
|
|
stdout=subprocess.PIPE,
|
|
)
|
|
subprocess.check_call(
|
|
["dot", "-Tsvg", "-o", str(svg_path)],
|
|
stdin=gprof2dot_process.stdout,
|
|
)
|
|
log.warning("Generated SVG from profile at %s", svg_path)
|
|
except FileNotFoundError:
|
|
log.warning(
|
|
"Failed to generate SVG from profile -- dumping stats instead."
|
|
"Try installing gprof2dot and dot for a better visualization"
|
|
)
|
|
ps.sort_stats(pstats.SortKey.TIME).print_stats(20)
|
|
ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20)
|
|
|
|
if manifold_link := maybe_upload_prof_stats_to_manifold(
|
|
str(profile_path)
|
|
): # fb-only
|
|
torch._logging.trace_structured(
|
|
"link",
|
|
lambda: {"name": "cprofile_manifold_url", "url": manifold_link},
|
|
)
|
|
return retval
|
|
|
|
return profile_wrapper
|
|
|
|
|
|
class ConvertFrameAssert:
|
|
def __init__(
|
|
self,
|
|
compiler_fn: CompilerFn,
|
|
one_graph: bool = True,
|
|
export: bool = False,
|
|
export_constraints: Optional[typing.Never] = None,
|
|
) -> None:
|
|
# assert export_constraints is None
|
|
reset_graph_break_dup_checker()
|
|
self._torchdynamo_orig_callable = compiler_fn
|
|
self._one_graph = one_graph
|
|
self._export = export
|
|
self._export_constraints = export_constraints
|
|
|
|
@property
|
|
def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]:
|
|
return lambda backend: convert_frame_assert(
|
|
backend,
|
|
self._one_graph,
|
|
self._export,
|
|
self._export_constraints,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
frame: DynamoFrameType,
|
|
cache_entry: Optional[CacheEntry],
|
|
hooks: Hooks,
|
|
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
|
|
*,
|
|
skip: int = 0,
|
|
) -> ConvertFrameReturn:
|
|
increment_frame()
|
|
|
|
code = frame.f_code
|
|
|
|
cache_size = compute_cache_size(frame, cache_entry)
|
|
input_codes.add(code)
|
|
if code in output_codes:
|
|
return ConvertFrameReturn()
|
|
if (
|
|
os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION")
|
|
and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name
|
|
):
|
|
return ConvertFrameReturn()
|
|
if code.co_name == "<genexpr>" and code.co_filename.endswith(
|
|
(
|
|
"transformers/file_utils.py",
|
|
"transformers/utils/generic.py",
|
|
"diffusers/utils/outputs.py",
|
|
)
|
|
):
|
|
# not needed, but cleans up torchbench error stats
|
|
return ConvertFrameReturn()
|
|
if code.co_name == "__setattr__":
|
|
# setattr could be tricky to handle generally,
|
|
# but also not likely useful to compile- skip the whole frame
|
|
return ConvertFrameReturn()
|
|
if code.co_name == "__init__" and code.co_filename.startswith(
|
|
os.path.dirname(torch.optim.__file__)
|
|
):
|
|
# optimizer support is still incomplete see
|
|
# test_state_dict in test/dynamo/test_optimizers.py
|
|
return ConvertFrameReturn()
|
|
|
|
# Check if the frame is generated by an exec builtin call
|
|
# TODO - Running exec generated frame seems propagates f_globals to the
|
|
# next frames.
|
|
if code.co_name == "<module>" and code.co_filename == "<string>":
|
|
return ConvertFrameReturn()
|
|
|
|
if (
|
|
code.co_name == "<lambda>"
|
|
and code.co_filename == "<string>"
|
|
and not bool(frame.f_builtins)
|
|
):
|
|
# namedtuple subclass constructor. Empty builtins cause issue with
|
|
# len keyword in LIST_LEN guard.
|
|
return ConvertFrameReturn()
|
|
|
|
if is_generator(code):
|
|
unimplemented_v2(
|
|
gb_type="Attempt to trace generator",
|
|
context="",
|
|
explanation="Generators cannot be compiled directly with `torch.compile`.",
|
|
hints=[
|
|
"Call a generator from inside of a non-generator Python function and "
|
|
"compile that function instead.",
|
|
*graph_break_hints.FUNDAMENTAL,
|
|
],
|
|
)
|
|
|
|
if not has_tensor_in_frame(frame):
|
|
return ConvertFrameReturn()
|
|
|
|
global initial_global_state
|
|
initial_global_state = GlobalStateGuard()
|
|
|
|
global FRAME_COUNTER
|
|
if "_id" not in frame_state:
|
|
frame_state["_id"] = FRAME_COUNTER
|
|
FRAME_COUNTER += 1
|
|
frame_id = frame_state["_id"]
|
|
assert isinstance(frame_id, int)
|
|
|
|
frame_compile_id = FRAME_COMPILE_COUNTER[frame_id]
|
|
FRAME_COMPILE_COUNTER[frame_id] += 1
|
|
|
|
compiled_autograd_id = None
|
|
if prior := CompileContext.current_compile_id():
|
|
compiled_autograd_id = prior.compiled_autograd_id
|
|
compile_id = CompileId(
|
|
compiled_autograd_id=compiled_autograd_id,
|
|
frame_id=frame_id,
|
|
frame_compile_id=frame_compile_id,
|
|
)
|
|
|
|
signpost_event(
|
|
"dynamo",
|
|
"_convert_frame_assert._compile",
|
|
{
|
|
"co_name": code.co_name,
|
|
"frame_id": frame_id,
|
|
"compile_id": str(compile_id),
|
|
"co_filename": code.co_filename,
|
|
"co_firstlineno": code.co_firstlineno,
|
|
"cache_size": cache_size.num_cache_entries_with_same_id_matched_objs,
|
|
"accumulated_cache_size": cache_size.num_cache_entries,
|
|
},
|
|
)
|
|
|
|
# Record traced frames, skipping Dynamo generated ones.
|
|
if not code.co_name.startswith(TORCH_DYNAMO_RESUME_IN_PREFIX):
|
|
info = f"{code.co_name} {code.co_filename}:{code.co_firstlineno}"
|
|
dynamo_tls.traced_frame_infos.append(info)
|
|
|
|
with compile_context(CompileContext(compile_id)):
|
|
return _compile(
|
|
frame.f_code,
|
|
frame.f_globals,
|
|
frame.f_locals,
|
|
frame.f_builtins,
|
|
frame.closure,
|
|
self._torchdynamo_orig_callable,
|
|
self._one_graph,
|
|
self._export,
|
|
self._export_constraints,
|
|
hooks,
|
|
cache_entry,
|
|
cache_size,
|
|
frame,
|
|
frame_state=frame_state,
|
|
compile_id=compile_id,
|
|
skip=skip + 1,
|
|
)
|
|
|
|
|
|
def convert_frame_assert(
|
|
compiler_fn: CompilerFn,
|
|
one_graph: bool = True,
|
|
export: bool = False,
|
|
export_constraints: Optional[typing.Never] = None,
|
|
) -> ConvertFrameAssert:
|
|
"""Fully convert a frame into an FX graph"""
|
|
return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints)
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
from torch.utils.hooks import RemovableHandle
|
|
|
|
|
|
if typing.TYPE_CHECKING:
|
|
from .output_graph import OutputGraph
|
|
|
|
# we have to use `OrderedDict` to make `RemovableHandle` work.
|
|
_bytecode_hooks: dict[int, BytecodeHook] = OrderedDict()
|
|
|
|
|
|
def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle:
|
|
"""Register hooks for bytecode generated by Dynamo. The hook can do some
|
|
logging, as well as return a new code object to be used. Please refer
|
|
to `BytecodeHook` for the hook signature.
|
|
"""
|
|
handle = RemovableHandle(_bytecode_hooks)
|
|
_bytecode_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
|
|
def _compile(
|
|
code: CodeType,
|
|
globals: dict[str, object],
|
|
locals: dict[str, object],
|
|
builtins: dict[str, object],
|
|
closure: tuple[CellType],
|
|
compiler_fn: CompilerFn,
|
|
one_graph: bool,
|
|
export: bool,
|
|
export_constraints: Optional[typing.Never],
|
|
hooks: Hooks,
|
|
cache_entry: Optional[CacheEntry],
|
|
cache_size: CacheSizeRelevantForFrame,
|
|
frame: Optional[DynamoFrameType] = None,
|
|
frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
|
|
*,
|
|
compile_id: CompileId,
|
|
skip: int = 0,
|
|
) -> ConvertFrameReturn:
|
|
from torch.fx.experimental.validator import (
|
|
bisect,
|
|
BisectValidationException,
|
|
translation_validation_enabled,
|
|
ValidationException,
|
|
)
|
|
|
|
# Only nonlocal defs here please!
|
|
# Time spent compiling this frame before restarting or failing analysis
|
|
dynamo_time_before_restart: float = 0.0
|
|
output: Optional[OutputGraph] = None
|
|
tracer: Optional[InstructionTranslator] = None
|
|
|
|
tf_mode_stack: list[torch.overrides.TorchFunctionMode] = (
|
|
torch.overrides._get_current_function_mode_stack()
|
|
)
|
|
|
|
@preserve_global_state
|
|
def transform(
|
|
instructions: list[Instruction], code_options: dict[str, object]
|
|
) -> None:
|
|
nonlocal output
|
|
nonlocal tracer
|
|
speculation_log.restart()
|
|
exn_vt_stack = ExceptionStack()
|
|
tracer = InstructionTranslator(
|
|
instructions,
|
|
code,
|
|
locals,
|
|
globals,
|
|
builtins,
|
|
closure,
|
|
tf_mode_stack,
|
|
code_options,
|
|
compiler_fn,
|
|
one_graph,
|
|
export,
|
|
export_constraints,
|
|
frame_state=frame_state,
|
|
speculation_log=speculation_log,
|
|
exn_vt_stack=exn_vt_stack,
|
|
distributed_state=distributed_state,
|
|
)
|
|
|
|
try:
|
|
with tracing(tracer.output.tracing_context), tracer.set_current_tx():
|
|
tracer.run()
|
|
except exc.UnspecializeRestartAnalysis:
|
|
speculation_log.clear()
|
|
raise
|
|
except (
|
|
exc.SpeculationRestartAnalysis,
|
|
exc.TensorifyScalarRestartAnalysis,
|
|
exc.SkipFrame,
|
|
):
|
|
raise
|
|
except Exception:
|
|
if translation_validation_enabled():
|
|
bisect(tracer.output.shape_env)
|
|
raise
|
|
finally:
|
|
tracer.output.call_cleanup_hooks()
|
|
|
|
output = tracer.output
|
|
assert output is not None
|
|
assert output.output_instructions
|
|
instructions[:] = output.output_instructions
|
|
code_options.update(output.code_options)
|
|
propagate_inst_exn_table_entries(instructions)
|
|
check_inst_exn_tab_entries_valid(instructions)
|
|
instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))
|
|
|
|
@compile_time_strobelight_meta(phase_name="compile_inner")
|
|
def compile_inner(
|
|
code: CodeType,
|
|
one_graph: bool,
|
|
hooks: Hooks,
|
|
transform: Callable[[list[Instruction], dict[str, Any]], Any],
|
|
) -> ConvertFrameReturn:
|
|
with contextlib.ExitStack() as stack:
|
|
stack.enter_context(
|
|
dynamo_timed(
|
|
"_compile.compile_inner",
|
|
phase_name="entire_frame_compile",
|
|
dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
_WaitCounter("pytorch.wait_counter.dynamo_compile").guard()
|
|
)
|
|
stack.enter_context(torch._dynamo.callback_handler.install_callbacks())
|
|
stack.enter_context(CompileTimeInstructionCounter.record())
|
|
return _compile_inner(code, one_graph, hooks, transform)
|
|
|
|
return (
|
|
ConvertFrameReturn()
|
|
) # dead, but see https://github.com/python/mypy/issues/7577
|
|
|
|
@maybe_cprofile
|
|
def _compile_inner(
|
|
code: CodeType,
|
|
one_graph: bool,
|
|
hooks: Hooks,
|
|
transform: Callable[[list[Instruction], dict[str, Any]], Any],
|
|
) -> ConvertFrameReturn:
|
|
nonlocal dynamo_time_before_restart
|
|
last_attempt_start_time = start_time = time.time()
|
|
|
|
def log_bytecode(
|
|
prefix: str, name: str, filename: str, line_no: int, code: CodeType
|
|
) -> None:
|
|
if bytecode_log.isEnabledFor(logging.DEBUG):
|
|
bytecode_log.debug(
|
|
format_bytecode(prefix, name, filename, line_no, code)
|
|
)
|
|
|
|
log_bytecode(
|
|
"ORIGINAL BYTECODE",
|
|
code.co_name,
|
|
code.co_filename,
|
|
code.co_firstlineno,
|
|
code,
|
|
)
|
|
|
|
out_code = None
|
|
for attempt in itertools.count():
|
|
CompileContext.get().attempt = attempt
|
|
try:
|
|
out_code = transform_code_object(code, transform)
|
|
break
|
|
except exc.RestartAnalysis as e:
|
|
if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
|
|
TensorifyState.clear()
|
|
log.info(
|
|
"Restarting analysis due to %s",
|
|
LazyString(format_traceback_short, e.__traceback__),
|
|
)
|
|
# If restart reason is None just log the type of the exception
|
|
restart_reasons.add(e.restart_reason or str(type(e)))
|
|
# We now have a new "last attempt", reset the clock
|
|
last_attempt_start_time = time.time()
|
|
if attempt > 100:
|
|
unimplemented_v2(
|
|
gb_type="Excessive RestartAnalysis() calls",
|
|
context="",
|
|
explanation="Dynamo attempted to trace the same frame 100+ times. "
|
|
"Giving up on compiling as the compile time tradeoff is likely not "
|
|
"worth the performance gain.",
|
|
hints=[],
|
|
)
|
|
except exc.SkipFrame as e:
|
|
if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
|
|
TensorifyState.clear()
|
|
log.debug(
|
|
"Skipping frame %s %s \
|
|
%s %s",
|
|
e,
|
|
code.co_name,
|
|
code.co_filename,
|
|
code.co_firstlineno,
|
|
)
|
|
if one_graph:
|
|
log.debug("No graph captured with one_graph=True")
|
|
return ConvertFrameReturn()
|
|
|
|
assert distributed_state is None or distributed_state.all_states is not None, (
|
|
"compiler collective wasn't run before compilation completed"
|
|
)
|
|
|
|
assert out_code is not None
|
|
log_bytecode(
|
|
"MODIFIED BYTECODE",
|
|
code.co_name,
|
|
code.co_filename,
|
|
code.co_firstlineno,
|
|
out_code,
|
|
)
|
|
|
|
for hook in _bytecode_hooks.values():
|
|
hook_output = hook(code, out_code)
|
|
if hook_output is not None:
|
|
out_code = hook_output
|
|
|
|
orig_code_map[out_code] = code
|
|
output_codes.add(out_code)
|
|
dynamo_time_before_restart = last_attempt_start_time - start_time
|
|
assert output is not None
|
|
|
|
# Tests for new code objects.
|
|
# The rationale for these tests can be found in torch/csrc/dynamo/eval_frame.c
|
|
# Only test once the code object is created.
|
|
# They are not tested during runtime.
|
|
|
|
def count_args(code: CodeType) -> int:
|
|
import inspect
|
|
|
|
return (
|
|
code.co_argcount
|
|
+ code.co_kwonlyargcount
|
|
+ bool(code.co_flags & inspect.CO_VARARGS)
|
|
+ bool(code.co_flags & inspect.CO_VARKEYWORDS)
|
|
)
|
|
|
|
assert out_code is not None
|
|
|
|
total_argcount_old = count_args(code)
|
|
total_argcount_new = count_args(out_code)
|
|
msg = "arg mismatch: "
|
|
msg += f"old code object has args {code.co_varnames[:total_argcount_old]}, "
|
|
msg += f"new code object has args {out_code.co_varnames[:total_argcount_new]}"
|
|
assert (
|
|
code.co_varnames[:total_argcount_old]
|
|
== out_code.co_varnames[:total_argcount_new]
|
|
), msg
|
|
|
|
msg = "free var mismatch: "
|
|
msg += f"old code object has free var {code.co_freevars}, "
|
|
msg += f"new code object has free var {out_code.co_freevars}"
|
|
assert code.co_freevars == out_code.co_freevars, msg
|
|
|
|
msg = "cell var mismatch: "
|
|
msg += f"old code object has cell var {code.co_cellvars}, "
|
|
msg += f"new code object has cell var {out_code.co_cellvars}"
|
|
assert code.co_cellvars == out_code.co_cellvars, msg
|
|
|
|
# Skipping Dynamo on a frame without any extracted graph.
|
|
# This does not affect eager functionality. But this is necessary
|
|
# for export for cases where Dynamo-reconstructed bytecode can create
|
|
# new function frames, confusing export in thinking that there
|
|
# are extra graphs now.
|
|
|
|
if output.export and output.is_empty_graph():
|
|
return ConvertFrameReturn()
|
|
|
|
assert output.guards is not None
|
|
CleanupManager.instance[out_code] = output.cleanups
|
|
nonlocal cache_entry
|
|
check_fn = CheckFunctionManager(
|
|
code,
|
|
output,
|
|
cache_entry,
|
|
hooks.guard_fail_fn if hooks else None,
|
|
)
|
|
|
|
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
|
|
annotation_str = "Torch-Compiled Region: " + compile_id_str
|
|
guarded_code = GuardedCode(
|
|
out_code,
|
|
check_fn.guard_manager, # type: ignore[arg-type]
|
|
compile_id,
|
|
annotation_str,
|
|
)
|
|
|
|
if not output.is_empty_graph() and hooks.guard_export_fn is not None:
|
|
# We should not run the guard_export_fn when Dynamo does not
|
|
# generate any graph. This can happen in export when TorchDynamo
|
|
# generated bytecode has some reconstruction logic for mutated
|
|
# variables which can trigger TorchDynamo on the children frames but
|
|
# they are benign and do not generate any new graphs.
|
|
hooks.guard_export_fn(output.guards)
|
|
|
|
return wrap_guarded_code(guarded_code)
|
|
|
|
metrics_context = get_metrics_context()
|
|
with (
|
|
_use_lazy_graph_module(config.use_lazy_graph_module),
|
|
compile_context(CompileContext(compile_id)),
|
|
chromium_event_timed(
|
|
"dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True
|
|
),
|
|
metrics_context,
|
|
):
|
|
restart_reasons: set[str] = set()
|
|
# This is shared across restarts
|
|
speculation_log = SpeculationLog()
|
|
if compile_pg := get_compile_pg():
|
|
distributed_state = DistributedState(compile_pg, LocalState())
|
|
else:
|
|
distributed_state = None
|
|
|
|
# Check recompilations
|
|
recompile_reason: Optional[str] = None
|
|
if is_recompilation(cache_size) and frame:
|
|
reasons = get_and_maybe_log_recompilation_reasons(cache_entry, frame)
|
|
recompile_reason = (
|
|
"Unable to find recompilation reasons" if not reasons else reasons[0]
|
|
)
|
|
metrics_context.update_outer({"recompile_reason": recompile_reason})
|
|
|
|
exceeded, limit_type = exceeds_recompile_limit(cache_size, compile_id)
|
|
if exceeded:
|
|
|
|
def format_func_info(code: CodeType) -> str:
|
|
return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})"
|
|
|
|
log.warning(
|
|
"torch._dynamo hit config.%s (%s)\n"
|
|
" function: %s\n"
|
|
" last reason: %s\n"
|
|
'To log all recompilation reasons, use TORCH_LOGS="recompiles".\n'
|
|
"To diagnose recompilation issues, see %s.",
|
|
limit_type,
|
|
getattr(config, limit_type),
|
|
format_func_info(code),
|
|
recompile_reason,
|
|
troubleshooting_url,
|
|
)
|
|
if config.fail_on_recompile_limit_hit:
|
|
raise FailOnRecompileLimitHit(
|
|
f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure"
|
|
)
|
|
elif one_graph:
|
|
raise FailOnRecompileLimitHit(
|
|
f"{limit_type} reached with one_graph=True. Excessive recompilations can degrade "
|
|
"performance due to the compilation overhead of each recompilation. To monitor "
|
|
"recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider "
|
|
"increasing torch._dynamo.config.cache_size_limit to an appropriate value."
|
|
)
|
|
elif justknobs_check(
|
|
"pytorch/compiler:skip_code_recursive_on_recompile_limit_hit"
|
|
):
|
|
raise RecompileLimitExceeded(f"{limit_type} reached")
|
|
else:
|
|
# do not recursively skip frames
|
|
unimplemented_v2(
|
|
gb_type="Dynamo cache limit exceeded",
|
|
context=f"Limit type: {limit_type}",
|
|
explanation="Dynamo attempted to recompile the code object too many times, "
|
|
f"exceeding the {limit_type} cache size limit."
|
|
"Giving up on compiling as the compile time tradeoff is likely not "
|
|
"worth the performance gain.",
|
|
hints=[],
|
|
)
|
|
|
|
log.debug(
|
|
"torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s",
|
|
code.co_name,
|
|
code.co_filename,
|
|
code.co_firstlineno,
|
|
skip + 2,
|
|
# -2: omit current frame, omit contextlib decorator
|
|
"".join(CapturedTraceback.extract(skip=2 + skip).format()),
|
|
)
|
|
# -4: -2 as above, plus trace_structured frames
|
|
#
|
|
# NB: the frame looks like this:
|
|
#
|
|
# # handled by skip argument
|
|
# torch/_dynamo/convert_frame.py:1069 in catch_errors
|
|
# torch/_dynamo/convert_frame.py:910 in _convert_frame
|
|
# torch/_dynamo/convert_frame.py:464 in _convert_frame_assert
|
|
# torch/_utils_internal.py:70 in wrapper_function
|
|
#
|
|
# # 2 current frame and context lib
|
|
# env/lib/python3.10/contextlib.py:79 in inner
|
|
# torch/_dynamo/convert_frame.py:776 in _compile
|
|
#
|
|
# # 2 extra here
|
|
# torch/_logging/_internal.py:1064 in trace_structured
|
|
# torch/_dynamo/convert_frame.py:780 in <lambda>
|
|
convert_frame_intern = structured.intern_string(__file__)
|
|
# Initialize the ChromiumEventLogger on start
|
|
torch._logging.trace_structured(
|
|
"dynamo_start",
|
|
lambda: {
|
|
"stack": list(
|
|
itertools.takewhile(
|
|
lambda f: f["filename"] != convert_frame_intern,
|
|
structured.from_traceback(
|
|
CapturedTraceback.extract(skip=4 + skip).summary()
|
|
),
|
|
)
|
|
)
|
|
+ [
|
|
{
|
|
"line": code.co_firstlineno,
|
|
"name": code.co_name,
|
|
"filename": structured.intern_string(code.co_filename),
|
|
}
|
|
]
|
|
},
|
|
)
|
|
start_time_ns = time.time_ns()
|
|
fail_type: Optional[str] = None
|
|
fail_reason: Optional[str] = None
|
|
fail_user_frame_filename: Optional[str] = None
|
|
fail_user_frame_lineno: Optional[int] = None
|
|
torch._dynamo.utils.ReinplaceCounters.clear()
|
|
guarded_code = None
|
|
try:
|
|
guarded_code = compile_inner(code, one_graph, hooks, transform)
|
|
|
|
# NB: We only put_code_state in success case. Success case here
|
|
# does include graph breaks; specifically, if a graph break still
|
|
# resulted in a partially compiled graph, we WILL return here. An
|
|
# Unsupported exception will only bubble to the top level if we
|
|
# are unable to compile the frame at all. In this case, there's
|
|
# no point in uploading the code state, because we will always
|
|
# fail exactly the same way even without the update. (It's useful
|
|
# to upload for graph break though, because this can prevent
|
|
# extra graph break compilations.)
|
|
put_code_state()
|
|
|
|
return guarded_code
|
|
except Exception as e:
|
|
# NB: e's msg is mutated here to add user stack, but we DON'T want
|
|
# that stack in the Scuba logged fail_reason. So we grab the fail
|
|
# info here and add it to the metrics context below.
|
|
fail_type = type(e).__qualname__
|
|
fail_reason = str(e)
|
|
exception_handler(e, code, frame, export=export)
|
|
# NB: this is the post-mutation exception
|
|
torch._logging.trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "dynamo_error",
|
|
"encoding": "string",
|
|
},
|
|
payload_fn=lambda: traceback.format_exc(),
|
|
)
|
|
fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message(
|
|
e, compile_id
|
|
)
|
|
if isinstance(
|
|
e,
|
|
(
|
|
Unsupported,
|
|
TorchRuntimeError,
|
|
BackendCompilerFailed,
|
|
AssertionError,
|
|
ConstraintViolationError,
|
|
GuardOnDataDependentSymNode,
|
|
ValidationException,
|
|
UncapturedHigherOrderOpError,
|
|
BisectValidationException,
|
|
ShortenTraceback,
|
|
),
|
|
):
|
|
raise
|
|
else:
|
|
# Rewrap for clarity
|
|
raise InternalTorchDynamoError(
|
|
f"{type(e).__qualname__}: {str(e)}"
|
|
).with_traceback(e.__traceback__) from None
|
|
finally:
|
|
# === WARNING WARNING WARNING ===
|
|
# If you commit a bug here, it will suppress writing to
|
|
# dynamo_compile table, and we will not have telemetry.
|
|
# Be extra careful when making changes here!
|
|
|
|
if torch._dynamo.config.run_gc_after_compile:
|
|
with dynamo_timed("gc", dynamo_compile_column_us="gc_time_us"):
|
|
log.info("run_gc_after_compile: running gc")
|
|
gc.collect(1)
|
|
|
|
if tracer:
|
|
tracer.output.local_scope = {}
|
|
|
|
from .utils import curr_frame
|
|
|
|
frame_key = str(curr_frame)
|
|
if fail_reason is None and output is not None:
|
|
guard_count = len(output.guards)
|
|
shape_env_guard_count = len(output.shape_env.guards)
|
|
graph_op_count = output.count_calls()
|
|
graph_node_count = len(output.graph.nodes)
|
|
graph_input_count = len(output.placeholders)
|
|
non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops}
|
|
compliant_custom_ops = {
|
|
op.__qualname__ for op in output.compliant_custom_ops
|
|
}
|
|
torch._dynamo.utils.ReinplaceCounters.log()
|
|
else:
|
|
guard_count = None
|
|
shape_env_guard_count = None
|
|
graph_op_count = None
|
|
graph_node_count = None
|
|
graph_input_count = None
|
|
non_compliant_ops = set({})
|
|
compliant_custom_ops = set({})
|
|
restart_reasons = set()
|
|
# If compilation failed, the entire time is wasted
|
|
dynamo_time_before_restart = (time.time_ns() - start_time_ns) / 1e9
|
|
|
|
metrics = {
|
|
"frame_key": frame_key,
|
|
"co_name": code.co_name,
|
|
"co_filename": code.co_filename,
|
|
"co_firstlineno": code.co_firstlineno,
|
|
"cache_size": cache_size.num_cache_entries_with_same_id_matched_objs,
|
|
"accumulated_cache_size": cache_size.num_cache_entries,
|
|
"guard_count": guard_count,
|
|
"shape_env_guard_count": shape_env_guard_count,
|
|
"graph_op_count": graph_op_count,
|
|
"graph_node_count": graph_node_count,
|
|
"graph_input_count": graph_input_count,
|
|
"fail_type": fail_type,
|
|
"fail_reason": fail_reason,
|
|
"fail_user_frame_filename": fail_user_frame_filename,
|
|
"fail_user_frame_lineno": fail_user_frame_lineno,
|
|
"non_compliant_ops": non_compliant_ops,
|
|
"compliant_custom_ops": compliant_custom_ops,
|
|
"restart_reasons": restart_reasons,
|
|
"dynamo_time_before_restart_s": dynamo_time_before_restart,
|
|
"has_guarded_code": guarded_code is not None,
|
|
"config_suppress_errors": config.suppress_errors,
|
|
"config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules,
|
|
"specialize_float": config.specialize_float,
|
|
"is_forward": True,
|
|
"dynamo_compile_time_before_restart_us": to_int_us(
|
|
dynamo_time_before_restart
|
|
),
|
|
}
|
|
# TODO: replace with CompileEventLogger.compilation_metrics
|
|
# There are some columns here not in PT2 Compile Events
|
|
# so we need to slightly change it
|
|
metrics_context.update_outer(metrics)
|
|
# === END WARNING WARNING WARNING ===
|
|
|
|
|
|
class ConvertFrame:
|
|
def __init__(
|
|
self,
|
|
compiler_fn: CompilerFn,
|
|
hooks: Hooks,
|
|
) -> None:
|
|
self._torchdynamo_orig_callable = compiler_fn
|
|
self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
|
|
self._hooks = hooks
|
|
|
|
@property
|
|
def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]:
|
|
return lambda backend: convert_frame(backend, self._hooks)
|
|
|
|
def __call__(
|
|
self,
|
|
frame: DynamoFrameType,
|
|
cache_entry: Optional[CacheEntry],
|
|
hooks: Hooks,
|
|
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
|
|
skip: int = 0,
|
|
) -> ConvertFrameReturn:
|
|
counters["frames"]["total"] += 1
|
|
try:
|
|
result = self._inner_convert(
|
|
frame, cache_entry, hooks, frame_state, skip=skip + 1
|
|
)
|
|
counters["frames"]["ok"] += 1
|
|
return result
|
|
except Exception as e:
|
|
# These two exception types are "soft" failure, in the sense that
|
|
# we know this is due to something we didn't implement all the
|
|
# way, scare the user less about it. That being said, if you
|
|
# are trying to understand why a graph break happened, it's still
|
|
# important to have this information, so offer it.
|
|
#
|
|
# NB: NotImplementedError used to be on this list, but actually
|
|
# it is impossible for it to reach here, as it is converted into
|
|
# InternalTorchDynamoError. This behavior seemed reasonable
|
|
# to me (ezyang, Aug 2023) so I kept it, but maybe at some point
|
|
# someone wanted these to also get suppressed. If so, you'll
|
|
# need to make these exceptions not get wrapped
|
|
|
|
# We intentionally don't want to suppress error here.
|
|
if isinstance(e, UncapturedHigherOrderOpError):
|
|
raise
|
|
|
|
soft_fail = isinstance(e, Unsupported)
|
|
|
|
# This is a soft failure. In the sense, the code path reaches here
|
|
# when we do not support graph breaks on bytecodes like LOAD_ATTR,
|
|
# BUILD_SET etc. In such case, we can fallback to eager without
|
|
# scaring users.
|
|
if soft_fail and graph_break_log.isEnabledFor(logging.DEBUG):
|
|
# Log this message in the graph break. Also use the string
|
|
# "skip: " to tell that the whole frame is falling back to
|
|
# eager.
|
|
if hasattr(e, "compile_id") and hasattr(e, "real_stack"):
|
|
with compile_context(CompileContext(e.compile_id)): # type: ignore[attr-defined]
|
|
user_stack = e.real_stack
|
|
user_stack_formatted = "".join(
|
|
traceback.format_list(user_stack)
|
|
)
|
|
user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}"
|
|
torch._logging.trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "dynamo_graph_break_reason",
|
|
"encoding": "string",
|
|
},
|
|
payload_fn=lambda: f"{user_stack_trace}\n{traceback.format_exc()}",
|
|
)
|
|
graph_break_log.debug(
|
|
user_stack_trace,
|
|
exc_info=True,
|
|
)
|
|
|
|
if not config.suppress_errors and not soft_fail:
|
|
raise
|
|
|
|
# Suppress the error. NB: It's very important to do the
|
|
# suppression logging HERE, where the actual suppression
|
|
# happens. Previously it was somewhere else and so it was
|
|
# possible to accidentally not log at all.
|
|
record_filename = getattr(e, "record_filename", None)
|
|
code = frame.f_code
|
|
error_msg = format_error_msg(e, code, record_filename, frame)
|
|
|
|
if soft_fail:
|
|
log.info(error_msg, exc_info=True)
|
|
else:
|
|
log.warning(error_msg, exc_info=True)
|
|
|
|
if isinstance(e, SkipCodeRecursiveException):
|
|
return ConvertFrameReturn(
|
|
frame_exec_strategy=FrameExecStrategy(
|
|
FrameAction.SKIP, FrameAction.SKIP
|
|
)
|
|
)
|
|
elif isinstance(e, RecompileLimitExceeded):
|
|
return ConvertFrameReturn(
|
|
frame_exec_strategy=FrameExecStrategy(
|
|
FrameAction.RUN_ONLY, FrameAction.RUN_ONLY
|
|
)
|
|
)
|
|
|
|
return ConvertFrameReturn()
|
|
|
|
|
|
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks) -> ConvertFrame:
|
|
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
|
|
return ConvertFrame(compiler_fn, hooks)
|
|
|
|
|
|
# TODO mlazos: add support for same args, or record them
|
|
def replay(filename: str) -> None:
|
|
from .backends.debugging import eager
|
|
|
|
original_replay_val = config.replay_record_enabled
|
|
config.replay_record_enabled = False
|
|
with open(filename, "rb") as in_file:
|
|
record = ExecutionRecord.load(in_file)
|
|
record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
|
|
|
|
try:
|
|
_compile(
|
|
record.code,
|
|
record.globals,
|
|
record.locals,
|
|
record.builtins,
|
|
record.closure,
|
|
compiler_fn=eager,
|
|
one_graph=False,
|
|
export=False,
|
|
export_constraints=None,
|
|
hooks=Hooks(),
|
|
cache_size=CacheSizeRelevantForFrame(0, 0),
|
|
cache_entry=None,
|
|
frame=None,
|
|
frame_state={},
|
|
compile_id=CompileId(frame_id=42, frame_compile_id=999),
|
|
)
|
|
finally:
|
|
config.replay_record_enabled = original_replay_val
|
|
|
|
|
|
def first_real_inst_idx(code: CodeType) -> int:
|
|
if sys.version_info < (3, 11):
|
|
return 0
|
|
for inst in dis.get_instructions(code):
|
|
if inst.opname == "RESUME":
|
|
return inst.offset // 2
|
|
raise RuntimeError("RESUME instruction not found in code")
|
|
|
|
|
|
class ConvertFrameProtocol(typing.Protocol):
|
|
def __call__(
|
|
self,
|
|
frame: DynamoFrameType,
|
|
cache_entry: Optional[CacheEntry],
|
|
hooks: Hooks,
|
|
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
|
|
*,
|
|
skip: int = 0,
|
|
) -> ConvertFrameReturn: ...
|
|
|
|
|
|
class CatchErrorsWrapper:
|
|
def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None:
|
|
functools.wraps(callback)(self)
|
|
self._torchdynamo_orig_callable = callback
|
|
self.hooks = hooks
|
|
|
|
def __call__(
|
|
self,
|
|
frame: DynamoFrameType,
|
|
cache_entry: Optional[CacheEntry],
|
|
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
|
|
) -> ConvertFrameReturn:
|
|
assert frame_state is not None
|
|
|
|
is_skipfile = trace_rules.check(frame.f_code)
|
|
if sys.version_info >= (3, 13):
|
|
has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code)
|
|
else:
|
|
has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code)
|
|
if (
|
|
# TODO: the first condition is not covered by any test
|
|
has_started_execution
|
|
or is_skipfile
|
|
or config.disable
|
|
or (
|
|
is_in_torch_dispatch_mode(include_infra_modes=False)
|
|
and not getattr(self._torchdynamo_orig_callable, "_export", False)
|
|
)
|
|
):
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
if has_started_execution:
|
|
skip_reason = "traced frame already"
|
|
elif trace_rules.check(frame.f_code):
|
|
skip_reason = "in skipfiles"
|
|
elif is_in_torch_dispatch_mode(include_infra_modes=False):
|
|
skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile"
|
|
else:
|
|
skip_reason = "dynamo tracing is disabled"
|
|
|
|
log.debug(
|
|
"skipping: %s (reason: %s, file: %s)",
|
|
frame.f_code.co_name,
|
|
skip_reason,
|
|
frame.f_code.co_filename,
|
|
)
|
|
return ConvertFrameReturn()
|
|
|
|
if frame.f_code.co_filename == "<string>" and frame.f_code.co_name == "__new__":
|
|
# nametuple constructor
|
|
return ConvertFrameReturn()
|
|
if torch._dynamo.utils.get_optimize_ddp_mode() == "ddp_optimizer":
|
|
ddp_module = DistributedDataParallel._get_active_ddp_module()
|
|
if ddp_module:
|
|
with compile_lock:
|
|
from torch._dynamo.backends.distributed import DDPOptimizer
|
|
|
|
ddp_optimizer = DDPOptimizer(
|
|
bucket_bytes_cap=ddp_module.bucket_bytes_cap,
|
|
backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable, # type: ignore[attr-defined]
|
|
)
|
|
assert hasattr(
|
|
self._torchdynamo_orig_callable, "_clone_with_backend"
|
|
), (
|
|
"DDPOptimizer only supports callback fns that know how to clone themselves."
|
|
)
|
|
hijacked_callback = (
|
|
self._torchdynamo_orig_callable._clone_with_backend(
|
|
ddp_optimizer.compile_fn,
|
|
)
|
|
)
|
|
return hijacked_callback(
|
|
frame, cache_entry, self.hooks, frame_state
|
|
)
|
|
|
|
with compile_lock, _disable_current_modes():
|
|
# skip=1: skip this frame
|
|
return self._torchdynamo_orig_callable(
|
|
frame, cache_entry, self.hooks, frame_state, skip=1
|
|
)
|
|
|
|
|
|
def catch_errors_wrapper(
|
|
callback: ConvertFrameProtocol, hooks: Hooks
|
|
) -> CatchErrorsWrapper:
|
|
return CatchErrorsWrapper(callback, hooks)
|