team-10/env/Lib/site-packages/torch/_inductor/bounds.py
2025-08-02 07:34:44 +02:00

259 lines
9.5 KiB
Python

import logging
import operator
from functools import partial
from typing import Any, Callable, Optional, Union
import sympy
from sympy import Expr
import torch
from torch.utils._sympy.value_ranges import (
bound_sympy,
SymPyValueRangeAnalysis,
ValueRanges,
)
from ..utils._sympy.functions import PowByNatural
from ..utils._sympy.numbers import int_oo
from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock
from .ops_handler import DefaultHandler, ReductionType, StoreMode
from .utils import cache_on_self, dominated_nodes
from .virtualized import V
log = logging.getLogger(__name__)
class BoundVars:
"""
Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run()
It exposes the ranges of the nodes in the `bounds` variable
Note. A current limitation of this analysis is that it just works on a per-loop basis.
We should be able to propagate the bounds between across the whole graph. This may benefit
the case a bounded variable is returned by a kernel and fed into another.
"""
def __init__(self, loop_body: LoopBody) -> None:
def upper_bound(v: Union[Expr, int]) -> int:
return bound_sympy(v).upper if isinstance(v, Expr) else v
self.loop_body = loop_body
self.replacement_vals = {
k: ValueRanges[Expr](0, upper_bound(v) - 1)
for k, v in loop_body.var_ranges.items()
}
# avoid computing these values, pessimistically assume that they are unbounded
self.unbounded_vars = dominated_nodes(
node
for node in self.loop_body.get_nodes()
if node.target in ["load", "reduction", operator.getitem]
or "masked_subblock" in node.target
)
# To access this variable call `get_bounds()`
self._bounds: dict[torch.fx.Node, ValueRanges[Expr]] = {}
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"loop_body={self.loop_body},\n "
f"replacement_vals={self.replacement_vals}, \n"
f"unbounded_vars={self.unbounded_vars}, \n"
f"_bounds={self._bounds})"
)
@cache_on_self
def get_bounds(self) -> dict[torch.fx.Node, ValueRanges[Expr]]:
submodules = self.swap_submodules(self.loop_body.submodules)
# Initialize the environment with the unbounded variables
for node in self.unbounded_vars:
# we need to evaluate masked_subblock to recurse, and we need to set indirect values
if not isinstance(node.target, str) or (
"masked_subblock" not in node.target
and "set_indirect" not in node.target
):
self._bounds[node] = ValueRanges[Expr].unknown()
with V.set_ops_handler(ValueRangeAnalysis()):
interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules)
log.debug("get_bounds:\n%s", self.loop_body.root_block.graph)
interpreter.run(V.get_ops_handler(), initial_env=self._bounds)
return self._bounds
def swap_submodules(
self, submodules: dict[str, Callable[..., Any]]
) -> dict[str, Callable[..., ValueRanges[Expr]]]:
result: dict[str, Callable[..., ValueRanges[Expr]]] = {}
for key in submodules.keys():
if key == "get_index":
result[key] = self.get_index
elif "masked_subblock" in key:
subblock = self.loop_body.subblocks[key]
# The result within the lambda will reference to the final
# set of modules at the end of the for-loop as it stores a reference to it
# bind subblock in a function because python lambdas close over by reference
# moving the lambda out of make_fn would close over the reference to subblock,
# so all lambdas would have the same subblock reference that is the final
# subblock in the loop
def make_fn(
subblock: LoopBodyBlock,
) -> Callable[[Any, Any], ValueRanges[Expr]]:
return lambda mask, value: self.masked_subblock(
subblock, self._bounds, mask, value, result
)
result[key] = make_fn(subblock)
elif "set_indirect" in key:
idx = int(key[len("set_indirect") :])
var = self.loop_body.indirect_vars[idx]
indirect = partial(self.set_indirect, var)
result[key] = indirect
else:
assert "scan" in key
result[key] = submodules[key]
return result
def masked_subblock(
self,
subblock: LoopBodyBlock,
env: dict[torch.fx.Node, ValueRanges[Expr]],
mask: Any,
value: Any,
submodules: dict[str, Callable[..., Any]],
) -> ValueRanges[Expr]:
interp = InterpreterShim(subblock.graph, submodules)
interp.run(V.get_ops_handler(), initial_env=env)
output = [node for node in subblock.graph.nodes if node.target == "output"]
assert len(output) == 1
# dont bother unioning with value since the load from buffer will be
# pessimistically assumed to be inf anyway
return interp.env[output[0]]
def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]:
assert isinstance(new, ValueRanges)
self.replacement_vals[old] = new
return new
def get_index(self, name: str) -> ValueRanges[Expr]:
expr = self.loop_body.indexing_exprs[name]
bound = self.replacement_vals.get(expr)
if bound is None:
bound = bound_sympy(expr, self.replacement_vals)
# The following assertion is true at the time of this writing
# We don't assert is as to not execute bound_sympy when bound is not None
# assert bound is None or bound == bound_sympy(expr, self.replacement_vals)
self.replacement_vals[name] = bound
return bound
class ValueRangeAnalysis(SymPyValueRangeAnalysis, DefaultHandler):
def __init__(self) -> None:
self.name = "ValueRangeAnalysis"
boolean_operators = (
"xor",
"logical_and",
"logical_or",
"logical_not",
)
for op in boolean_operators:
setattr(self, op, self.bool_handler)
@staticmethod
def bool_handler(*args: Any, **kwargs: Any) -> ValueRanges[Any]:
# just assuming bools can have both values
return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type]
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
# many ops are unlikely to show up in optimizable indexing compute,
# so we dont have full coverage
return ValueRanges.unknown()
def load(self, name: str, index: sympy.Expr) -> ValueRanges[Any]:
return ValueRanges.unknown()
def store(
self, name: str, index: sympy.Expr, value: Any, mode: StoreMode = None
) -> None:
return
def reduction(
self,
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: ReductionType,
value: Any,
) -> ValueRanges[Any]:
return ValueRanges.unknown()
@classmethod
def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]:
assert isinstance(index, ValueRanges)
return cls.to_dtype(index, dtype)
@staticmethod
def to_dtype(
x: Any,
dtype: torch.dtype,
src_dtype: Optional[torch.dtype] = None,
use_compute_types: bool = True,
) -> ValueRanges[Any]:
x = ValueRanges.wrap(x)
if dtype == torch.bool:
if x.is_singleton():
return ValueRanges.wrap(x.lower != 0)
elif x.is_bool:
return x
elif 0 not in x:
return ValueRanges.wrap(sympy.true)
else:
return ValueRanges(sympy.false, sympy.true)
def cast(x: Any, dtype: torch.dtype) -> sympy.Expr:
# dtype is int or float
if dtype.is_floating_point:
return sympy.Float(x)
else:
if x in (int_oo, -int_oo):
return x
try:
return sympy.Integer(x)
except TypeError:
# inf cannot be cast to Integer
return x
if x.is_bool:
if x.is_singleton():
val = 1 if x.lower else 0
return ValueRanges.wrap(cast(val, dtype))
else:
return ValueRanges(cast(0, dtype), cast(1, dtype))
else:
# int to float or float to int
return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype))
@staticmethod
def square(x: Any) -> ValueRanges[Any]:
return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))
@staticmethod
def neg(x: Any) -> ValueRanges[Any]:
return ValueRanges.decreasing_map(x, operator.neg)
# TODO: this is slightly inaccurate because truncdiv operates at integer
# precision, but we're going through float truediv which means we can
# potentially lose precision on the bounds
@classmethod
def truncdiv(cls, a: Any, b: Any) -> ValueRanges[Any]:
x = cls.truediv(a, b)
if x == ValueRanges.unknown():
return x
return cls.trunc(x)
@classmethod
def sub(cls, a: Any, b: Any) -> ValueRanges[Any]:
return cls.add(a, cls.neg(b))