21 lines
693 B
Python
21 lines
693 B
Python
|
from typing import Callable, TypeVar
|
||
|
|
||
|
|
||
|
F = TypeVar("F")
|
||
|
|
||
|
|
||
|
# Allows one to expose an API in a private submodule publicly as per the definition
|
||
|
# in PyTorch's public api policy.
|
||
|
#
|
||
|
# It is a temporary solution while we figure out if it should be the long-term solution
|
||
|
# or if we should amend PyTorch's public api policy. The concern is that this approach
|
||
|
# may not be very robust because it's not clear what __module__ is used for.
|
||
|
# However, both numpy and jax overwrite the __module__ attribute of their APIs
|
||
|
# without problem, so it seems fine.
|
||
|
def exposed_in(module: str) -> Callable[[F], F]:
|
||
|
def wrapper(fn: F) -> F:
|
||
|
fn.__module__ = module
|
||
|
return fn
|
||
|
|
||
|
return wrapper
|