from torch._functorch.apis import grad, grad_and_value, vmap from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ from torch._functorch.eager_transforms import ( debug_unwrap, functionalize, hessian, jacfwd, jacrev, jvp, linearize, vjp, ) from torch._functorch.functional_call import functional_call, stack_module_state __all__ = [ "grad", "grad_and_value", "vmap", "replace_all_batch_norm_modules_", "functionalize", "hessian", "jacfwd", "jacrev", "jvp", "linearize", "vjp", "functional_call", "stack_module_state", "debug_unwrap", ]