# mypy: allow-untyped-defs # Unpickler restricted to loading only state dicts # Restrict constructing types to a list defined in _get_allowed_globals() # Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only # Restrict APPEND/APPENDS to `list` # In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary # defined by `_get_allowed_globals()` method, that contains: # - torch types (Storage, dtypes, Tensor, `torch.Size`), # - `torch._utils._rebuild` functions. # - `torch.nn.Parameter` # - `collections.Counter` # - `collections.OrderedDict` # Additionally, users can use an allowlist for adding classes they have deemed as safe using # `_add_safe_globals()` (`torch.serialization.add_safe_globals`) # `_clear_safe_globals()` (`torch.serialization.clear_safe_globals`) # `_get_safe_globals()` (`torch.serialization.get_safe_globals`) # Based of https://github.com/python/cpython/blob/main/Lib/pickle.py # Expected to be useful for loading PyTorch model weights # For example: # data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read() # buf = io.BytesIO(data) # weights = torch.load(buf, weights_only = True) import functools as _functools import warnings from _codecs import encode from collections import Counter, OrderedDict from pickle import ( APPEND, APPENDS, BINFLOAT, BINGET, BININT, BININT1, BININT2, BINPERSID, BINPUT, BINUNICODE, BUILD, bytes_types, decode_long, EMPTY_DICT, EMPTY_LIST, EMPTY_SET, EMPTY_TUPLE, GLOBAL, LONG1, LONG_BINGET, LONG_BINPUT, MARK, NEWFALSE, NEWOBJ, NEWTRUE, NONE, PROTO, REDUCE, SETITEM, SETITEMS, SHORT_BINSTRING, STOP, TUPLE, TUPLE1, TUPLE2, TUPLE3, UnpicklingError, ) from struct import unpack from sys import maxsize from typing import Any, Callable, Union import torch from torch._utils import _sparse_tensors_to_validate, IMPORT_MAPPING, NAME_MAPPING # modules in this list are never allowed, even if the user attempts to allowlist # functions/classes from them _blocklisted_modules = [ "sys", "os", "posix", "nt", ] _marked_safe_globals_set: set[Union[Callable, tuple[Callable, str]]] = set() def _add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]): global _marked_safe_globals_set _marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals)) def _get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]: global _marked_safe_globals_set return list(_marked_safe_globals_set) def _clear_safe_globals(): global _marked_safe_globals_set _marked_safe_globals_set = set() def _remove_safe_globals( globals_to_remove: list[Union[Callable, tuple[Callable, str]]], ): global _marked_safe_globals_set _marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove) class _safe_globals: def __init__(self, safe_globals: list[Union[Callable, tuple[Callable, str]]]): self.safe_globals = safe_globals def __enter__(self): _add_safe_globals(self.safe_globals) def __exit__(self, type, value, tb): _remove_safe_globals(self.safe_globals) # Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals # For example if user had a script like # torch.load(file_a) # torch.serialization._add_safe_globals([torch.foo]) # torch.load(file_b) # the dynamic additions to safe_globals would not be picked up by # _get_allowed_globals due to the lru_cache def _get_user_allowed_globals(): rc: dict[str, Any] = {} for f in _marked_safe_globals_set: if isinstance(f, tuple): if len(f) != 2: raise ValueError( f"Expected tuple of length 2 (global, str of callable full path), but got tuple of length: {len(f)}" ) if type(f[1]) is not str: raise TypeError( f"Expected second item in tuple to be str of callable full path, but got: {type(f[1])}" ) f, name = f rc[name] = f else: module, name = f.__module__, f.__qualname__ rc[f"{module}.{name}"] = f return rc def _tensor_rebuild_functions(): return { torch._utils._rebuild_parameter, torch._utils._rebuild_parameter_with_state, torch._utils._rebuild_qtensor, torch._utils._rebuild_tensor, torch._utils._rebuild_tensor_v2, torch._utils._rebuild_tensor_v3, torch._utils._rebuild_sparse_tensor, torch._utils._rebuild_meta_tensor_no_storage, torch._utils._rebuild_nested_tensor, torch._utils._rebuild_wrapper_subclass, # Allowlisting this, but not allowlisting the numpy functions by default # Reasoning is that we don't have control over the numpy functions, but # this utility is provided by pytorch torch._utils._rebuild_device_tensor_from_numpy, # In 2.6, we should no longer have a dependency on numpy and the above # _rebuild_device_tensor_from_numpy function. torch._utils._rebuild_device_tensor_from_cpu_tensor, } # Unpickling machinery @_functools.lru_cache(maxsize=1) def _get_allowed_globals(): rc: dict[str, Any] = { "collections.OrderedDict": OrderedDict, "collections.Counter": Counter, "torch.nn.parameter.Parameter": torch.nn.Parameter, "torch.serialization._get_layout": torch.serialization._get_layout, "torch.Size": torch.Size, "torch.Tensor": torch.Tensor, "torch.device": torch.device, "_codecs.encode": encode, # for bytes "builtins.bytearray": bytearray, # for bytearray "builtins.set": set, # for set "builtins.complex": complex, # for complex } # dtype for t in torch.storage._dtype_to_storage_type_map().keys(): rc[str(t)] = t for t in torch.storage._new_dtypes(): rc[str(t)] = t for t in [getattr(torch, f"uint{x}") for x in range(1, 8)]: rc[str(t)] = t for t in [getattr(torch, f"int{x}") for x in range(1, 8)]: rc[str(t)] = t # Tensor classes for tt in torch._tensor_classes: rc[f"{tt.__module__}.{tt.__name__}"] = tt # Storage classes for ts in torch._storage_classes: if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage): # Wrap legacy storage types in a dummy class rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType( ts.__name__ ) else: rc[f"{ts.__module__}.{ts.__name__}"] = ts # Quantization specific for qt in [ torch.per_tensor_affine, torch.per_tensor_symmetric, torch.per_channel_affine, torch.per_channel_symmetric, torch.per_channel_affine_float_qparams, ]: rc[str(qt)] = qt # Rebuild functions for f in _tensor_rebuild_functions(): rc[f"torch._utils.{f.__name__}"] = f # Handles Tensor Subclasses, Tensor's with attributes. # NOTE: It calls into above rebuild functions for regular Tensor types. rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2 return rc def _read_global_instruction(readline: Callable) -> tuple[str, str]: module = readline()[:-1].decode("utf-8") name = readline()[:-1].decode("utf-8") # Patch since torch.save default protocol is 2 # users will be running this code in python > 3 if (module, name) in NAME_MAPPING: module, name = NAME_MAPPING[(module, name)] elif module in IMPORT_MAPPING: module = IMPORT_MAPPING[module] return module, name def get_globals_in_pkl(file) -> set[str]: globals_in_checkpoint = set() read = file.read readline = file.readline op_to_bytes_to_read = { NEWOBJ[0]: 0, REDUCE[0]: 0, BUILD[0]: 0, APPEND[0]: 0, APPENDS[0]: 0, SETITEM[0]: 0, SETITEMS[0]: 0, MARK[0]: 0, TUPLE[0]: 0, TUPLE1[0]: 0, TUPLE2[0]: 0, TUPLE3[0]: 0, NONE[0]: 0, NEWFALSE[0]: 0, NEWTRUE[0]: 0, EMPTY_TUPLE[0]: 0, EMPTY_LIST[0]: 0, EMPTY_DICT[0]: 0, EMPTY_SET[0]: 0, BINPERSID[0]: 0, BININT[0]: 4, BININT1[0]: 1, BININT2[0]: 2, BINFLOAT[0]: 8, BINGET[0]: 1, LONG_BINGET[0]: 4, BINPUT[0]: 1, LONG_BINPUT[0]: 4, } while True: key = read(1) if not key: raise EOFError assert isinstance(key, bytes_types) if key[0] == GLOBAL[0]: module, name = _read_global_instruction(readline) globals_in_checkpoint.add(f"{module}.{name}") elif key[0] in op_to_bytes_to_read: bytes_to_read = op_to_bytes_to_read[key[0]] if bytes_to_read: read(bytes_to_read) # ops where bytes to read depends on the data elif key[0] == BINUNICODE[0]: strlen = unpack(" maxsize: raise UnpicklingError("String is too long") read(strlen) elif key[0] in {SHORT_BINSTRING[0], LONG1[0]}: strlen = read(1)[0] read(strlen) # first and last op elif key[0] == PROTO[0]: read(1)[0] elif key[0] == STOP[0]: return globals_in_checkpoint else: raise UnpicklingError(f"Unsupported operand {key[0]}") class Unpickler: def __init__(self, file, *, encoding: str = "bytes"): self.encoding = encoding self.readline = file.readline self.read = file.read self.memo: dict[int, Any] = {} self.proto: int = -1 def load(self): """Read a pickled object representation from the open file. Return the reconstituted object hierarchy specified in the file. """ self.metastack = [] self.stack: list[Any] = [] self.append = self.stack.append read = self.read while True: key = read(1) if not key: raise EOFError assert isinstance(key, bytes_types) # Risky operators if key[0] == GLOBAL[0]: module, name = _read_global_instruction(self.readline) full_path = f"{module}.{name}" if module in _blocklisted_modules: raise UnpicklingError( f"Trying to load unsupported GLOBAL {full_path} whose module {module} is blocked." ) if full_path in _get_allowed_globals(): self.append(_get_allowed_globals()[full_path]) elif full_path in _get_user_allowed_globals(): self.append(_get_user_allowed_globals()[full_path]) elif full_path in ( [ "torch.nested._internal.nested_tensor.NestedTensor", "torch.nested._internal.nested_tensor._rebuild_njt", "torch._dynamo.decorators._DimRange", ] ): raise UnpicklingError( "``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)" ) elif full_path in ( [ "torch.distributed.device_mesh.DeviceMesh", "torch.distributed.tensor._dtensor_spec.DTensorSpec", "torch.distributed.tensor._dtensor_spec.TensorMeta", "torch.distributed.tensor.DTensor", "torch.distributed.tensor.placement_types.Partial", "torch.distributed.tensor.placement_types.Replicate", "torch.distributed.tensor.placement_types.Shard", ] ): raise UnpicklingError( "``torch.distributed.tensor`` must be imported to load DTensors" ) else: builtins_name = "builtins" if ( builtins_name in full_path and builtins_name == full_path[: len(builtins_name)] ): full_path = full_path[len(builtins_name) :] full_path = ( full_path[1:] if len(full_path) > 0 and full_path[0] == "." else builtins_name + full_path ) raise UnpicklingError( f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " f"Please use `torch.serialization.add_safe_globals([{full_path}])` or the " f"`torch.serialization.safe_globals([{full_path}])` context manager to allowlist this global " "if you trust this class/function." ) elif key[0] == NEWOBJ[0]: args = self.stack.pop() cls = self.stack.pop() if cls is torch.nn.Parameter: self.append(torch.nn.Parameter(*args)) elif ( cls in _get_user_allowed_globals().values() or cls in _get_allowed_globals().values() ): result = cls.__new__(cls, *args) if cls in torch._tensor_classes and "sparse" in cls.__module__: _sparse_tensors_to_validate.append(result) self.append(result) else: raise UnpicklingError( "Can only create new object for nn.Parameter or classes allowlisted " f"via `add_safe_globals` but got {cls}" ) elif key[0] == REDUCE[0]: args = self.stack.pop() func = self.stack[-1] if ( func not in _get_allowed_globals().values() and func not in _get_user_allowed_globals().values() ): raise UnpicklingError( f"Trying to call reduce for unrecognized function {func}" ) result = func(*args) if func in torch._tensor_classes and "sparse" in func.__module__: _sparse_tensors_to_validate.append(result) self.stack[-1] = result elif key[0] == BUILD[0]: state = self.stack.pop() inst = self.stack[-1] if type(inst) is torch.Tensor: # Legacy unpickling inst.set_(*state) elif type(inst) is torch.nn.Parameter: inst.__setstate__(state) elif type(inst) is OrderedDict: inst.__dict__.update(state) elif ( type(inst) in _get_user_allowed_globals().values() or type(inst) in _get_allowed_globals().values() ): if hasattr(inst, "__setstate__"): inst.__setstate__(state) else: # mimics load_build in pickle # https://github.com/python/cpython/blob/f0c6fccd08904787a39269367f09f263d496114c/Lib/pickle.py#L1854-L1867 slotstate = None if isinstance(state, tuple) and len(state) == 2: state, slotstate = state if state: inst.__dict__.update(state) if slotstate: for k, v in slotstate.items(): setattr(inst, k, v) else: raise UnpicklingError( "Can only build Tensor, Parameter, OrderedDict or types allowlisted " f"via `add_safe_globals`, but got {type(inst)}" ) # Stack manipulation elif key[0] == APPEND[0]: item = self.stack.pop() list_obj = self.stack[-1] if type(list_obj) is not list: raise UnpicklingError( f"Can only append to lists, but got {type(list_obj)}" ) list_obj.append(item) elif key[0] == APPENDS[0]: items = self.pop_mark() list_obj = self.stack[-1] if type(list_obj) is not list: raise UnpicklingError( f"Can only extend lists, but got {type(list_obj)}" ) list_obj.extend(items) elif key[0] == SETITEM[0]: (v, k) = (self.stack.pop(), self.stack.pop()) self.stack[-1][k] = v elif key[0] == SETITEMS[0]: items = self.pop_mark() for i in range(0, len(items), 2): self.stack[-1][items[i]] = items[i + 1] elif key[0] == MARK[0]: self.metastack.append(self.stack) self.stack = [] self.append = self.stack.append elif key[0] == TUPLE[0]: items = self.pop_mark() self.append(tuple(items)) elif key[0] == TUPLE1[0]: self.stack[-1] = (self.stack[-1],) elif key[0] == TUPLE2[0]: self.stack[-2:] = [(self.stack[-2], self.stack[-1])] elif key[0] == TUPLE3[0]: self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])] # Basic types construction elif key[0] == NONE[0]: self.append(None) elif key[0] == NEWFALSE[0]: self.append(False) elif key[0] == NEWTRUE[0]: self.append(True) elif key[0] == EMPTY_TUPLE[0]: self.append(()) elif key[0] == EMPTY_LIST[0]: self.append([]) elif key[0] == EMPTY_DICT[0]: self.append({}) elif key[0] == EMPTY_SET[0]: self.append(set()) elif key[0] == BININT[0]: self.append(unpack("d", self.read(8))[0]) elif key[0] == BINUNICODE[0]: strlen = unpack(" maxsize: raise UnpicklingError("String is too long") strval = str(read(strlen), "utf-8", "surrogatepass") self.append(strval) elif key[0] == SHORT_BINSTRING[0]: strlen = read(1)[0] strdata = read(strlen) if self.encoding != "bytes": strdata = strdata.decode(self.encoding, "strict") self.append(strdata) elif key[0] == BINPERSID[0]: pid = self.stack.pop() # Only allow persistent load of storage if type(pid) is not tuple and not type(pid) is not int: raise UnpicklingError( f"persistent_load id must be tuple or int, but got {type(pid)}" ) if ( type(pid) is tuple and len(pid) > 0 and torch.serialization._maybe_decode_ascii(pid[0]) != "storage" ): raise UnpicklingError( f"Only persistent_load of storage is allowed, but got {pid[0]}" ) self.append(self.persistent_load(pid)) elif key[0] in [BINGET[0], LONG_BINGET[0]]: idx = (read(1) if key[0] == BINGET[0] else unpack("