# mypy: ignore-errors """ This module provides iterator-related variable tracking functionality for Dynamo. It implements variable classes for handling Python iterators and itertools functions during symbolic execution and tracing. The module includes: - Base iterator variable classes for tracking iterator state - Implementations of built-in iterators (zip, map, filter) - Support for itertools functions (product, accumulate, combinations, etc.) - Mutation tracking and reconstruction capabilities for iterator operations These classes integrate with Dynamo's variable tracking system to enable proper handling of iterator operations during code transformation and optimization. """ import itertools import operator import sys from typing import Optional, TYPE_CHECKING, Union from .. import polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction from ..exc import ( handle_observed_exception, ObservedUserStopIteration, raise_observed_exception, unimplemented, UserError, ) from .base import ValueMutationNew, VariableTracker from .constant import ConstantVariable if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator MAX_ITERATOR_LIMIT = 100 * 1024 # 100k class ItertoolsVariable(VariableTracker): def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) self.value = value def __repr__(self) -> str: return f"ItertoolsVariable({self.value})" def as_python_constant(self): return self.value def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": # See also: module `torch._dynamo.polyfills.itertools` if ( self.value is itertools.product and not kwargs and all(arg.has_unpack_var_sequence(tx) for arg in args) ): seqs = [arg.unpack_var_sequence(tx) for arg in args] items = [ variables.TupleVariable(list(item)) for item in itertools.product(*seqs) ] return variables.ListIteratorVariable( items, mutation_type=ValueMutationNew() ) elif self.value is itertools.accumulate: from .builtin import BuiltinVariable if any(key not in ["initial", "func"] for key in kwargs.keys()): unimplemented( "Unsupported kwargs for itertools.accumulate: " f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}" ) acc = kwargs.get("initial") if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx): seq = args[0].unpack_var_sequence(tx) if "func" in kwargs and len(args) == 1: func = kwargs["func"].call_function elif len(args) == 2: func = args[1].call_function elif len(args) == 1: # Default to operator.add func = BuiltinVariable(operator.add).call_function else: unimplemented( "itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg" ) else: unimplemented("Unsupported arguments for itertools.accumulate") items = [] if acc is not None: items.append(acc) for item in seq: if acc is None: acc = item else: try: acc = func(tx, [acc, item], {}) except Exception as e: unimplemented( f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})", from_exc=e, ) items.append(acc) return variables.ListIteratorVariable( items, mutation_type=ValueMutationNew() ) elif ( self.value is itertools.combinations and not kwargs and len(args) == 2 and args[0].has_unpack_var_sequence(tx) and args[1].is_python_constant() ): iterable = args[0].unpack_var_sequence(tx) r = args[1].as_python_constant() items = [] for item in itertools.combinations(iterable, r): items.append(variables.TupleVariable(list(item))) return variables.ListIteratorVariable( items, mutation_type=ValueMutationNew() ) elif self.value is itertools.groupby: if any(kw != "key" for kw in kwargs.keys()): unimplemented( "Unsupported kwargs for itertools.groupby: " f"{','.join(set(kwargs.keys()) - {'key'})}" ) def retrieve_const_key(key): if isinstance(key, variables.SymNodeVariable): return key.evaluate_expr() elif isinstance(key, variables.ConstantVariable): return key.as_python_constant() else: unimplemented( "Unsupported key type for itertools.groupby: " + str(type(key)) ) if len(args) == 1 and args[0].has_unpack_var_sequence(tx): seq = args[0].unpack_var_sequence(tx) else: unimplemented("Unsupported arguments for itertools.groupby") if "key" in kwargs: def keyfunc(x): return retrieve_const_key( kwargs.get("key").call_function(tx, [x], {}) ) else: def keyfunc(x): return retrieve_const_key(x) result = [] try: for k, v in itertools.groupby(seq, key=keyfunc): result.append( variables.TupleVariable( [ variables.ConstantVariable.create(k) if variables.ConstantVariable.is_literal(k) else k, variables.ListIteratorVariable( list(v), mutation_type=ValueMutationNew() ), ], mutation_type=ValueMutationNew(), ) ) except Exception as e: unimplemented( "Unexpected failure when calling itertools.groupby", from_exc=e, ) return variables.ListIteratorVariable( result, mutation_type=ValueMutationNew() ) elif self.value is itertools.repeat: if len(args) < 2: return variables.RepeatIteratorVariable( *args, mutation_type=ValueMutationNew() ) return tx.inline_user_function_return( VariableTracker.build(tx, polyfills.repeat), args, kwargs ) elif self.value is itertools.count: return variables.CountIteratorVariable( *args, mutation_type=ValueMutationNew() ) elif self.value is itertools.cycle: return variables.CycleIteratorVariable( *args, mutation_type=ValueMutationNew() ) else: return super().call_function(tx, args, kwargs) class IteratorVariable(VariableTracker): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) def next_variable(self, tx): unimplemented("abstract method, must implement") # NOTE: only call when unpacking this iterator safely done eagerly! # Normally, iterators are accessed lazily. # Example of safe eager unpacking: list(map(f, seq)) # Example of unsafe eager unpacking: list(islice(map(f, seq), 5)) def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: result = [] while True: try: result.append(self.next_variable(tx)) except ObservedUserStopIteration: handle_observed_exception(tx) break return result # don't call force_unpack_var_sequence since it can mutate # IteratorVariable state! def has_force_unpack_var_sequence(self, tx) -> bool: return True class RepeatIteratorVariable(IteratorVariable): def __init__(self, item: VariableTracker, **kwargs) -> None: super().__init__(**kwargs) self.item = item # Repeat needs no mutation, clone self def next_variable(self, tx): return self.item def reconstruct(self, codegen): codegen.add_push_null( lambda: codegen.extend_output( [ codegen.create_load_python_module(itertools), codegen.create_load_attr("repeat"), ] ) ) codegen(self.item) codegen.extend_output(create_call_function(1, False)) class CountIteratorVariable(IteratorVariable): def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: super().__init__(**kwargs) if not isinstance(item, VariableTracker): item = ConstantVariable.create(item) if not isinstance(step, VariableTracker): step = ConstantVariable.create(step) self.item = item self.step = step def next_variable(self, tx): assert self.is_mutable() old_item = self.item tx.output.side_effects.mutation(self) self.item = self.item.call_method(tx, "__add__", [self.step], {}) return old_item def reconstruct(self, codegen): codegen.add_push_null( lambda: codegen.extend_output( [ codegen.create_load_python_module(itertools), codegen.create_load_attr("count"), ] ) ) codegen(self.item) codegen(self.step) codegen.extend_output(create_call_function(2, False)) class CycleIteratorVariable(IteratorVariable): def __init__( self, iterator: IteratorVariable, saved: Optional[list[VariableTracker]] = None, saved_index: int = 0, item: Optional[VariableTracker] = None, **kwargs, ) -> None: if saved is None: saved = [] super().__init__(**kwargs) self.iterator = iterator self.saved = saved self.saved_index = saved_index self.item = item def next_variable(self, tx): assert self.is_mutable() if self.iterator is not None: try: new_item = self.iterator.next_variable(tx) if len(self.saved) > MAX_ITERATOR_LIMIT: unimplemented( "input iterator to itertools.cycle has too many items" ) tx.output.side_effects.mutation(self) self.saved.append(new_item) self.item = new_item if self.item is None: return self.next_variable(tx) return self.item except ObservedUserStopIteration: handle_observed_exception(tx) self.iterator = None return self.next_variable(tx) elif len(self.saved) > 0: tx.output.side_effects.mutation(self) self.saved_index = (self.saved_index + 1) % len(self.saved) return self.item else: raise_observed_exception(StopIteration, tx) class ZipVariable(IteratorVariable): """ Represents zip(*iterables) """ _nonvar_fields = { "index", "strict", *IteratorVariable._nonvar_fields, } def __init__( self, iterables: list[Union[list[VariableTracker], VariableTracker]], strict: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) assert isinstance(iterables, list) # can be list[Variable] or VariableTracker (with next_variable implemented) self.iterables = iterables self.index = 0 self.strict = strict def python_type(self): return zip def has_unpack_var_sequence(self, tx) -> bool: return all( isinstance(it, list) or it.has_unpack_var_sequence(tx) for it in self.iterables ) def unpack_var_sequence(self, tx) -> list["VariableTracker"]: assert self.has_unpack_var_sequence(tx) iterables = [] for it in self.iterables: if isinstance(it, list): iterables.append(it[self.index :]) else: iterables.append(it.unpack_var_sequence(tx)) kwargs = {"strict": self.strict} if self.strict else {} zipped = zip(*iterables, **kwargs) return [variables.TupleVariable(list(var)) for var in zipped] def next_variable(self, tx): assert self.is_mutable() old_index = self.index args = [] def get_item(it): if isinstance(it, list): if old_index >= len(it): raise_observed_exception(StopIteration, tx) return it[old_index] else: return it.next_variable(tx) try: for idx, it in enumerate(self.iterables): args.append(get_item(it)) except ObservedUserStopIteration: if self.strict: if idx == 0: # all other iterables should be exhausted for it in self.iterables: try: get_item(it) except ObservedUserStopIteration: handle_observed_exception(tx) continue # no ObservedUserStopIteration - fall through to UserError break else: # all iterables exhausted, raise original error raise handle_observed_exception(tx) raise UserError( ValueError, "zip() has one argument of len differing from others", ) from None raise tx.output.side_effects.mutation(self) self.index += 1 return variables.TupleVariable(args) def reconstruct_items(self, codegen): for it in self.iterables: if isinstance(it, list): remaining_items = it[self.index :] codegen.foreach(remaining_items) codegen.append_output( create_instruction("BUILD_TUPLE", arg=len(remaining_items)) ) else: codegen(it) def reconstruct(self, codegen): codegen.add_push_null( lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True ) self.reconstruct_items(codegen) codegen.append_output( create_instruction("BUILD_TUPLE", arg=len(self.iterables)) ) if sys.version_info >= (3, 10): codegen.extend_output( [ codegen.create_load_const("strict"), codegen.create_load_const(self.strict), create_instruction("BUILD_MAP", arg=1), create_instruction("CALL_FUNCTION_EX", arg=1), ] ) else: codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0)) class MapVariable(ZipVariable): """ Represents map(fn, *iterables) """ def __init__( self, fn: VariableTracker, iterables: list[Union[list[VariableTracker], VariableTracker]], **kwargs, ) -> None: super().__init__(iterables, **kwargs) self.fn = fn def python_type(self): return map def has_unpack_var_sequence(self, tx) -> bool: return False def next_variable(self, tx): args = super().next_variable(tx) return self.fn.call_function(tx, args.items, {}) def reconstruct(self, codegen): codegen.add_push_null( lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True ) codegen(self.fn) self.reconstruct_items(codegen) codegen.extend_output( [ create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1), create_instruction("CALL_FUNCTION_EX", arg=0), ] ) class FilterVariable(IteratorVariable): """ Represents filter(fn, iterable) """ _nonvar_fields = { "index", *IteratorVariable._nonvar_fields, } def __init__( self, fn: VariableTracker, iterable: Union[list[VariableTracker], VariableTracker], **kwargs, ) -> None: super().__init__(**kwargs) self.fn = fn self.iterable = iterable self.index = 0 def python_type(self): return filter def has_unpack_var_sequence(self, tx) -> bool: return isinstance(self.iterable, list) or self.iterable.has_unpack_var_sequence( tx ) def unpack_var_sequence(self, tx) -> list["VariableTracker"]: assert self.has_unpack_var_sequence(tx) it = None if isinstance(self.iterable, list): it = self.iterable[self.index :] else: it = self.iterable.unpack_var_sequence(tx) filtered = self.fn.call_function(tx, it, {}) return [variables.TupleVariable([filtered])] def next_variable(self, tx): def _next(): old_index = self.index if isinstance(self.iterable, list): if old_index >= len(self.iterable): raise_observed_exception(StopIteration, tx) return self.iterable[old_index] else: return self.iterable.next_variable(tx) # A do-while loop to find elements that make fn return true while True: item = _next() self.index += 1 res = self.fn.call_function(tx, [item], {}) pred_res = variables.UserFunctionVariable( polyfills.predicate ).call_function(tx, [res], {}) if pred_res.as_python_constant(): return item def reconstruct_items(self, codegen): if isinstance(self.iterable, list): remaining_items = self.iterable[self.index :] codegen.foreach(remaining_items) codegen.append_output( create_instruction("BUILD_TUPLE", arg=len(remaining_items)) ) else: codegen(self.iterable) def reconstruct(self, codegen): codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter")) codegen(self.fn) self.reconstruct_items(codegen) codegen.extend_output(create_call_function(2, False))