155 lines
4.4 KiB
Python
155 lines
4.4 KiB
Python
import pickle
|
|
from dataclasses import dataclass
|
|
from io import BufferedIOBase
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch._weights_only_unpickler as _weights_only_unpickler
|
|
from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION
|
|
|
|
|
|
__all__: list[str] = []
|
|
|
|
|
|
@dataclass
|
|
class _Entry:
|
|
key: str
|
|
is_storage: bool
|
|
length: int
|
|
|
|
|
|
_weights_only_unpickler._add_safe_globals([_Entry])
|
|
|
|
|
|
class _PseudoZipFile:
|
|
def __init__(self) -> None:
|
|
self.records: dict[str, tuple[object, int]] = {}
|
|
|
|
def write_record(self, key: str, data: object, length: int) -> None:
|
|
self.records[key] = (data, length)
|
|
|
|
def write_to(self, f: BufferedIOBase) -> None:
|
|
entries = []
|
|
for key, (data, length) in self.records.items():
|
|
entries.append(
|
|
_Entry(
|
|
key=key,
|
|
is_storage=isinstance(data, torch.UntypedStorage),
|
|
length=length,
|
|
)
|
|
)
|
|
|
|
pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL)
|
|
|
|
for key, (data, length) in self.records.items():
|
|
if isinstance(data, bytes):
|
|
f.write(data)
|
|
elif isinstance(data, str):
|
|
f.write(data.encode("utf-8"))
|
|
elif isinstance(data, torch.UntypedStorage):
|
|
data._write_file(f, False, False, 1)
|
|
else:
|
|
raise TypeError(f"unknown type: {type(data)}")
|
|
|
|
def read_from(self, f: BufferedIOBase) -> None:
|
|
entries = _weights_only_unpickler.load(f)
|
|
|
|
for entry in entries:
|
|
data = f.read(entry.length)
|
|
if entry.is_storage:
|
|
storage = torch.frombuffer(
|
|
data,
|
|
dtype=torch.uint8,
|
|
).untyped_storage()
|
|
|
|
self.records[entry.key] = (
|
|
storage,
|
|
entry.length,
|
|
)
|
|
else:
|
|
self.records[entry.key] = (data, entry.length)
|
|
|
|
def has_record(self, key: str) -> bool:
|
|
return key in self.records
|
|
|
|
def get_record(self, key: str) -> object:
|
|
return self.records[key][0]
|
|
|
|
def get_storage_from_record(
|
|
self, key: str, _length: int, _type: int
|
|
) -> torch.Tensor:
|
|
return torch.tensor(self.records[key][0], dtype=torch.uint8)
|
|
|
|
def serialization_id(self) -> str:
|
|
return "torchft"
|
|
|
|
|
|
def _streaming_save(
|
|
obj: object,
|
|
f: BufferedIOBase,
|
|
pickle_module: Any = pickle,
|
|
pickle_protocol: int = DEFAULT_PROTOCOL,
|
|
) -> None:
|
|
"""
|
|
Save the object to a file-like object in a streaming fashion compatible with
|
|
network sockets.
|
|
|
|
This behaves similarly to :func:`torch.save` with a few notable differences:
|
|
|
|
* A non-seekable file like object can be used when loading.
|
|
* No forwards/backwards compatiblity is provided for the serialization
|
|
format. This is only intended to be used with a single version of PyTorch
|
|
with transient storage (i.e. sockets or temp files).
|
|
* mmap is not supported
|
|
|
|
See :func:`torch.save` for more details on specific arguments.
|
|
"""
|
|
|
|
zip_file = _PseudoZipFile()
|
|
_save(
|
|
obj,
|
|
zip_file=zip_file,
|
|
pickle_module=pickle_module,
|
|
pickle_protocol=pickle_protocol,
|
|
_disable_byteorder_record=False,
|
|
)
|
|
zip_file.write_to(f)
|
|
|
|
|
|
def _streaming_load(
|
|
f: BufferedIOBase,
|
|
map_location: MAP_LOCATION = None,
|
|
pickle_module: Any = None,
|
|
*,
|
|
weights_only: bool = True,
|
|
**pickle_load_args: Any,
|
|
) -> object:
|
|
"""
|
|
Load the object from a file-like object in a streaming fashion compatible with
|
|
network sockets.
|
|
|
|
See :func:`_streaming_save` for more details about the streaming behavior.
|
|
|
|
See :func:`torch.load` for more details on specific arguments.
|
|
"""
|
|
if weights_only:
|
|
if pickle_module is not None:
|
|
raise RuntimeError(
|
|
"Can not safely load weights when explicit pickle_module is specified"
|
|
)
|
|
pickle_module = _weights_only_unpickler
|
|
else:
|
|
if pickle_module is None:
|
|
pickle_module = pickle
|
|
|
|
if "encoding" not in pickle_load_args.keys():
|
|
pickle_load_args["encoding"] = "utf-8"
|
|
|
|
zip_file = _PseudoZipFile()
|
|
zip_file.read_from(f)
|
|
return _load(
|
|
zip_file=zip_file,
|
|
map_location=map_location,
|
|
pickle_module=pickle_module,
|
|
**pickle_load_args,
|
|
)
|