import asyncio import gc import shutil import pytest from joblib.memory import ( AsyncMemorizedFunc, AsyncNotMemorizedFunc, MemorizedResult, Memory, NotMemorizedResult, ) from joblib.test.common import np, with_numpy from joblib.testing import raises from .test_memory import corrupt_single_cache_item, monkeypatch_cached_func_warn async def check_identity_lazy_async(func, accumulator, location): """Similar to check_identity_lazy_async for coroutine functions""" memory = Memory(location=location, verbose=0) func = memory.cache(func) for i in range(3): for _ in range(2): value = await func(i) assert value == i assert len(accumulator) == i + 1 @pytest.mark.asyncio async def test_memory_integration_async(tmpdir): accumulator = list() async def f(n): await asyncio.sleep(0.1) accumulator.append(1) return n await check_identity_lazy_async(f, accumulator, tmpdir.strpath) # Now test clearing for compress in (False, True): for mmap_mode in ("r", None): memory = Memory( location=tmpdir.strpath, verbose=10, mmap_mode=mmap_mode, compress=compress, ) # First clear the cache directory, to check that our code can # handle that # NOTE: this line would raise an exception, as the database # file is still open; we ignore the error since we want to # test what happens if the directory disappears shutil.rmtree(tmpdir.strpath, ignore_errors=True) g = memory.cache(f) await g(1) g.clear(warn=False) current_accumulator = len(accumulator) out = await g(1) assert len(accumulator) == current_accumulator + 1 # Also, check that Memory.eval works similarly evaled = await memory.eval(f, 1) assert evaled == out assert len(accumulator) == current_accumulator + 1 # Now do a smoke test with a function defined in __main__, as the name # mangling rules are more complex f.__module__ = "__main__" memory = Memory(location=tmpdir.strpath, verbose=0) await memory.cache(f)(1) @pytest.mark.asyncio async def test_no_memory_async(): accumulator = list() async def ff(x): await asyncio.sleep(0.1) accumulator.append(1) return x memory = Memory(location=None, verbose=0) gg = memory.cache(ff) for _ in range(4): current_accumulator = len(accumulator) await gg(1) assert len(accumulator) == current_accumulator + 1 @with_numpy @pytest.mark.asyncio async def test_memory_numpy_check_mmap_mode_async(tmpdir, monkeypatch): """Check that mmap_mode is respected even at the first call""" memory = Memory(location=tmpdir.strpath, mmap_mode="r", verbose=0) @memory.cache() async def twice(a): return a * 2 a = np.ones(3) b = await twice(a) c = await twice(a) assert isinstance(c, np.memmap) assert c.mode == "r" assert isinstance(b, np.memmap) assert b.mode == "r" # Corrupts the file, Deleting b and c mmaps # is necessary to be able edit the file del b del c gc.collect() corrupt_single_cache_item(memory) # Make sure that corrupting the file causes recomputation and that # a warning is issued. recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch) d = await twice(a) assert len(recorded_warnings) == 1 exception_msg = "Exception while loading results" assert exception_msg in recorded_warnings[0] # Asserts that the recomputation returns a mmap assert isinstance(d, np.memmap) assert d.mode == "r" @pytest.mark.asyncio async def test_call_and_shelve_async(tmpdir): async def f(x, y=1): await asyncio.sleep(0.1) return x**2 + y # Test MemorizedFunc outputting a reference to cache. for func, Result in zip( ( AsyncMemorizedFunc(f, tmpdir.strpath), AsyncNotMemorizedFunc(f), Memory(location=tmpdir.strpath, verbose=0).cache(f), Memory(location=None).cache(f), ), ( MemorizedResult, NotMemorizedResult, MemorizedResult, NotMemorizedResult, ), ): for _ in range(2): result = await func.call_and_shelve(2) assert isinstance(result, Result) assert result.get() == 5 result.clear() with raises(KeyError): result.get() result.clear() # Do nothing if there is no cache. @pytest.mark.asyncio async def test_memorized_func_call_async(memory): async def ff(x, counter): await asyncio.sleep(0.1) counter[x] = counter.get(x, 0) + 1 return counter[x] gg = memory.cache(ff, ignore=["counter"]) counter = {} assert await gg(2, counter) == 1 assert await gg(2, counter) == 1 x, meta = await gg.call(2, counter) assert x == 2, "f has not been called properly" assert isinstance(meta, dict), "Metadata are not returned by MemorizedFunc.call."