from typing import Callable, TypeVar F = TypeVar("F") # Allows one to expose an API in a private submodule publicly as per the definition # in PyTorch's public api policy. # # It is a temporary solution while we figure out if it should be the long-term solution # or if we should amend PyTorch's public api policy. The concern is that this approach # may not be very robust because it's not clear what __module__ is used for. # However, both numpy and jax overwrite the __module__ attribute of their APIs # without problem, so it seems fine. def exposed_in(module: str) -> Callable[[F], F]: def wrapper(fn: F) -> F: fn.__module__ = module return fn return wrapper