912 lines
29 KiB
Python
912 lines
29 KiB
Python
![]() |
# Licensed to the Apache Software Foundation (ASF) under one
|
||
|
# or more contributor license agreements. See the NOTICE file
|
||
|
# distributed with this work for additional information
|
||
|
# regarding copyright ownership. The ASF licenses this file
|
||
|
# to you under the Apache License, Version 2.0 (the
|
||
|
# "License"); you may not use this file except in compliance
|
||
|
# with the License. You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing,
|
||
|
# software distributed under the License is distributed on an
|
||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||
|
# KIND, either express or implied. See the License for the
|
||
|
# specific language governing permissions and limitations
|
||
|
# under the License.
|
||
|
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
try:
|
||
|
import numpy as np
|
||
|
except ImportError:
|
||
|
np = None
|
||
|
|
||
|
import pyarrow as pa
|
||
|
from pyarrow import compute as pc
|
||
|
|
||
|
# UDFs are all tested with a dataset scan
|
||
|
pytestmark = pytest.mark.dataset
|
||
|
|
||
|
# For convenience, most of the test here doesn't care about udf func docs
|
||
|
empty_udf_doc = {"summary": "", "description": ""}
|
||
|
|
||
|
try:
|
||
|
import pyarrow.dataset as ds
|
||
|
except ImportError:
|
||
|
ds = None
|
||
|
|
||
|
|
||
|
def mock_udf_context(batch_length=10):
|
||
|
from pyarrow._compute import _get_udf_context
|
||
|
return _get_udf_context(pa.default_memory_pool(), batch_length)
|
||
|
|
||
|
|
||
|
class MyError(RuntimeError):
|
||
|
pass
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def sum_agg_func_fixture():
|
||
|
"""
|
||
|
Register a unary aggregate function (mean)
|
||
|
"""
|
||
|
def func(ctx, x, *args):
|
||
|
return pa.scalar(np.nansum(x))
|
||
|
|
||
|
func_name = "sum_udf"
|
||
|
func_doc = empty_udf_doc
|
||
|
|
||
|
pc.register_aggregate_function(func,
|
||
|
func_name,
|
||
|
func_doc,
|
||
|
{
|
||
|
"x": pa.float64(),
|
||
|
},
|
||
|
pa.float64()
|
||
|
)
|
||
|
return func, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def exception_agg_func_fixture():
|
||
|
def func(ctx, x):
|
||
|
raise RuntimeError("Oops")
|
||
|
return pa.scalar(len(x))
|
||
|
|
||
|
func_name = "y=exception_len(x)"
|
||
|
func_doc = empty_udf_doc
|
||
|
|
||
|
pc.register_aggregate_function(func,
|
||
|
func_name,
|
||
|
func_doc,
|
||
|
{
|
||
|
"x": pa.int64(),
|
||
|
},
|
||
|
pa.int64()
|
||
|
)
|
||
|
return func, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def wrong_output_dtype_agg_func_fixture(scope="session"):
|
||
|
def func(ctx, x):
|
||
|
return pa.scalar(len(x), pa.int32())
|
||
|
|
||
|
func_name = "y=wrong_output_dtype(x)"
|
||
|
func_doc = empty_udf_doc
|
||
|
|
||
|
pc.register_aggregate_function(func,
|
||
|
func_name,
|
||
|
func_doc,
|
||
|
{
|
||
|
"x": pa.int64(),
|
||
|
},
|
||
|
pa.int64()
|
||
|
)
|
||
|
return func, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def wrong_output_type_agg_func_fixture(scope="session"):
|
||
|
def func(ctx, x):
|
||
|
return len(x)
|
||
|
|
||
|
func_name = "y=wrong_output_type(x)"
|
||
|
func_doc = empty_udf_doc
|
||
|
|
||
|
pc.register_aggregate_function(func,
|
||
|
func_name,
|
||
|
func_doc,
|
||
|
{
|
||
|
"x": pa.int64(),
|
||
|
},
|
||
|
pa.int64()
|
||
|
)
|
||
|
return func, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def binary_func_fixture():
|
||
|
"""
|
||
|
Register a binary scalar function.
|
||
|
"""
|
||
|
def binary_function(ctx, m, x):
|
||
|
return pc.call_function("multiply", [m, x],
|
||
|
memory_pool=ctx.memory_pool)
|
||
|
func_name = "y=mx"
|
||
|
binary_doc = {"summary": "y=mx",
|
||
|
"description": "find y from y = mx"}
|
||
|
pc.register_scalar_function(binary_function,
|
||
|
func_name,
|
||
|
binary_doc,
|
||
|
{"m": pa.int64(),
|
||
|
"x": pa.int64(),
|
||
|
},
|
||
|
pa.int64())
|
||
|
return binary_function, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def ternary_func_fixture():
|
||
|
"""
|
||
|
Register a ternary scalar function.
|
||
|
"""
|
||
|
def ternary_function(ctx, m, x, c):
|
||
|
mx = pc.call_function("multiply", [m, x],
|
||
|
memory_pool=ctx.memory_pool)
|
||
|
return pc.call_function("add", [mx, c],
|
||
|
memory_pool=ctx.memory_pool)
|
||
|
ternary_doc = {"summary": "y=mx+c",
|
||
|
"description": "find y from y = mx + c"}
|
||
|
func_name = "y=mx+c"
|
||
|
pc.register_scalar_function(ternary_function,
|
||
|
func_name,
|
||
|
ternary_doc,
|
||
|
{
|
||
|
"array1": pa.int64(),
|
||
|
"array2": pa.int64(),
|
||
|
"array3": pa.int64(),
|
||
|
},
|
||
|
pa.int64())
|
||
|
return ternary_function, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def varargs_func_fixture():
|
||
|
"""
|
||
|
Register a varargs scalar function with at least two arguments.
|
||
|
"""
|
||
|
def varargs_function(ctx, first, *values):
|
||
|
acc = first
|
||
|
for val in values:
|
||
|
acc = pc.call_function("add", [acc, val],
|
||
|
memory_pool=ctx.memory_pool)
|
||
|
return acc
|
||
|
func_name = "z=ax+by+c"
|
||
|
varargs_doc = {"summary": "z=ax+by+c",
|
||
|
"description": "find z from z = ax + by + c"
|
||
|
}
|
||
|
pc.register_scalar_function(varargs_function,
|
||
|
func_name,
|
||
|
varargs_doc,
|
||
|
{
|
||
|
"array1": pa.int64(),
|
||
|
"array2": pa.int64(),
|
||
|
},
|
||
|
pa.int64())
|
||
|
return varargs_function, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def nullary_func_fixture():
|
||
|
"""
|
||
|
Register a nullary scalar function.
|
||
|
"""
|
||
|
def nullary_func(context):
|
||
|
return pa.array([42] * context.batch_length, type=pa.int64(),
|
||
|
memory_pool=context.memory_pool)
|
||
|
|
||
|
func_doc = {
|
||
|
"summary": "random function",
|
||
|
"description": "generates a random value"
|
||
|
}
|
||
|
func_name = "test_nullary_func"
|
||
|
pc.register_scalar_function(nullary_func,
|
||
|
func_name,
|
||
|
func_doc,
|
||
|
{},
|
||
|
pa.int64())
|
||
|
|
||
|
return nullary_func, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def ephemeral_nullary_func_fixture():
|
||
|
"""
|
||
|
Register a nullary scalar function with an ephemeral Python function.
|
||
|
This stresses that the Python function object is properly kept alive by the
|
||
|
registered function.
|
||
|
"""
|
||
|
def nullary_func(context):
|
||
|
return pa.array([42] * context.batch_length, type=pa.int64(),
|
||
|
memory_pool=context.memory_pool)
|
||
|
|
||
|
func_doc = {
|
||
|
"summary": "random function",
|
||
|
"description": "generates a random value"
|
||
|
}
|
||
|
func_name = "test_ephemeral_nullary_func"
|
||
|
pc.register_scalar_function(nullary_func,
|
||
|
func_name,
|
||
|
func_doc,
|
||
|
{},
|
||
|
pa.int64())
|
||
|
|
||
|
return func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def wrong_output_type_func_fixture():
|
||
|
"""
|
||
|
Register a scalar function which returns something that is neither
|
||
|
a Arrow scalar or array.
|
||
|
"""
|
||
|
def wrong_output_type(ctx):
|
||
|
return 42
|
||
|
|
||
|
func_name = "test_wrong_output_type"
|
||
|
in_types = {}
|
||
|
out_type = pa.int64()
|
||
|
doc = {
|
||
|
"summary": "return wrong output type",
|
||
|
"description": ""
|
||
|
}
|
||
|
pc.register_scalar_function(wrong_output_type, func_name, doc,
|
||
|
in_types, out_type)
|
||
|
return wrong_output_type, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def wrong_output_datatype_func_fixture():
|
||
|
"""
|
||
|
Register a scalar function whose actual output DataType doesn't
|
||
|
match the declared output DataType.
|
||
|
"""
|
||
|
def wrong_output_datatype(ctx, array):
|
||
|
return pc.call_function("add", [array, 1])
|
||
|
func_name = "test_wrong_output_datatype"
|
||
|
in_types = {"array": pa.int64()}
|
||
|
# The actual output DataType will be int64.
|
||
|
out_type = pa.int16()
|
||
|
doc = {
|
||
|
"summary": "return wrong output datatype",
|
||
|
"description": ""
|
||
|
}
|
||
|
pc.register_scalar_function(wrong_output_datatype, func_name, doc,
|
||
|
in_types, out_type)
|
||
|
return wrong_output_datatype, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def wrong_signature_func_fixture():
|
||
|
"""
|
||
|
Register a scalar function with the wrong signature.
|
||
|
"""
|
||
|
# Missing the context argument
|
||
|
def wrong_signature():
|
||
|
return pa.scalar(1, type=pa.int64())
|
||
|
|
||
|
func_name = "test_wrong_signature"
|
||
|
in_types = {}
|
||
|
out_type = pa.int64()
|
||
|
doc = {
|
||
|
"summary": "UDF with wrong signature",
|
||
|
"description": ""
|
||
|
}
|
||
|
pc.register_scalar_function(wrong_signature, func_name, doc,
|
||
|
in_types, out_type)
|
||
|
return wrong_signature, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def raising_func_fixture():
|
||
|
"""
|
||
|
Register a scalar function which raises a custom exception.
|
||
|
"""
|
||
|
def raising_func(ctx):
|
||
|
raise MyError("error raised by scalar UDF")
|
||
|
func_name = "test_raise"
|
||
|
doc = {
|
||
|
"summary": "raising function",
|
||
|
"description": ""
|
||
|
}
|
||
|
pc.register_scalar_function(raising_func, func_name, doc,
|
||
|
{}, pa.int64())
|
||
|
return raising_func, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def unary_vector_func_fixture():
|
||
|
"""
|
||
|
Register a vector function
|
||
|
"""
|
||
|
def pct_rank(ctx, x):
|
||
|
# copy here to get around pandas 1.0 issue
|
||
|
return pa.array(x.to_pandas().copy().rank(pct=True))
|
||
|
|
||
|
func_name = "y=pct_rank(x)"
|
||
|
doc = empty_udf_doc
|
||
|
pc.register_vector_function(pct_rank, func_name, doc, {
|
||
|
'x': pa.float64()}, pa.float64())
|
||
|
|
||
|
return pct_rank, func_name
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def struct_vector_func_fixture():
|
||
|
"""
|
||
|
Register a vector function that returns a struct array
|
||
|
"""
|
||
|
def pivot(ctx, k, v, c):
|
||
|
df = pa.RecordBatch.from_arrays([k, v, c], names=['k', 'v', 'c']).to_pandas()
|
||
|
df_pivot = df.pivot(columns='c', values='v', index='k').reset_index()
|
||
|
return pa.RecordBatch.from_pandas(df_pivot).to_struct_array()
|
||
|
|
||
|
func_name = "y=pivot(x)"
|
||
|
doc = empty_udf_doc
|
||
|
pc.register_vector_function(
|
||
|
pivot, func_name, doc,
|
||
|
{'k': pa.int64(), 'v': pa.float64(), 'c': pa.utf8()},
|
||
|
pa.struct([('k', pa.int64()), ('v1', pa.float64()), ('v2', pa.float64())])
|
||
|
)
|
||
|
|
||
|
return pivot, func_name
|
||
|
|
||
|
|
||
|
def check_scalar_function(func_fixture,
|
||
|
inputs, *,
|
||
|
run_in_dataset=True,
|
||
|
batch_length=None):
|
||
|
function, name = func_fixture
|
||
|
if batch_length is None:
|
||
|
all_scalar = True
|
||
|
for arg in inputs:
|
||
|
if isinstance(arg, pa.Array):
|
||
|
all_scalar = False
|
||
|
batch_length = len(arg)
|
||
|
if all_scalar:
|
||
|
batch_length = 1
|
||
|
|
||
|
func = pc.get_function(name)
|
||
|
assert func.name == name
|
||
|
|
||
|
result = pc.call_function(name, inputs, length=batch_length)
|
||
|
expected_output = function(mock_udf_context(batch_length), *inputs)
|
||
|
assert result == expected_output
|
||
|
# At the moment there is an issue when handling nullary functions.
|
||
|
# See: ARROW-15286 and ARROW-16290.
|
||
|
if run_in_dataset:
|
||
|
field_names = [f'field{index}' for index, in_arr in inputs]
|
||
|
table = pa.Table.from_arrays(inputs, field_names)
|
||
|
dataset = ds.dataset(table)
|
||
|
func_args = [ds.field(field_name) for field_name in field_names]
|
||
|
result_table = dataset.to_table(
|
||
|
columns={'result': ds.field('')._call(name, func_args)})
|
||
|
assert result_table.column(0).chunks[0] == expected_output
|
||
|
|
||
|
|
||
|
def test_udf_array_unary(unary_func_fixture):
|
||
|
check_scalar_function(unary_func_fixture,
|
||
|
[
|
||
|
pa.array([10, 20], pa.int64())
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_udf_array_binary(binary_func_fixture):
|
||
|
check_scalar_function(binary_func_fixture,
|
||
|
[
|
||
|
pa.array([10, 20], pa.int64()),
|
||
|
pa.array([2, 4], pa.int64())
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_udf_array_ternary(ternary_func_fixture):
|
||
|
check_scalar_function(ternary_func_fixture,
|
||
|
[
|
||
|
pa.array([10, 20], pa.int64()),
|
||
|
pa.array([2, 4], pa.int64()),
|
||
|
pa.array([5, 10], pa.int64())
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_udf_array_varargs(varargs_func_fixture):
|
||
|
check_scalar_function(varargs_func_fixture,
|
||
|
[
|
||
|
pa.array([2, 3], pa.int64()),
|
||
|
pa.array([10, 20], pa.int64()),
|
||
|
pa.array([3, 7], pa.int64()),
|
||
|
pa.array([20, 30], pa.int64()),
|
||
|
pa.array([5, 10], pa.int64())
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_registration_errors():
|
||
|
# validate function name
|
||
|
doc = {
|
||
|
"summary": "test udf input",
|
||
|
"description": "parameters are validated"
|
||
|
}
|
||
|
in_types = {"scalar": pa.int64()}
|
||
|
out_type = pa.int64()
|
||
|
|
||
|
def test_reg_function(context):
|
||
|
return pa.array([10])
|
||
|
|
||
|
with pytest.raises(TypeError):
|
||
|
pc.register_scalar_function(test_reg_function,
|
||
|
None, doc, in_types,
|
||
|
out_type)
|
||
|
|
||
|
# validate function
|
||
|
with pytest.raises(TypeError, match="func must be a callable"):
|
||
|
pc.register_scalar_function(None, "test_none_function", doc, in_types,
|
||
|
out_type)
|
||
|
|
||
|
# validate output type
|
||
|
expected_expr = "DataType expected, got <class 'NoneType'>"
|
||
|
with pytest.raises(TypeError, match=expected_expr):
|
||
|
pc.register_scalar_function(test_reg_function,
|
||
|
"test_output_function", doc, in_types,
|
||
|
None)
|
||
|
|
||
|
# validate input type
|
||
|
expected_expr = "in_types must be a dictionary of DataType"
|
||
|
with pytest.raises(TypeError, match=expected_expr):
|
||
|
pc.register_scalar_function(test_reg_function,
|
||
|
"test_input_function", doc, None,
|
||
|
out_type)
|
||
|
|
||
|
# register an already registered function
|
||
|
# first registration
|
||
|
pc.register_scalar_function(test_reg_function,
|
||
|
"test_reg_function", doc, {},
|
||
|
out_type)
|
||
|
# second registration
|
||
|
expected_expr = "Already have a function registered with name:" \
|
||
|
+ " test_reg_function"
|
||
|
with pytest.raises(KeyError, match=expected_expr):
|
||
|
pc.register_scalar_function(test_reg_function,
|
||
|
"test_reg_function", doc, {},
|
||
|
out_type)
|
||
|
|
||
|
|
||
|
def test_varargs_function_validation(varargs_func_fixture):
|
||
|
_, func_name = varargs_func_fixture
|
||
|
|
||
|
error_msg = r"VarArgs function 'z=ax\+by\+c' needs at least 2 arguments"
|
||
|
|
||
|
with pytest.raises(ValueError, match=error_msg):
|
||
|
pc.call_function(func_name, [42])
|
||
|
|
||
|
|
||
|
def test_function_doc_validation():
|
||
|
# validate arity
|
||
|
in_types = {"scalar": pa.int64()}
|
||
|
out_type = pa.int64()
|
||
|
|
||
|
# doc with no summary
|
||
|
func_doc = {
|
||
|
"description": "desc"
|
||
|
}
|
||
|
|
||
|
def add_const(ctx, scalar):
|
||
|
return pc.call_function("add", [scalar, 1])
|
||
|
|
||
|
with pytest.raises(ValueError,
|
||
|
match="Function doc must contain a summary"):
|
||
|
pc.register_scalar_function(add_const, "test_no_summary",
|
||
|
func_doc, in_types,
|
||
|
out_type)
|
||
|
|
||
|
# doc with no description
|
||
|
func_doc = {
|
||
|
"summary": "test summary"
|
||
|
}
|
||
|
|
||
|
with pytest.raises(ValueError,
|
||
|
match="Function doc must contain a description"):
|
||
|
pc.register_scalar_function(add_const, "test_no_desc",
|
||
|
func_doc, in_types,
|
||
|
out_type)
|
||
|
|
||
|
|
||
|
def test_nullary_function(nullary_func_fixture):
|
||
|
# XXX the Python compute layer API doesn't let us override batch_length,
|
||
|
# so only test with the default value of 1.
|
||
|
check_scalar_function(nullary_func_fixture, [], run_in_dataset=False,
|
||
|
batch_length=1)
|
||
|
|
||
|
|
||
|
def test_ephemeral_function(ephemeral_nullary_func_fixture):
|
||
|
name = ephemeral_nullary_func_fixture
|
||
|
result = pc.call_function(name, [], length=1)
|
||
|
assert result.to_pylist() == [42]
|
||
|
|
||
|
|
||
|
def test_wrong_output_type(wrong_output_type_func_fixture):
|
||
|
_, func_name = wrong_output_type_func_fixture
|
||
|
|
||
|
with pytest.raises(TypeError,
|
||
|
match="Unexpected output type: int"):
|
||
|
pc.call_function(func_name, [], length=1)
|
||
|
|
||
|
|
||
|
def test_wrong_output_datatype(wrong_output_datatype_func_fixture):
|
||
|
_, func_name = wrong_output_datatype_func_fixture
|
||
|
|
||
|
expected_expr = ("Expected output datatype int16, "
|
||
|
"but function returned datatype int64")
|
||
|
|
||
|
with pytest.raises(TypeError, match=expected_expr):
|
||
|
pc.call_function(func_name, [pa.array([20, 30])])
|
||
|
|
||
|
|
||
|
def test_wrong_signature(wrong_signature_func_fixture):
|
||
|
_, func_name = wrong_signature_func_fixture
|
||
|
|
||
|
expected_expr = (r"wrong_signature\(\) takes 0 positional arguments "
|
||
|
"but 1 was given")
|
||
|
|
||
|
with pytest.raises(TypeError, match=expected_expr):
|
||
|
pc.call_function(func_name, [], length=1)
|
||
|
|
||
|
|
||
|
def test_wrong_datatype_declaration():
|
||
|
def identity(ctx, val):
|
||
|
return val
|
||
|
|
||
|
func_name = "test_wrong_datatype_declaration"
|
||
|
in_types = {"array": pa.int64()}
|
||
|
out_type = {}
|
||
|
doc = {
|
||
|
"summary": "test output value",
|
||
|
"description": "test output"
|
||
|
}
|
||
|
with pytest.raises(TypeError,
|
||
|
match="DataType expected, got <class 'dict'>"):
|
||
|
pc.register_scalar_function(identity, func_name,
|
||
|
doc, in_types, out_type)
|
||
|
|
||
|
|
||
|
def test_wrong_input_type_declaration():
|
||
|
def identity(ctx, val):
|
||
|
return val
|
||
|
|
||
|
func_name = "test_wrong_input_type_declaration"
|
||
|
in_types = {"array": None}
|
||
|
out_type = pa.int64()
|
||
|
doc = {
|
||
|
"summary": "test invalid input type",
|
||
|
"description": "invalid input function"
|
||
|
}
|
||
|
with pytest.raises(TypeError,
|
||
|
match="DataType expected, got <class 'NoneType'>"):
|
||
|
pc.register_scalar_function(identity, func_name, doc,
|
||
|
in_types, out_type)
|
||
|
|
||
|
|
||
|
def test_scalar_udf_context(unary_func_fixture):
|
||
|
# Check the memory_pool argument is properly propagated
|
||
|
proxy_pool = pa.proxy_memory_pool(pa.default_memory_pool())
|
||
|
_, func_name = unary_func_fixture
|
||
|
|
||
|
res = pc.call_function(func_name,
|
||
|
[pa.array([1] * 1000, type=pa.int64())],
|
||
|
memory_pool=proxy_pool)
|
||
|
assert res == pa.array([2] * 1000, type=pa.int64())
|
||
|
assert proxy_pool.bytes_allocated() == 1000 * 8
|
||
|
# Destroying Python array should destroy underlying C++ memory
|
||
|
res = None
|
||
|
assert proxy_pool.bytes_allocated() == 0
|
||
|
|
||
|
|
||
|
def test_raising_func(raising_func_fixture):
|
||
|
_, func_name = raising_func_fixture
|
||
|
with pytest.raises(MyError, match="error raised by scalar UDF"):
|
||
|
pc.call_function(func_name, [], length=1)
|
||
|
|
||
|
|
||
|
def test_scalar_input(unary_func_fixture):
|
||
|
function, func_name = unary_func_fixture
|
||
|
res = pc.call_function(func_name, [pa.scalar(10)])
|
||
|
assert res == pa.scalar(11)
|
||
|
|
||
|
|
||
|
def test_input_lifetime(unary_func_fixture):
|
||
|
function, func_name = unary_func_fixture
|
||
|
|
||
|
proxy_pool = pa.proxy_memory_pool(pa.default_memory_pool())
|
||
|
assert proxy_pool.bytes_allocated() == 0
|
||
|
|
||
|
v = pa.array([1] * 1000, type=pa.int64(), memory_pool=proxy_pool)
|
||
|
assert proxy_pool.bytes_allocated() == 1000 * 8
|
||
|
pc.call_function(func_name, [v])
|
||
|
assert proxy_pool.bytes_allocated() == 1000 * 8
|
||
|
# Calling a UDF should not have kept `v` alive longer than required
|
||
|
v = None
|
||
|
assert proxy_pool.bytes_allocated() == 0
|
||
|
|
||
|
|
||
|
def _record_batch_from_iters(schema, *iters):
|
||
|
arrays = [pa.array(list(v), type=schema[i].type)
|
||
|
for i, v in enumerate(iters)]
|
||
|
return pa.RecordBatch.from_arrays(arrays=arrays, schema=schema)
|
||
|
|
||
|
|
||
|
def _record_batch_for_range(schema, n):
|
||
|
return _record_batch_from_iters(schema,
|
||
|
range(n, n + 10),
|
||
|
range(n + 1, n + 11))
|
||
|
|
||
|
|
||
|
def make_udt_func(schema, batch_gen):
|
||
|
def udf_func(ctx):
|
||
|
class UDT:
|
||
|
def __init__(self):
|
||
|
self.caller = None
|
||
|
|
||
|
def __call__(self, ctx):
|
||
|
try:
|
||
|
if self.caller is None:
|
||
|
self.caller, ctx = batch_gen(ctx).send, None
|
||
|
batch = self.caller(ctx)
|
||
|
except StopIteration:
|
||
|
arrays = [pa.array([], type=field.type)
|
||
|
for field in schema]
|
||
|
batch = pa.RecordBatch.from_arrays(
|
||
|
arrays=arrays, schema=schema)
|
||
|
return batch.to_struct_array()
|
||
|
return UDT()
|
||
|
return udf_func
|
||
|
|
||
|
|
||
|
def datasource1_direct():
|
||
|
"""A short dataset"""
|
||
|
schema = datasource1_schema()
|
||
|
|
||
|
class Generator:
|
||
|
def __init__(self):
|
||
|
self.n = 3
|
||
|
|
||
|
def __call__(self, ctx):
|
||
|
if self.n == 0:
|
||
|
batch = _record_batch_from_iters(schema, [], [])
|
||
|
else:
|
||
|
self.n -= 1
|
||
|
batch = _record_batch_for_range(schema, self.n)
|
||
|
return batch.to_struct_array()
|
||
|
return lambda ctx: Generator()
|
||
|
|
||
|
|
||
|
def datasource1_generator():
|
||
|
schema = datasource1_schema()
|
||
|
|
||
|
def batch_gen(ctx):
|
||
|
for n in range(3, 0, -1):
|
||
|
# ctx =
|
||
|
yield _record_batch_for_range(schema, n - 1)
|
||
|
return make_udt_func(schema, batch_gen)
|
||
|
|
||
|
|
||
|
def datasource1_exception():
|
||
|
schema = datasource1_schema()
|
||
|
|
||
|
def batch_gen(ctx):
|
||
|
for n in range(3, 0, -1):
|
||
|
# ctx =
|
||
|
yield _record_batch_for_range(schema, n - 1)
|
||
|
raise RuntimeError("datasource1_exception")
|
||
|
return make_udt_func(schema, batch_gen)
|
||
|
|
||
|
|
||
|
def datasource1_schema():
|
||
|
return pa.schema([('', pa.int32()), ('', pa.int32())])
|
||
|
|
||
|
|
||
|
def datasource1_args(func, func_name):
|
||
|
func_doc = {"summary": f"{func_name} UDT",
|
||
|
"description": "test {func_name} UDT"}
|
||
|
in_types = {}
|
||
|
out_type = pa.struct([("", pa.int32()), ("", pa.int32())])
|
||
|
return func, func_name, func_doc, in_types, out_type
|
||
|
|
||
|
|
||
|
def _test_datasource1_udt(func_maker):
|
||
|
schema = datasource1_schema()
|
||
|
func = func_maker()
|
||
|
func_name = func_maker.__name__
|
||
|
func_args = datasource1_args(func, func_name)
|
||
|
pc.register_tabular_function(*func_args)
|
||
|
n = 3
|
||
|
for item in pc.call_tabular_function(func_name):
|
||
|
n -= 1
|
||
|
assert item == _record_batch_for_range(schema, n)
|
||
|
|
||
|
|
||
|
def test_udt_datasource1_direct():
|
||
|
_test_datasource1_udt(datasource1_direct)
|
||
|
|
||
|
|
||
|
def test_udt_datasource1_generator():
|
||
|
_test_datasource1_udt(datasource1_generator)
|
||
|
|
||
|
|
||
|
def test_udt_datasource1_exception():
|
||
|
with pytest.raises(RuntimeError, match='datasource1_exception'):
|
||
|
_test_datasource1_udt(datasource1_exception)
|
||
|
|
||
|
|
||
|
@pytest.mark.numpy
|
||
|
def test_scalar_agg_basic(unary_agg_func_fixture):
|
||
|
arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64())
|
||
|
result = pc.call_function("mean_udf", [arr])
|
||
|
expected = pa.scalar(30.0)
|
||
|
assert result == expected
|
||
|
|
||
|
|
||
|
@pytest.mark.numpy
|
||
|
def test_scalar_agg_empty(unary_agg_func_fixture):
|
||
|
empty = pa.array([], pa.float64())
|
||
|
|
||
|
with pytest.raises(pa.ArrowInvalid, match='empty inputs'):
|
||
|
pc.call_function("mean_udf", [empty])
|
||
|
|
||
|
|
||
|
def test_scalar_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture):
|
||
|
arr = pa.array([10, 20, 30, 40, 50], pa.int64())
|
||
|
with pytest.raises(pa.ArrowTypeError, match="output datatype"):
|
||
|
pc.call_function("y=wrong_output_dtype(x)", [arr])
|
||
|
|
||
|
|
||
|
def test_scalar_agg_wrong_output_type(wrong_output_type_agg_func_fixture):
|
||
|
arr = pa.array([10, 20, 30, 40, 50], pa.int64())
|
||
|
with pytest.raises(pa.ArrowTypeError, match="output type"):
|
||
|
pc.call_function("y=wrong_output_type(x)", [arr])
|
||
|
|
||
|
|
||
|
@pytest.mark.numpy
|
||
|
def test_scalar_agg_varargs(varargs_agg_func_fixture):
|
||
|
arr1 = pa.array([10, 20, 30, 40, 50], pa.int64())
|
||
|
arr2 = pa.array([1.0, 2.0, 3.0, 4.0, 5.0], pa.float64())
|
||
|
|
||
|
result = pc.call_function(
|
||
|
"sum_mean", [arr1, arr2]
|
||
|
)
|
||
|
expected = pa.scalar(33.0)
|
||
|
assert result == expected
|
||
|
|
||
|
|
||
|
@pytest.mark.numpy
|
||
|
def test_scalar_agg_exception(exception_agg_func_fixture):
|
||
|
arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64())
|
||
|
|
||
|
with pytest.raises(RuntimeError, match='Oops'):
|
||
|
pc.call_function("y=exception_len(x)", [arr])
|
||
|
|
||
|
|
||
|
@pytest.mark.numpy
|
||
|
def test_hash_agg_basic(unary_agg_func_fixture):
|
||
|
arr1 = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64())
|
||
|
arr2 = pa.array([4, 2, 1, 2, 1], pa.int32())
|
||
|
|
||
|
arr3 = pa.array([60.0, 70.0, 80.0, 90.0, 100.0], pa.float64())
|
||
|
arr4 = pa.array([5, 1, 1, 4, 1], pa.int32())
|
||
|
|
||
|
table1 = pa.table([arr2, arr1], names=["id", "value"])
|
||
|
table2 = pa.table([arr4, arr3], names=["id", "value"])
|
||
|
table = pa.concat_tables([table1, table2])
|
||
|
|
||
|
result = table.group_by("id").aggregate([("value", "mean_udf")])
|
||
|
expected = table.group_by("id").aggregate(
|
||
|
[("value", "mean")]).rename_columns(['id', 'value_mean_udf'])
|
||
|
|
||
|
assert result.sort_by('id') == expected.sort_by('id')
|
||
|
|
||
|
|
||
|
@pytest.mark.numpy
|
||
|
def test_hash_agg_empty(unary_agg_func_fixture):
|
||
|
arr1 = pa.array([], pa.float64())
|
||
|
arr2 = pa.array([], pa.int32())
|
||
|
table = pa.table([arr2, arr1], names=["id", "value"])
|
||
|
|
||
|
result = table.group_by("id").aggregate([("value", "mean_udf")])
|
||
|
expected = pa.table([pa.array([], pa.int32()), pa.array(
|
||
|
[], pa.float64())], names=['id', 'value_mean_udf'])
|
||
|
|
||
|
assert result == expected
|
||
|
|
||
|
|
||
|
def test_hash_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture):
|
||
|
arr1 = pa.array([10, 20, 30, 40, 50], pa.int64())
|
||
|
arr2 = pa.array([4, 2, 1, 2, 1], pa.int32())
|
||
|
|
||
|
table = pa.table([arr2, arr1], names=["id", "value"])
|
||
|
with pytest.raises(pa.ArrowTypeError, match="output datatype"):
|
||
|
table.group_by("id").aggregate([("value", "y=wrong_output_dtype(x)")])
|
||
|
|
||
|
|
||
|
def test_hash_agg_wrong_output_type(wrong_output_type_agg_func_fixture):
|
||
|
arr1 = pa.array([10, 20, 30, 40, 50], pa.int64())
|
||
|
arr2 = pa.array([4, 2, 1, 2, 1], pa.int32())
|
||
|
table = pa.table([arr2, arr1], names=["id", "value"])
|
||
|
|
||
|
with pytest.raises(pa.ArrowTypeError, match="output type"):
|
||
|
table.group_by("id").aggregate([("value", "y=wrong_output_type(x)")])
|
||
|
|
||
|
|
||
|
@pytest.mark.numpy
|
||
|
def test_hash_agg_exception(exception_agg_func_fixture):
|
||
|
arr1 = pa.array([10, 20, 30, 40, 50], pa.int64())
|
||
|
arr2 = pa.array([4, 2, 1, 2, 1], pa.int32())
|
||
|
table = pa.table([arr2, arr1], names=["id", "value"])
|
||
|
|
||
|
with pytest.raises(RuntimeError, match='Oops'):
|
||
|
table.group_by("id").aggregate([("value", "y=exception_len(x)")])
|
||
|
|
||
|
|
||
|
@pytest.mark.numpy
|
||
|
def test_hash_agg_random(sum_agg_func_fixture):
|
||
|
"""Test hash aggregate udf with randomly sampled data"""
|
||
|
|
||
|
value_num = 1000000
|
||
|
group_num = 1000
|
||
|
|
||
|
arr1 = pa.array(np.repeat(1, value_num), pa.float64())
|
||
|
arr2 = pa.array(np.random.choice(group_num, value_num), pa.int32())
|
||
|
|
||
|
table = pa.table([arr2, arr1], names=['id', 'value'])
|
||
|
|
||
|
result = table.group_by("id").aggregate([("value", "sum_udf")])
|
||
|
expected = table.group_by("id").aggregate(
|
||
|
[("value", "sum")]).rename_columns(['id', 'value_sum_udf'])
|
||
|
|
||
|
assert result.sort_by('id') == expected.sort_by('id')
|
||
|
|
||
|
|
||
|
@pytest.mark.pandas
|
||
|
def test_vector_basic(unary_vector_func_fixture):
|
||
|
arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64())
|
||
|
result = pc.call_function("y=pct_rank(x)", [arr])
|
||
|
expected = unary_vector_func_fixture[0](None, arr)
|
||
|
assert result == expected
|
||
|
|
||
|
|
||
|
@pytest.mark.pandas
|
||
|
def test_vector_empty(unary_vector_func_fixture):
|
||
|
arr = pa.array([1], pa.float64())
|
||
|
result = pc.call_function("y=pct_rank(x)", [arr])
|
||
|
expected = unary_vector_func_fixture[0](None, arr)
|
||
|
assert result == expected
|
||
|
|
||
|
|
||
|
@pytest.mark.pandas
|
||
|
def test_vector_struct(struct_vector_func_fixture):
|
||
|
k = pa.array(
|
||
|
[1, 1, 2, 2], pa.int64()
|
||
|
)
|
||
|
v = pa.array(
|
||
|
[1.0, 2.0, 3.0, 4.0], pa.float64()
|
||
|
)
|
||
|
c = pa.array(
|
||
|
['v1', 'v2', 'v1', 'v2']
|
||
|
)
|
||
|
result = pc.call_function("y=pivot(x)", [k, v, c])
|
||
|
expected = struct_vector_func_fixture[0](None, k, v, c)
|
||
|
assert result == expected
|