51 lines
1.4 KiB
Python
51 lines
1.4 KiB
Python
import importlib.machinery
|
|
import os
|
|
|
|
from torch.hub import _get_torch_home
|
|
|
|
|
|
_HOME = os.path.join(_get_torch_home(), "datasets", "vision")
|
|
_USE_SHARDED_DATASETS = False
|
|
IN_FBCODE = False
|
|
|
|
|
|
def _download_file_from_remote_location(fpath: str, url: str) -> None:
|
|
pass
|
|
|
|
|
|
def _is_remote_location_available() -> bool:
|
|
return False
|
|
|
|
|
|
try:
|
|
from torch.hub import load_state_dict_from_url # noqa: 401
|
|
except ImportError:
|
|
from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401
|
|
|
|
|
|
def _get_extension_path(lib_name):
|
|
|
|
lib_dir = os.path.dirname(__file__)
|
|
if os.name == "nt":
|
|
# Register the main torchvision library location on the default DLL path
|
|
import ctypes
|
|
|
|
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
|
|
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
|
|
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
|
|
|
if with_load_library_flags:
|
|
kernel32.AddDllDirectory.restype = ctypes.c_void_p
|
|
|
|
os.add_dll_directory(lib_dir)
|
|
|
|
kernel32.SetErrorMode(prev_error_mode)
|
|
|
|
loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
|
|
|
|
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
|
|
ext_specs = extfinder.find_spec(lib_name)
|
|
if ext_specs is None:
|
|
raise ImportError
|
|
|
|
return ext_specs.origin
|