team-10/env/Lib/site-packages/joblib/test/test_memory_async.py
2025-08-02 07:34:44 +02:00

180 lines
5.1 KiB
Python

import asyncio
import gc
import shutil
import pytest
from joblib.memory import (
AsyncMemorizedFunc,
AsyncNotMemorizedFunc,
MemorizedResult,
Memory,
NotMemorizedResult,
)
from joblib.test.common import np, with_numpy
from joblib.testing import raises
from .test_memory import corrupt_single_cache_item, monkeypatch_cached_func_warn
async def check_identity_lazy_async(func, accumulator, location):
"""Similar to check_identity_lazy_async for coroutine functions"""
memory = Memory(location=location, verbose=0)
func = memory.cache(func)
for i in range(3):
for _ in range(2):
value = await func(i)
assert value == i
assert len(accumulator) == i + 1
@pytest.mark.asyncio
async def test_memory_integration_async(tmpdir):
accumulator = list()
async def f(n):
await asyncio.sleep(0.1)
accumulator.append(1)
return n
await check_identity_lazy_async(f, accumulator, tmpdir.strpath)
# Now test clearing
for compress in (False, True):
for mmap_mode in ("r", None):
memory = Memory(
location=tmpdir.strpath,
verbose=10,
mmap_mode=mmap_mode,
compress=compress,
)
# First clear the cache directory, to check that our code can
# handle that
# NOTE: this line would raise an exception, as the database
# file is still open; we ignore the error since we want to
# test what happens if the directory disappears
shutil.rmtree(tmpdir.strpath, ignore_errors=True)
g = memory.cache(f)
await g(1)
g.clear(warn=False)
current_accumulator = len(accumulator)
out = await g(1)
assert len(accumulator) == current_accumulator + 1
# Also, check that Memory.eval works similarly
evaled = await memory.eval(f, 1)
assert evaled == out
assert len(accumulator) == current_accumulator + 1
# Now do a smoke test with a function defined in __main__, as the name
# mangling rules are more complex
f.__module__ = "__main__"
memory = Memory(location=tmpdir.strpath, verbose=0)
await memory.cache(f)(1)
@pytest.mark.asyncio
async def test_no_memory_async():
accumulator = list()
async def ff(x):
await asyncio.sleep(0.1)
accumulator.append(1)
return x
memory = Memory(location=None, verbose=0)
gg = memory.cache(ff)
for _ in range(4):
current_accumulator = len(accumulator)
await gg(1)
assert len(accumulator) == current_accumulator + 1
@with_numpy
@pytest.mark.asyncio
async def test_memory_numpy_check_mmap_mode_async(tmpdir, monkeypatch):
"""Check that mmap_mode is respected even at the first call"""
memory = Memory(location=tmpdir.strpath, mmap_mode="r", verbose=0)
@memory.cache()
async def twice(a):
return a * 2
a = np.ones(3)
b = await twice(a)
c = await twice(a)
assert isinstance(c, np.memmap)
assert c.mode == "r"
assert isinstance(b, np.memmap)
assert b.mode == "r"
# Corrupts the file, Deleting b and c mmaps
# is necessary to be able edit the file
del b
del c
gc.collect()
corrupt_single_cache_item(memory)
# Make sure that corrupting the file causes recomputation and that
# a warning is issued.
recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch)
d = await twice(a)
assert len(recorded_warnings) == 1
exception_msg = "Exception while loading results"
assert exception_msg in recorded_warnings[0]
# Asserts that the recomputation returns a mmap
assert isinstance(d, np.memmap)
assert d.mode == "r"
@pytest.mark.asyncio
async def test_call_and_shelve_async(tmpdir):
async def f(x, y=1):
await asyncio.sleep(0.1)
return x**2 + y
# Test MemorizedFunc outputting a reference to cache.
for func, Result in zip(
(
AsyncMemorizedFunc(f, tmpdir.strpath),
AsyncNotMemorizedFunc(f),
Memory(location=tmpdir.strpath, verbose=0).cache(f),
Memory(location=None).cache(f),
),
(
MemorizedResult,
NotMemorizedResult,
MemorizedResult,
NotMemorizedResult,
),
):
for _ in range(2):
result = await func.call_and_shelve(2)
assert isinstance(result, Result)
assert result.get() == 5
result.clear()
with raises(KeyError):
result.get()
result.clear() # Do nothing if there is no cache.
@pytest.mark.asyncio
async def test_memorized_func_call_async(memory):
async def ff(x, counter):
await asyncio.sleep(0.1)
counter[x] = counter.get(x, 0) + 1
return counter[x]
gg = memory.cache(ff, ignore=["counter"])
counter = {}
assert await gg(2, counter) == 1
assert await gg(2, counter) == 1
x, meta = await gg.call(2, counter)
assert x == 2, "f has not been called properly"
assert isinstance(meta, dict), "Metadata are not returned by MemorizedFunc.call."