1225 lines
41 KiB
Python
1225 lines
41 KiB
Python
"""Test the numpy pickler as a replacement of the standard pickler."""
|
|
|
|
import bz2
|
|
import copy
|
|
import gzip
|
|
import io
|
|
import mmap
|
|
import os
|
|
import pickle
|
|
import random
|
|
import re
|
|
import socket
|
|
import sys
|
|
import warnings
|
|
import zlib
|
|
from contextlib import closing
|
|
from pathlib import Path
|
|
|
|
try:
|
|
import lzma
|
|
except ImportError:
|
|
lzma = None
|
|
|
|
import pytest
|
|
|
|
# numpy_pickle is not a drop-in replacement of pickle, as it takes
|
|
# filenames instead of open files as arguments.
|
|
from joblib import numpy_pickle, register_compressor
|
|
from joblib.compressor import (
|
|
_COMPRESSORS,
|
|
_LZ4_PREFIX,
|
|
LZ4_NOT_INSTALLED_ERROR,
|
|
BinaryZlibFile,
|
|
CompressorWrapper,
|
|
)
|
|
from joblib.numpy_pickle_utils import (
|
|
_IO_BUFFER_SIZE,
|
|
_detect_compressor,
|
|
_ensure_native_byte_order,
|
|
_is_numpy_array_byte_order_mismatch,
|
|
)
|
|
from joblib.test import data
|
|
from joblib.test.common import (
|
|
memory_used,
|
|
np,
|
|
with_lz4,
|
|
with_memory_profiler,
|
|
with_numpy,
|
|
without_lz4,
|
|
)
|
|
from joblib.testing import parametrize, raises, warns
|
|
|
|
###############################################################################
|
|
# Define a list of standard types.
|
|
# Borrowed from dill, initial author: Micheal McKerns:
|
|
# http://dev.danse.us/trac/pathos/browser/dill/dill_test2.py
|
|
|
|
typelist = []
|
|
|
|
# testing types
|
|
_none = None
|
|
typelist.append(_none)
|
|
_type = type
|
|
typelist.append(_type)
|
|
_bool = bool(1)
|
|
typelist.append(_bool)
|
|
_int = int(1)
|
|
typelist.append(_int)
|
|
_float = float(1)
|
|
typelist.append(_float)
|
|
_complex = complex(1)
|
|
typelist.append(_complex)
|
|
_string = str(1)
|
|
typelist.append(_string)
|
|
_tuple = ()
|
|
typelist.append(_tuple)
|
|
_list = []
|
|
typelist.append(_list)
|
|
_dict = {}
|
|
typelist.append(_dict)
|
|
_builtin = len
|
|
typelist.append(_builtin)
|
|
|
|
|
|
def _function(x):
|
|
yield x
|
|
|
|
|
|
class _class:
|
|
def _method(self):
|
|
pass
|
|
|
|
|
|
class _newclass(object):
|
|
def _method(self):
|
|
pass
|
|
|
|
|
|
typelist.append(_function)
|
|
typelist.append(_class)
|
|
typelist.append(_newclass) # <type 'type'>
|
|
_instance = _class()
|
|
typelist.append(_instance)
|
|
_object = _newclass()
|
|
typelist.append(_object) # <type 'class'>
|
|
|
|
|
|
###############################################################################
|
|
# Tests
|
|
|
|
|
|
@parametrize("compress", [0, 1])
|
|
@parametrize("member", typelist)
|
|
def test_standard_types(tmpdir, compress, member):
|
|
# Test pickling and saving with standard types.
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
numpy_pickle.dump(member, filename, compress=compress)
|
|
_member = numpy_pickle.load(filename)
|
|
# We compare the pickled instance to the reloaded one only if it
|
|
# can be compared to a copied one
|
|
if member == copy.deepcopy(member):
|
|
assert member == _member
|
|
|
|
|
|
def test_value_error():
|
|
# Test inverting the input arguments to dump
|
|
with raises(ValueError):
|
|
numpy_pickle.dump("foo", dict())
|
|
|
|
|
|
@parametrize("wrong_compress", [-1, 10, dict()])
|
|
def test_compress_level_error(wrong_compress):
|
|
# Verify that passing an invalid compress argument raises an error.
|
|
exception_msg = 'Non valid compress level given: "{0}"'.format(wrong_compress)
|
|
with raises(ValueError) as excinfo:
|
|
numpy_pickle.dump("dummy", "foo", compress=wrong_compress)
|
|
excinfo.match(exception_msg)
|
|
|
|
|
|
@with_numpy
|
|
@parametrize("compress", [False, True, 0, 3, "zlib"])
|
|
def test_numpy_persistence(tmpdir, compress):
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
rnd = np.random.RandomState(0)
|
|
a = rnd.random_sample((10, 2))
|
|
# We use 'a.T' to have a non C-contiguous array.
|
|
for index, obj in enumerate(((a,), (a.T,), (a, a), [a, a, a])):
|
|
filenames = numpy_pickle.dump(obj, filename, compress=compress)
|
|
|
|
# All is cached in one file
|
|
assert len(filenames) == 1
|
|
# Check that only one file was created
|
|
assert filenames[0] == filename
|
|
# Check that this file does exist
|
|
assert os.path.exists(filenames[0])
|
|
|
|
# Unpickle the object
|
|
obj_ = numpy_pickle.load(filename)
|
|
# Check that the items are indeed arrays
|
|
for item in obj_:
|
|
assert isinstance(item, np.ndarray)
|
|
# And finally, check that all the values are equal.
|
|
np.testing.assert_array_equal(np.array(obj), np.array(obj_))
|
|
|
|
# Now test with an array subclass
|
|
obj = np.memmap(filename + "mmap", mode="w+", shape=4, dtype=np.float64)
|
|
filenames = numpy_pickle.dump(obj, filename, compress=compress)
|
|
# All is cached in one file
|
|
assert len(filenames) == 1
|
|
|
|
obj_ = numpy_pickle.load(filename)
|
|
if type(obj) is not np.memmap and hasattr(obj, "__array_prepare__"):
|
|
# We don't reconstruct memmaps
|
|
assert isinstance(obj_, type(obj))
|
|
|
|
np.testing.assert_array_equal(obj_, obj)
|
|
|
|
# Test with an object containing multiple numpy arrays
|
|
obj = ComplexTestObject()
|
|
filenames = numpy_pickle.dump(obj, filename, compress=compress)
|
|
# All is cached in one file
|
|
assert len(filenames) == 1
|
|
|
|
obj_loaded = numpy_pickle.load(filename)
|
|
assert isinstance(obj_loaded, type(obj))
|
|
np.testing.assert_array_equal(obj_loaded.array_float, obj.array_float)
|
|
np.testing.assert_array_equal(obj_loaded.array_int, obj.array_int)
|
|
np.testing.assert_array_equal(obj_loaded.array_obj, obj.array_obj)
|
|
|
|
|
|
@with_numpy
|
|
def test_numpy_persistence_bufferred_array_compression(tmpdir):
|
|
big_array = np.ones((_IO_BUFFER_SIZE + 100), dtype=np.uint8)
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
numpy_pickle.dump(big_array, filename, compress=True)
|
|
arr_reloaded = numpy_pickle.load(filename)
|
|
|
|
np.testing.assert_array_equal(big_array, arr_reloaded)
|
|
|
|
|
|
@with_numpy
|
|
def test_memmap_persistence(tmpdir):
|
|
rnd = np.random.RandomState(0)
|
|
a = rnd.random_sample(10)
|
|
filename = tmpdir.join("test1.pkl").strpath
|
|
numpy_pickle.dump(a, filename)
|
|
b = numpy_pickle.load(filename, mmap_mode="r")
|
|
|
|
assert isinstance(b, np.memmap)
|
|
|
|
# Test with an object containing multiple numpy arrays
|
|
filename = tmpdir.join("test2.pkl").strpath
|
|
obj = ComplexTestObject()
|
|
numpy_pickle.dump(obj, filename)
|
|
obj_loaded = numpy_pickle.load(filename, mmap_mode="r")
|
|
assert isinstance(obj_loaded, type(obj))
|
|
assert isinstance(obj_loaded.array_float, np.memmap)
|
|
assert not obj_loaded.array_float.flags.writeable
|
|
assert isinstance(obj_loaded.array_int, np.memmap)
|
|
assert not obj_loaded.array_int.flags.writeable
|
|
# Memory map not allowed for numpy object arrays
|
|
assert not isinstance(obj_loaded.array_obj, np.memmap)
|
|
np.testing.assert_array_equal(obj_loaded.array_float, obj.array_float)
|
|
np.testing.assert_array_equal(obj_loaded.array_int, obj.array_int)
|
|
np.testing.assert_array_equal(obj_loaded.array_obj, obj.array_obj)
|
|
|
|
# Test we can write in memmapped arrays
|
|
obj_loaded = numpy_pickle.load(filename, mmap_mode="r+")
|
|
assert obj_loaded.array_float.flags.writeable
|
|
obj_loaded.array_float[0:10] = 10.0
|
|
assert obj_loaded.array_int.flags.writeable
|
|
obj_loaded.array_int[0:10] = 10
|
|
|
|
obj_reloaded = numpy_pickle.load(filename, mmap_mode="r")
|
|
np.testing.assert_array_equal(obj_reloaded.array_float, obj_loaded.array_float)
|
|
np.testing.assert_array_equal(obj_reloaded.array_int, obj_loaded.array_int)
|
|
|
|
# Test w+ mode is caught and the mode has switched to r+
|
|
numpy_pickle.load(filename, mmap_mode="w+")
|
|
assert obj_loaded.array_int.flags.writeable
|
|
assert obj_loaded.array_int.mode == "r+"
|
|
assert obj_loaded.array_float.flags.writeable
|
|
assert obj_loaded.array_float.mode == "r+"
|
|
|
|
|
|
@with_numpy
|
|
def test_memmap_persistence_mixed_dtypes(tmpdir):
|
|
# loading datastructures that have sub-arrays with dtype=object
|
|
# should not prevent memmapping on fixed size dtype sub-arrays.
|
|
rnd = np.random.RandomState(0)
|
|
a = rnd.random_sample(10)
|
|
b = np.array([1, "b"], dtype=object)
|
|
construct = (a, b)
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
numpy_pickle.dump(construct, filename)
|
|
a_clone, b_clone = numpy_pickle.load(filename, mmap_mode="r")
|
|
|
|
# the floating point array has been memory mapped
|
|
assert isinstance(a_clone, np.memmap)
|
|
|
|
# the object-dtype array has been loaded in memory
|
|
assert not isinstance(b_clone, np.memmap)
|
|
|
|
|
|
@with_numpy
|
|
def test_masked_array_persistence(tmpdir):
|
|
# The special-case picker fails, because saving masked_array
|
|
# not implemented, but it just delegates to the standard pickler.
|
|
rnd = np.random.RandomState(0)
|
|
a = rnd.random_sample(10)
|
|
a = np.ma.masked_greater(a, 0.5)
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
numpy_pickle.dump(a, filename)
|
|
b = numpy_pickle.load(filename, mmap_mode="r")
|
|
assert isinstance(b, np.ma.masked_array)
|
|
|
|
|
|
@with_numpy
|
|
def test_compress_mmap_mode_warning(tmpdir):
|
|
# Test the warning in case of compress + mmap_mode
|
|
rnd = np.random.RandomState(0)
|
|
obj = rnd.random_sample(10)
|
|
this_filename = tmpdir.join("test.pkl").strpath
|
|
numpy_pickle.dump(obj, this_filename, compress=1)
|
|
with warns(UserWarning) as warninfo:
|
|
reloaded_obj = numpy_pickle.load(this_filename, mmap_mode="r+")
|
|
debug_msg = "\n".join([str(w) for w in warninfo])
|
|
warninfo = [w.message for w in warninfo]
|
|
assert not isinstance(reloaded_obj, np.memmap)
|
|
np.testing.assert_array_equal(obj, reloaded_obj)
|
|
assert len(warninfo) == 1, debug_msg
|
|
assert (
|
|
str(warninfo[0]) == 'mmap_mode "r+" is not compatible with compressed '
|
|
f'file {this_filename}. "r+" flag will be ignored.'
|
|
)
|
|
|
|
|
|
@with_numpy
|
|
@with_memory_profiler
|
|
@parametrize("compress", [True, False])
|
|
def test_memory_usage(tmpdir, compress):
|
|
# Verify memory stays within expected bounds.
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
small_array = np.ones((10, 10))
|
|
big_array = np.ones(shape=100 * int(1e6), dtype=np.uint8)
|
|
|
|
for obj in (small_array, big_array):
|
|
size = obj.nbytes / 1e6
|
|
obj_filename = filename + str(np.random.randint(0, 1000))
|
|
mem_used = memory_used(numpy_pickle.dump, obj, obj_filename, compress=compress)
|
|
|
|
# The memory used to dump the object shouldn't exceed the buffer
|
|
# size used to write array chunks (16MB).
|
|
write_buf_size = _IO_BUFFER_SIZE + 16 * 1024**2 / 1e6
|
|
assert mem_used <= write_buf_size
|
|
|
|
mem_used = memory_used(numpy_pickle.load, obj_filename)
|
|
# memory used should be less than array size + buffer size used to
|
|
# read the array chunk by chunk.
|
|
read_buf_size = 32 + _IO_BUFFER_SIZE # MiB
|
|
assert mem_used < size + read_buf_size
|
|
|
|
|
|
@with_numpy
|
|
def test_compressed_pickle_dump_and_load(tmpdir):
|
|
expected_list = [
|
|
np.arange(5, dtype=np.dtype("<i8")),
|
|
np.arange(5, dtype=np.dtype(">i8")),
|
|
np.arange(5, dtype=np.dtype("<f8")),
|
|
np.arange(5, dtype=np.dtype(">f8")),
|
|
np.array([1, "abc", {"a": 1, "b": 2}], dtype="O"),
|
|
np.arange(256, dtype=np.uint8).tobytes(),
|
|
"C'est l'\xe9t\xe9 !",
|
|
]
|
|
|
|
fname = tmpdir.join("temp.pkl.gz").strpath
|
|
|
|
dumped_filenames = numpy_pickle.dump(expected_list, fname, compress=1)
|
|
assert len(dumped_filenames) == 1
|
|
result_list = numpy_pickle.load(fname)
|
|
for result, expected in zip(result_list, expected_list):
|
|
if isinstance(expected, np.ndarray):
|
|
expected = _ensure_native_byte_order(expected)
|
|
assert result.dtype == expected.dtype
|
|
np.testing.assert_equal(result, expected)
|
|
else:
|
|
assert result == expected
|
|
|
|
|
|
@with_numpy
|
|
def test_memmap_load(tmpdir):
|
|
little_endian_dtype = np.dtype("<i8")
|
|
big_endian_dtype = np.dtype(">i8")
|
|
all_dtypes = (little_endian_dtype, big_endian_dtype)
|
|
|
|
le_array = np.arange(5, dtype=little_endian_dtype)
|
|
be_array = np.arange(5, dtype=big_endian_dtype)
|
|
|
|
fname = tmpdir.join("temp.pkl").strpath
|
|
|
|
numpy_pickle.dump([le_array, be_array], fname)
|
|
|
|
le_array_native_load, be_array_native_load = numpy_pickle.load(
|
|
fname, ensure_native_byte_order=True
|
|
)
|
|
|
|
assert le_array_native_load.dtype == be_array_native_load.dtype
|
|
assert le_array_native_load.dtype in all_dtypes
|
|
|
|
le_array_nonnative_load, be_array_nonnative_load = numpy_pickle.load(
|
|
fname, ensure_native_byte_order=False
|
|
)
|
|
|
|
assert le_array_nonnative_load.dtype == le_array.dtype
|
|
assert be_array_nonnative_load.dtype == be_array.dtype
|
|
|
|
|
|
def test_invalid_parameters_raise():
|
|
expected_msg = (
|
|
"Native byte ordering can only be enforced if 'mmap_mode' parameter "
|
|
"is set to None, but got 'mmap_mode=r+' instead."
|
|
)
|
|
|
|
with raises(ValueError, match=re.escape(expected_msg)):
|
|
numpy_pickle.load(
|
|
"/path/to/some/dump.pkl", ensure_native_byte_order=True, mmap_mode="r+"
|
|
)
|
|
|
|
|
|
def _check_pickle(filename, expected_list, mmap_mode=None):
|
|
"""Helper function to test joblib pickle content.
|
|
|
|
Note: currently only pickles containing an iterable are supported
|
|
by this function.
|
|
"""
|
|
version_match = re.match(r".+py(\d)(\d).+", filename)
|
|
py_version_used_for_writing = int(version_match.group(1))
|
|
|
|
py_version_to_default_pickle_protocol = {2: 2, 3: 3}
|
|
pickle_reading_protocol = py_version_to_default_pickle_protocol.get(3, 4)
|
|
pickle_writing_protocol = py_version_to_default_pickle_protocol.get(
|
|
py_version_used_for_writing, 4
|
|
)
|
|
if pickle_reading_protocol >= pickle_writing_protocol:
|
|
try:
|
|
with warnings.catch_warnings(record=True) as warninfo:
|
|
warnings.simplefilter("always")
|
|
result_list = numpy_pickle.load(filename, mmap_mode=mmap_mode)
|
|
filename_base = os.path.basename(filename)
|
|
expected_nb_deprecation_warnings = (
|
|
1 if ("_0.9" in filename_base or "_0.8.4" in filename_base) else 0
|
|
)
|
|
|
|
expected_nb_user_warnings = (
|
|
3
|
|
if (re.search("_0.1.+.pkl$", filename_base) and mmap_mode is not None)
|
|
else 0
|
|
)
|
|
expected_nb_warnings = (
|
|
expected_nb_deprecation_warnings + expected_nb_user_warnings
|
|
)
|
|
assert len(warninfo) == expected_nb_warnings, (
|
|
"Did not get the expected number of warnings. Expected "
|
|
f"{expected_nb_warnings} but got warnings: "
|
|
f"{[w.message for w in warninfo]}"
|
|
)
|
|
|
|
deprecation_warnings = [
|
|
w for w in warninfo if issubclass(w.category, DeprecationWarning)
|
|
]
|
|
user_warnings = [w for w in warninfo if issubclass(w.category, UserWarning)]
|
|
for w in deprecation_warnings:
|
|
assert (
|
|
str(w.message)
|
|
== "The file '{0}' has been generated with a joblib "
|
|
"version less than 0.10. Please regenerate this "
|
|
"pickle file.".format(filename)
|
|
)
|
|
|
|
for w in user_warnings:
|
|
escaped_filename = re.escape(filename)
|
|
assert re.search(
|
|
f"memmapped.+{escaped_filename}.+segmentation fault", str(w.message)
|
|
)
|
|
|
|
for result, expected in zip(result_list, expected_list):
|
|
if isinstance(expected, np.ndarray):
|
|
expected = _ensure_native_byte_order(expected)
|
|
assert result.dtype == expected.dtype
|
|
np.testing.assert_equal(result, expected)
|
|
else:
|
|
assert result == expected
|
|
except Exception as exc:
|
|
# When trying to read with python 3 a pickle generated
|
|
# with python 2 we expect a user-friendly error
|
|
if py_version_used_for_writing == 2:
|
|
assert isinstance(exc, ValueError)
|
|
message = (
|
|
"You may be trying to read with "
|
|
"python 3 a joblib pickle generated with python 2."
|
|
)
|
|
assert message in str(exc)
|
|
elif filename.endswith(".lz4") and with_lz4.args[0]:
|
|
assert isinstance(exc, ValueError)
|
|
assert LZ4_NOT_INSTALLED_ERROR in str(exc)
|
|
else:
|
|
raise
|
|
else:
|
|
# Pickle protocol used for writing is too high. We expect a
|
|
# "unsupported pickle protocol" error message
|
|
try:
|
|
numpy_pickle.load(filename)
|
|
raise AssertionError(
|
|
"Numpy pickle loading should have raised a ValueError exception"
|
|
)
|
|
except ValueError as e:
|
|
message = "unsupported pickle protocol: {0}".format(pickle_writing_protocol)
|
|
assert message in str(e.args)
|
|
|
|
|
|
@with_numpy
|
|
def test_joblib_pickle_across_python_versions():
|
|
# We need to be specific about dtypes in particular endianness
|
|
# because the pickles can be generated on one architecture and
|
|
# the tests run on another one. See
|
|
# https://github.com/joblib/joblib/issues/279.
|
|
expected_list = [
|
|
np.arange(5, dtype=np.dtype("<i8")),
|
|
np.arange(5, dtype=np.dtype("<f8")),
|
|
np.array([1, "abc", {"a": 1, "b": 2}], dtype="O"),
|
|
np.arange(256, dtype=np.uint8).tobytes(),
|
|
# np.matrix is a subclass of np.ndarray, here we want
|
|
# to verify this type of object is correctly unpickled
|
|
# among versions.
|
|
np.matrix([0, 1, 2], dtype=np.dtype("<i8")),
|
|
"C'est l'\xe9t\xe9 !",
|
|
]
|
|
|
|
# Testing all the compressed and non compressed
|
|
# pickles in joblib/test/data. These pickles were generated by
|
|
# the joblib/test/data/create_numpy_pickle.py script for the
|
|
# relevant python, joblib and numpy versions.
|
|
test_data_dir = os.path.dirname(os.path.abspath(data.__file__))
|
|
|
|
pickle_extensions = (".pkl", ".gz", ".gzip", ".bz2", "lz4")
|
|
if lzma is not None:
|
|
pickle_extensions += (".xz", ".lzma")
|
|
pickle_filenames = [
|
|
os.path.join(test_data_dir, fn)
|
|
for fn in os.listdir(test_data_dir)
|
|
if any(fn.endswith(ext) for ext in pickle_extensions)
|
|
]
|
|
|
|
for fname in pickle_filenames:
|
|
_check_pickle(fname, expected_list)
|
|
|
|
|
|
@with_numpy
|
|
def test_joblib_pickle_across_python_versions_with_mmap():
|
|
expected_list = [
|
|
np.arange(5, dtype=np.dtype("<i8")),
|
|
np.arange(5, dtype=np.dtype("<f8")),
|
|
np.array([1, "abc", {"a": 1, "b": 2}], dtype="O"),
|
|
np.arange(256, dtype=np.uint8).tobytes(),
|
|
# np.matrix is a subclass of np.ndarray, here we want
|
|
# to verify this type of object is correctly unpickled
|
|
# among versions.
|
|
np.matrix([0, 1, 2], dtype=np.dtype("<i8")),
|
|
"C'est l'\xe9t\xe9 !",
|
|
]
|
|
|
|
test_data_dir = os.path.dirname(os.path.abspath(data.__file__))
|
|
|
|
pickle_filenames = [
|
|
os.path.join(test_data_dir, fn)
|
|
for fn in os.listdir(test_data_dir)
|
|
if fn.endswith(".pkl")
|
|
]
|
|
for fname in pickle_filenames:
|
|
_check_pickle(fname, expected_list, mmap_mode="r")
|
|
|
|
|
|
@with_numpy
|
|
def test_numpy_array_byte_order_mismatch_detection():
|
|
# List of numpy arrays with big endian byteorder.
|
|
be_arrays = [
|
|
np.array([(1, 2.0), (3, 4.0)], dtype=[("", ">i8"), ("", ">f8")]),
|
|
np.arange(3, dtype=np.dtype(">i8")),
|
|
np.arange(3, dtype=np.dtype(">f8")),
|
|
]
|
|
|
|
# Verify the byteorder mismatch is correctly detected.
|
|
for array in be_arrays:
|
|
if sys.byteorder == "big":
|
|
assert not _is_numpy_array_byte_order_mismatch(array)
|
|
else:
|
|
assert _is_numpy_array_byte_order_mismatch(array)
|
|
converted = _ensure_native_byte_order(array)
|
|
if converted.dtype.fields:
|
|
for f in converted.dtype.fields.values():
|
|
f[0].byteorder == "="
|
|
else:
|
|
assert converted.dtype.byteorder == "="
|
|
|
|
# List of numpy arrays with little endian byteorder.
|
|
le_arrays = [
|
|
np.array([(1, 2.0), (3, 4.0)], dtype=[("", "<i8"), ("", "<f8")]),
|
|
np.arange(3, dtype=np.dtype("<i8")),
|
|
np.arange(3, dtype=np.dtype("<f8")),
|
|
]
|
|
|
|
# Verify the byteorder mismatch is correctly detected.
|
|
for array in le_arrays:
|
|
if sys.byteorder == "little":
|
|
assert not _is_numpy_array_byte_order_mismatch(array)
|
|
else:
|
|
assert _is_numpy_array_byte_order_mismatch(array)
|
|
converted = _ensure_native_byte_order(array)
|
|
if converted.dtype.fields:
|
|
for f in converted.dtype.fields.values():
|
|
f[0].byteorder == "="
|
|
else:
|
|
assert converted.dtype.byteorder == "="
|
|
|
|
|
|
@parametrize("compress_tuple", [("zlib", 3), ("gzip", 3)])
|
|
def test_compress_tuple_argument(tmpdir, compress_tuple):
|
|
# Verify the tuple is correctly taken into account.
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
numpy_pickle.dump("dummy", filename, compress=compress_tuple)
|
|
# Verify the file contains the right magic number
|
|
with open(filename, "rb") as f:
|
|
assert _detect_compressor(f) == compress_tuple[0]
|
|
|
|
|
|
@parametrize(
|
|
"compress_tuple,message",
|
|
[
|
|
(
|
|
("zlib", 3, "extra"), # wrong compress tuple
|
|
"Compress argument tuple should contain exactly 2 elements",
|
|
),
|
|
(
|
|
("wrong", 3), # wrong compress method
|
|
'Non valid compression method given: "{}"'.format("wrong"),
|
|
),
|
|
(
|
|
("zlib", "wrong"), # wrong compress level
|
|
'Non valid compress level given: "{}"'.format("wrong"),
|
|
),
|
|
],
|
|
)
|
|
def test_compress_tuple_argument_exception(tmpdir, compress_tuple, message):
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
# Verify setting a wrong compress tuple raises a ValueError.
|
|
with raises(ValueError) as excinfo:
|
|
numpy_pickle.dump("dummy", filename, compress=compress_tuple)
|
|
excinfo.match(message)
|
|
|
|
|
|
@parametrize("compress_string", ["zlib", "gzip"])
|
|
def test_compress_string_argument(tmpdir, compress_string):
|
|
# Verify the string is correctly taken into account.
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
numpy_pickle.dump("dummy", filename, compress=compress_string)
|
|
# Verify the file contains the right magic number
|
|
with open(filename, "rb") as f:
|
|
assert _detect_compressor(f) == compress_string
|
|
|
|
|
|
@with_numpy
|
|
@parametrize("compress", [1, 3, 6])
|
|
@parametrize("cmethod", _COMPRESSORS)
|
|
def test_joblib_compression_formats(tmpdir, compress, cmethod):
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
objects = (
|
|
np.ones(shape=(100, 100), dtype="f8"),
|
|
range(10),
|
|
{"a": 1, 2: "b"},
|
|
[],
|
|
(),
|
|
{},
|
|
0,
|
|
1.0,
|
|
)
|
|
|
|
if cmethod in ("lzma", "xz") and lzma is None:
|
|
pytest.skip("lzma is support not available")
|
|
|
|
elif cmethod == "lz4" and with_lz4.args[0]:
|
|
# Skip the test if lz4 is not installed. We here use the with_lz4
|
|
# skipif fixture whose argument is True when lz4 is not installed
|
|
pytest.skip("lz4 is not installed.")
|
|
|
|
dump_filename = filename + "." + cmethod
|
|
for obj in objects:
|
|
numpy_pickle.dump(obj, dump_filename, compress=(cmethod, compress))
|
|
# Verify the file contains the right magic number
|
|
with open(dump_filename, "rb") as f:
|
|
assert _detect_compressor(f) == cmethod
|
|
# Verify the reloaded object is correct
|
|
obj_reloaded = numpy_pickle.load(dump_filename)
|
|
assert isinstance(obj_reloaded, type(obj))
|
|
if isinstance(obj, np.ndarray):
|
|
np.testing.assert_array_equal(obj_reloaded, obj)
|
|
else:
|
|
assert obj_reloaded == obj
|
|
|
|
|
|
def _gzip_file_decompress(source_filename, target_filename):
|
|
"""Decompress a gzip file."""
|
|
with closing(gzip.GzipFile(source_filename, "rb")) as fo:
|
|
buf = fo.read()
|
|
|
|
with open(target_filename, "wb") as fo:
|
|
fo.write(buf)
|
|
|
|
|
|
def _zlib_file_decompress(source_filename, target_filename):
|
|
"""Decompress a zlib file."""
|
|
with open(source_filename, "rb") as fo:
|
|
buf = zlib.decompress(fo.read())
|
|
|
|
with open(target_filename, "wb") as fo:
|
|
fo.write(buf)
|
|
|
|
|
|
@parametrize(
|
|
"extension,decompress",
|
|
[(".z", _zlib_file_decompress), (".gz", _gzip_file_decompress)],
|
|
)
|
|
def test_load_externally_decompressed_files(tmpdir, extension, decompress):
|
|
# Test that BinaryZlibFile generates valid gzip and zlib compressed files.
|
|
obj = "a string to persist"
|
|
filename_raw = tmpdir.join("test.pkl").strpath
|
|
|
|
filename_compressed = filename_raw + extension
|
|
# Use automatic extension detection to compress with the right method.
|
|
numpy_pickle.dump(obj, filename_compressed)
|
|
|
|
# Decompress with the corresponding method
|
|
decompress(filename_compressed, filename_raw)
|
|
|
|
# Test that the uncompressed pickle can be loaded and
|
|
# that the result is correct.
|
|
obj_reloaded = numpy_pickle.load(filename_raw)
|
|
assert obj == obj_reloaded
|
|
|
|
|
|
@parametrize(
|
|
"extension,cmethod",
|
|
# valid compressor extensions
|
|
[
|
|
(".z", "zlib"),
|
|
(".gz", "gzip"),
|
|
(".bz2", "bz2"),
|
|
(".lzma", "lzma"),
|
|
(".xz", "xz"),
|
|
# invalid compressor extensions
|
|
(".pkl", "not-compressed"),
|
|
("", "not-compressed"),
|
|
],
|
|
)
|
|
def test_compression_using_file_extension(tmpdir, extension, cmethod):
|
|
if cmethod in ("lzma", "xz") and lzma is None:
|
|
pytest.skip("lzma is missing")
|
|
# test that compression method corresponds to the given filename extension.
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
obj = "object to dump"
|
|
|
|
dump_fname = filename + extension
|
|
numpy_pickle.dump(obj, dump_fname)
|
|
# Verify the file contains the right magic number
|
|
with open(dump_fname, "rb") as f:
|
|
assert _detect_compressor(f) == cmethod
|
|
# Verify the reloaded object is correct
|
|
obj_reloaded = numpy_pickle.load(dump_fname)
|
|
assert isinstance(obj_reloaded, type(obj))
|
|
assert obj_reloaded == obj
|
|
|
|
|
|
@with_numpy
|
|
def test_file_handle_persistence(tmpdir):
|
|
objs = [np.random.random((10, 10)), "some data"]
|
|
fobjs = [bz2.BZ2File, gzip.GzipFile]
|
|
if lzma is not None:
|
|
fobjs += [lzma.LZMAFile]
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
|
|
for obj in objs:
|
|
for fobj in fobjs:
|
|
with fobj(filename, "wb") as f:
|
|
numpy_pickle.dump(obj, f)
|
|
|
|
# using the same decompressor prevents from internally
|
|
# decompress again.
|
|
with fobj(filename, "rb") as f:
|
|
obj_reloaded = numpy_pickle.load(f)
|
|
|
|
# when needed, the correct decompressor should be used when
|
|
# passing a raw file handle.
|
|
with open(filename, "rb") as f:
|
|
obj_reloaded_2 = numpy_pickle.load(f)
|
|
|
|
if isinstance(obj, np.ndarray):
|
|
np.testing.assert_array_equal(obj_reloaded, obj)
|
|
np.testing.assert_array_equal(obj_reloaded_2, obj)
|
|
else:
|
|
assert obj_reloaded == obj
|
|
assert obj_reloaded_2 == obj
|
|
|
|
|
|
@with_numpy
|
|
def test_in_memory_persistence():
|
|
objs = [np.random.random((10, 10)), "some data"]
|
|
for obj in objs:
|
|
f = io.BytesIO()
|
|
numpy_pickle.dump(obj, f)
|
|
obj_reloaded = numpy_pickle.load(f)
|
|
if isinstance(obj, np.ndarray):
|
|
np.testing.assert_array_equal(obj_reloaded, obj)
|
|
else:
|
|
assert obj_reloaded == obj
|
|
|
|
|
|
@with_numpy
|
|
def test_file_handle_persistence_mmap(tmpdir):
|
|
obj = np.random.random((10, 10))
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
|
|
with open(filename, "wb") as f:
|
|
numpy_pickle.dump(obj, f)
|
|
|
|
with open(filename, "rb") as f:
|
|
obj_reloaded = numpy_pickle.load(f, mmap_mode="r+")
|
|
|
|
np.testing.assert_array_equal(obj_reloaded, obj)
|
|
|
|
|
|
@with_numpy
|
|
def test_file_handle_persistence_compressed_mmap(tmpdir):
|
|
obj = np.random.random((10, 10))
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
|
|
with open(filename, "wb") as f:
|
|
numpy_pickle.dump(obj, f, compress=("gzip", 3))
|
|
|
|
with closing(gzip.GzipFile(filename, "rb")) as f:
|
|
with warns(UserWarning) as warninfo:
|
|
numpy_pickle.load(f, mmap_mode="r+")
|
|
assert len(warninfo) == 1
|
|
assert (
|
|
str(warninfo[0].message)
|
|
== '"%(fileobj)r" is not a raw file, mmap_mode "%(mmap_mode)s" '
|
|
"flag will be ignored." % {"fileobj": f, "mmap_mode": "r+"}
|
|
)
|
|
|
|
|
|
@with_numpy
|
|
def test_file_handle_persistence_in_memory_mmap():
|
|
obj = np.random.random((10, 10))
|
|
buf = io.BytesIO()
|
|
|
|
numpy_pickle.dump(obj, buf)
|
|
|
|
with warns(UserWarning) as warninfo:
|
|
numpy_pickle.load(buf, mmap_mode="r+")
|
|
assert len(warninfo) == 1
|
|
assert (
|
|
str(warninfo[0].message)
|
|
== "In memory persistence is not compatible with mmap_mode "
|
|
'"%(mmap_mode)s" flag passed. mmap_mode option will be '
|
|
"ignored." % {"mmap_mode": "r+"}
|
|
)
|
|
|
|
|
|
@parametrize(
|
|
"data",
|
|
[
|
|
b"a little data as bytes.",
|
|
# More bytes
|
|
10000 * "{}".format(random.randint(0, 1000) * 1000).encode("latin-1"),
|
|
],
|
|
ids=["a little data as bytes.", "a large data as bytes."],
|
|
)
|
|
@parametrize("compress_level", [1, 3, 9])
|
|
def test_binary_zlibfile(tmpdir, data, compress_level):
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
# Regular cases
|
|
with open(filename, "wb") as f:
|
|
with BinaryZlibFile(f, "wb", compresslevel=compress_level) as fz:
|
|
assert fz.writable()
|
|
fz.write(data)
|
|
assert fz.fileno() == f.fileno()
|
|
with raises(io.UnsupportedOperation):
|
|
fz._check_can_read()
|
|
|
|
with raises(io.UnsupportedOperation):
|
|
fz._check_can_seek()
|
|
assert fz.closed
|
|
with raises(ValueError):
|
|
fz._check_not_closed()
|
|
|
|
with open(filename, "rb") as f:
|
|
with BinaryZlibFile(f) as fz:
|
|
assert fz.readable()
|
|
assert fz.seekable()
|
|
assert fz.fileno() == f.fileno()
|
|
assert fz.read() == data
|
|
with raises(io.UnsupportedOperation):
|
|
fz._check_can_write()
|
|
assert fz.seekable()
|
|
fz.seek(0)
|
|
assert fz.tell() == 0
|
|
assert fz.closed
|
|
|
|
# Test with a filename as input
|
|
with BinaryZlibFile(filename, "wb", compresslevel=compress_level) as fz:
|
|
assert fz.writable()
|
|
fz.write(data)
|
|
|
|
with BinaryZlibFile(filename, "rb") as fz:
|
|
assert fz.read() == data
|
|
assert fz.seekable()
|
|
|
|
# Test without context manager
|
|
fz = BinaryZlibFile(filename, "wb", compresslevel=compress_level)
|
|
assert fz.writable()
|
|
fz.write(data)
|
|
fz.close()
|
|
|
|
fz = BinaryZlibFile(filename, "rb")
|
|
assert fz.read() == data
|
|
fz.close()
|
|
|
|
|
|
@parametrize("bad_value", [-1, 10, 15, "a", (), {}])
|
|
def test_binary_zlibfile_bad_compression_levels(tmpdir, bad_value):
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
with raises(ValueError) as excinfo:
|
|
BinaryZlibFile(filename, "wb", compresslevel=bad_value)
|
|
pattern = re.escape(
|
|
"'compresslevel' must be an integer between 1 and 9. "
|
|
"You provided 'compresslevel={}'".format(bad_value)
|
|
)
|
|
excinfo.match(pattern)
|
|
|
|
|
|
@parametrize("bad_mode", ["a", "x", "r", "w", 1, 2])
|
|
def test_binary_zlibfile_invalid_modes(tmpdir, bad_mode):
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
with raises(ValueError) as excinfo:
|
|
BinaryZlibFile(filename, bad_mode)
|
|
excinfo.match("Invalid mode")
|
|
|
|
|
|
@parametrize("bad_file", [1, (), {}])
|
|
def test_binary_zlibfile_invalid_filename_type(bad_file):
|
|
with raises(TypeError) as excinfo:
|
|
BinaryZlibFile(bad_file, "rb")
|
|
excinfo.match("filename must be a str or bytes object, or a file")
|
|
|
|
|
|
###############################################################################
|
|
# Test dumping array subclasses
|
|
if np is not None:
|
|
|
|
class SubArray(np.ndarray):
|
|
def __reduce__(self):
|
|
return _load_sub_array, (np.asarray(self),)
|
|
|
|
def _load_sub_array(arr):
|
|
d = SubArray(arr.shape)
|
|
d[:] = arr
|
|
return d
|
|
|
|
class ComplexTestObject:
|
|
"""A complex object containing numpy arrays as attributes."""
|
|
|
|
def __init__(self):
|
|
self.array_float = np.arange(100, dtype="float64")
|
|
self.array_int = np.ones(100, dtype="int32")
|
|
self.array_obj = np.array(["a", 10, 20.0], dtype="object")
|
|
|
|
|
|
@with_numpy
|
|
def test_numpy_subclass(tmpdir):
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
a = SubArray((10,))
|
|
numpy_pickle.dump(a, filename)
|
|
c = numpy_pickle.load(filename)
|
|
assert isinstance(c, SubArray)
|
|
np.testing.assert_array_equal(c, a)
|
|
|
|
|
|
def test_pathlib(tmpdir):
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
value = 123
|
|
numpy_pickle.dump(value, Path(filename))
|
|
assert numpy_pickle.load(filename) == value
|
|
numpy_pickle.dump(value, filename)
|
|
assert numpy_pickle.load(Path(filename)) == value
|
|
|
|
|
|
@with_numpy
|
|
def test_non_contiguous_array_pickling(tmpdir):
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
|
|
for array in [ # Array that triggers a contiguousness issue with nditer,
|
|
# see https://github.com/joblib/joblib/pull/352 and see
|
|
# https://github.com/joblib/joblib/pull/353
|
|
np.asfortranarray([[1, 2], [3, 4]])[1:],
|
|
# Non contiguous array with works fine with nditer
|
|
np.ones((10, 50, 20), order="F")[:, :1, :],
|
|
]:
|
|
assert not array.flags.c_contiguous
|
|
assert not array.flags.f_contiguous
|
|
numpy_pickle.dump(array, filename)
|
|
array_reloaded = numpy_pickle.load(filename)
|
|
np.testing.assert_array_equal(array_reloaded, array)
|
|
|
|
|
|
@with_numpy
|
|
def test_pickle_highest_protocol(tmpdir):
|
|
# ensure persistence of a numpy array is valid even when using
|
|
# the pickle HIGHEST_PROTOCOL.
|
|
# see https://github.com/joblib/joblib/issues/362
|
|
|
|
filename = tmpdir.join("test.pkl").strpath
|
|
test_array = np.zeros(10)
|
|
|
|
numpy_pickle.dump(test_array, filename, protocol=pickle.HIGHEST_PROTOCOL)
|
|
array_reloaded = numpy_pickle.load(filename)
|
|
|
|
np.testing.assert_array_equal(array_reloaded, test_array)
|
|
|
|
|
|
@with_numpy
|
|
def test_pickle_in_socket():
|
|
# test that joblib can pickle in sockets
|
|
test_array = np.arange(10)
|
|
_ADDR = ("localhost", 12345)
|
|
listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
listener.bind(_ADDR)
|
|
listener.listen(1)
|
|
|
|
with socket.create_connection(_ADDR) as client:
|
|
server, client_addr = listener.accept()
|
|
|
|
with server.makefile("wb") as sf:
|
|
numpy_pickle.dump(test_array, sf)
|
|
|
|
with client.makefile("rb") as cf:
|
|
array_reloaded = numpy_pickle.load(cf)
|
|
|
|
np.testing.assert_array_equal(array_reloaded, test_array)
|
|
|
|
# Check that a byte-aligned numpy array written in a file can be send
|
|
# over a socket and then read on the other side
|
|
bytes_to_send = io.BytesIO()
|
|
numpy_pickle.dump(test_array, bytes_to_send)
|
|
server.send(bytes_to_send.getvalue())
|
|
|
|
with client.makefile("rb") as cf:
|
|
array_reloaded = numpy_pickle.load(cf)
|
|
|
|
np.testing.assert_array_equal(array_reloaded, test_array)
|
|
|
|
|
|
@with_numpy
|
|
def test_load_memmap_with_big_offset(tmpdir):
|
|
# Test that numpy memmap offset is set correctly if greater than
|
|
# mmap.ALLOCATIONGRANULARITY, see
|
|
# https://github.com/joblib/joblib/issues/451 and
|
|
# https://github.com/numpy/numpy/pull/8443 for more details.
|
|
fname = tmpdir.join("test.mmap").strpath
|
|
size = mmap.ALLOCATIONGRANULARITY
|
|
obj = [np.zeros(size, dtype="uint8"), np.ones(size, dtype="uint8")]
|
|
numpy_pickle.dump(obj, fname)
|
|
memmaps = numpy_pickle.load(fname, mmap_mode="r")
|
|
assert isinstance(memmaps[1], np.memmap)
|
|
assert memmaps[1].offset > size
|
|
np.testing.assert_array_equal(obj, memmaps)
|
|
|
|
|
|
def test_register_compressor(tmpdir):
|
|
# Check that registering compressor file works.
|
|
compressor_name = "test-name"
|
|
compressor_prefix = "test-prefix"
|
|
|
|
class BinaryCompressorTestFile(io.BufferedIOBase):
|
|
pass
|
|
|
|
class BinaryCompressorTestWrapper(CompressorWrapper):
|
|
def __init__(self):
|
|
CompressorWrapper.__init__(
|
|
self, obj=BinaryCompressorTestFile, prefix=compressor_prefix
|
|
)
|
|
|
|
register_compressor(compressor_name, BinaryCompressorTestWrapper())
|
|
|
|
assert _COMPRESSORS[compressor_name].fileobj_factory == BinaryCompressorTestFile
|
|
assert _COMPRESSORS[compressor_name].prefix == compressor_prefix
|
|
|
|
# Remove this dummy compressor file from extra compressors because other
|
|
# tests might fail because of this
|
|
_COMPRESSORS.pop(compressor_name)
|
|
|
|
|
|
@parametrize("invalid_name", [1, (), {}])
|
|
def test_register_compressor_invalid_name(invalid_name):
|
|
# Test that registering an invalid compressor name is not allowed.
|
|
with raises(ValueError) as excinfo:
|
|
register_compressor(invalid_name, None)
|
|
excinfo.match("Compressor name should be a string")
|
|
|
|
|
|
def test_register_compressor_invalid_fileobj():
|
|
# Test that registering an invalid file object is not allowed.
|
|
|
|
class InvalidFileObject:
|
|
pass
|
|
|
|
class InvalidFileObjectWrapper(CompressorWrapper):
|
|
def __init__(self):
|
|
CompressorWrapper.__init__(self, obj=InvalidFileObject, prefix=b"prefix")
|
|
|
|
with raises(ValueError) as excinfo:
|
|
register_compressor("invalid", InvalidFileObjectWrapper())
|
|
|
|
excinfo.match(
|
|
"Compressor 'fileobj_factory' attribute should implement "
|
|
"the file object interface"
|
|
)
|
|
|
|
|
|
class AnotherZlibCompressorWrapper(CompressorWrapper):
|
|
def __init__(self):
|
|
CompressorWrapper.__init__(self, obj=BinaryZlibFile, prefix=b"prefix")
|
|
|
|
|
|
class StandardLibGzipCompressorWrapper(CompressorWrapper):
|
|
def __init__(self):
|
|
CompressorWrapper.__init__(self, obj=gzip.GzipFile, prefix=b"prefix")
|
|
|
|
|
|
def test_register_compressor_already_registered():
|
|
# Test registration of existing compressor files.
|
|
compressor_name = "test-name"
|
|
|
|
# register a test compressor
|
|
register_compressor(compressor_name, AnotherZlibCompressorWrapper())
|
|
|
|
with raises(ValueError) as excinfo:
|
|
register_compressor(compressor_name, StandardLibGzipCompressorWrapper())
|
|
excinfo.match("Compressor '{}' already registered.".format(compressor_name))
|
|
|
|
register_compressor(compressor_name, StandardLibGzipCompressorWrapper(), force=True)
|
|
|
|
assert compressor_name in _COMPRESSORS
|
|
assert _COMPRESSORS[compressor_name].fileobj_factory == gzip.GzipFile
|
|
|
|
# Remove this dummy compressor file from extra compressors because other
|
|
# tests might fail because of this
|
|
_COMPRESSORS.pop(compressor_name)
|
|
|
|
|
|
@with_lz4
|
|
def test_lz4_compression(tmpdir):
|
|
# Check that lz4 can be used when dependency is available.
|
|
import lz4.frame
|
|
|
|
compressor = "lz4"
|
|
assert compressor in _COMPRESSORS
|
|
assert _COMPRESSORS[compressor].fileobj_factory == lz4.frame.LZ4FrameFile
|
|
|
|
fname = tmpdir.join("test.pkl").strpath
|
|
data = "test data"
|
|
numpy_pickle.dump(data, fname, compress=compressor)
|
|
|
|
with open(fname, "rb") as f:
|
|
assert f.read(len(_LZ4_PREFIX)) == _LZ4_PREFIX
|
|
assert numpy_pickle.load(fname) == data
|
|
|
|
# Test that LZ4 is applied based on file extension
|
|
numpy_pickle.dump(data, fname + ".lz4")
|
|
with open(fname, "rb") as f:
|
|
assert f.read(len(_LZ4_PREFIX)) == _LZ4_PREFIX
|
|
assert numpy_pickle.load(fname) == data
|
|
|
|
|
|
@without_lz4
|
|
def test_lz4_compression_without_lz4(tmpdir):
|
|
# Check that lz4 cannot be used when dependency is not available.
|
|
fname = tmpdir.join("test.nolz4").strpath
|
|
data = "test data"
|
|
msg = LZ4_NOT_INSTALLED_ERROR
|
|
with raises(ValueError) as excinfo:
|
|
numpy_pickle.dump(data, fname, compress="lz4")
|
|
excinfo.match(msg)
|
|
|
|
with raises(ValueError) as excinfo:
|
|
numpy_pickle.dump(data, fname + ".lz4")
|
|
excinfo.match(msg)
|
|
|
|
|
|
protocols = [pickle.DEFAULT_PROTOCOL]
|
|
if pickle.HIGHEST_PROTOCOL != pickle.DEFAULT_PROTOCOL:
|
|
protocols.append(pickle.HIGHEST_PROTOCOL)
|
|
|
|
|
|
@with_numpy
|
|
@parametrize("protocol", protocols)
|
|
def test_memmap_alignment_padding(tmpdir, protocol):
|
|
# Test that memmaped arrays returned by numpy.load are correctly aligned
|
|
fname = tmpdir.join("test.mmap").strpath
|
|
|
|
a = np.random.randn(2)
|
|
numpy_pickle.dump(a, fname, protocol=protocol)
|
|
memmap = numpy_pickle.load(fname, mmap_mode="r")
|
|
assert isinstance(memmap, np.memmap)
|
|
np.testing.assert_array_equal(a, memmap)
|
|
assert memmap.ctypes.data % numpy_pickle.NUMPY_ARRAY_ALIGNMENT_BYTES == 0
|
|
assert memmap.flags.aligned
|
|
|
|
array_list = [
|
|
np.random.randn(2),
|
|
np.random.randn(2),
|
|
np.random.randn(2),
|
|
np.random.randn(2),
|
|
]
|
|
|
|
# On Windows OSError 22 if reusing the same path for memmap ...
|
|
fname = tmpdir.join("test1.mmap").strpath
|
|
numpy_pickle.dump(array_list, fname, protocol=protocol)
|
|
l_reloaded = numpy_pickle.load(fname, mmap_mode="r")
|
|
|
|
for idx, memmap in enumerate(l_reloaded):
|
|
assert isinstance(memmap, np.memmap)
|
|
np.testing.assert_array_equal(array_list[idx], memmap)
|
|
assert memmap.ctypes.data % numpy_pickle.NUMPY_ARRAY_ALIGNMENT_BYTES == 0
|
|
assert memmap.flags.aligned
|
|
|
|
array_dict = {
|
|
"a0": np.arange(2, dtype=np.uint8),
|
|
"a1": np.arange(3, dtype=np.uint8),
|
|
"a2": np.arange(5, dtype=np.uint8),
|
|
"a3": np.arange(7, dtype=np.uint8),
|
|
"a4": np.arange(11, dtype=np.uint8),
|
|
"a5": np.arange(13, dtype=np.uint8),
|
|
"a6": np.arange(17, dtype=np.uint8),
|
|
"a7": np.arange(19, dtype=np.uint8),
|
|
"a8": np.arange(23, dtype=np.uint8),
|
|
}
|
|
|
|
# On Windows OSError 22 if reusing the same path for memmap ...
|
|
fname = tmpdir.join("test2.mmap").strpath
|
|
numpy_pickle.dump(array_dict, fname, protocol=protocol)
|
|
d_reloaded = numpy_pickle.load(fname, mmap_mode="r")
|
|
|
|
for key, memmap in d_reloaded.items():
|
|
assert isinstance(memmap, np.memmap)
|
|
np.testing.assert_array_equal(array_dict[key], memmap)
|
|
assert memmap.ctypes.data % numpy_pickle.NUMPY_ARRAY_ALIGNMENT_BYTES == 0
|
|
assert memmap.flags.aligned
|