232 lines
7.4 KiB
Python
232 lines
7.4 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import base64
|
|
from datetime import timedelta
|
|
import random
|
|
import pyarrow.fs as fs
|
|
import pyarrow as pa
|
|
|
|
import pytest
|
|
|
|
encryption_unavailable = False
|
|
|
|
try:
|
|
import pyarrow.parquet as pq
|
|
import pyarrow.dataset as ds
|
|
except ImportError:
|
|
pq = None
|
|
ds = None
|
|
|
|
try:
|
|
from pyarrow.tests.parquet.encryption import InMemoryKmsClient
|
|
import pyarrow.parquet.encryption as pe
|
|
except ImportError:
|
|
encryption_unavailable = True
|
|
|
|
|
|
# Marks all of the tests in this module
|
|
pytestmark = pytest.mark.dataset
|
|
|
|
|
|
FOOTER_KEY = b"0123456789112345"
|
|
FOOTER_KEY_NAME = "footer_key"
|
|
COL_KEY = b"1234567890123450"
|
|
COL_KEY_NAME = "col_key"
|
|
|
|
|
|
def create_sample_table():
|
|
return pa.table(
|
|
{
|
|
"year": [2020, 2022, 2021, 2022, 2019, 2021],
|
|
"n_legs": [2, 2, 4, 4, 5, 100],
|
|
"animal": [
|
|
"Flamingo",
|
|
"Parrot",
|
|
"Dog",
|
|
"Horse",
|
|
"Brittle stars",
|
|
"Centipede",
|
|
],
|
|
}
|
|
)
|
|
|
|
|
|
def create_encryption_config():
|
|
return pe.EncryptionConfiguration(
|
|
footer_key=FOOTER_KEY_NAME,
|
|
plaintext_footer=False,
|
|
column_keys={COL_KEY_NAME: ["n_legs", "animal"]},
|
|
encryption_algorithm="AES_GCM_V1",
|
|
# requires timedelta or an assertion is raised
|
|
cache_lifetime=timedelta(minutes=5.0),
|
|
data_key_length_bits=256,
|
|
)
|
|
|
|
|
|
def create_decryption_config():
|
|
return pe.DecryptionConfiguration(cache_lifetime=300)
|
|
|
|
|
|
def create_kms_connection_config():
|
|
return pe.KmsConnectionConfig(
|
|
custom_kms_conf={
|
|
FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
|
|
COL_KEY_NAME: COL_KEY.decode("UTF-8"),
|
|
}
|
|
)
|
|
|
|
|
|
def kms_factory(kms_connection_configuration):
|
|
return InMemoryKmsClient(kms_connection_configuration)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
encryption_unavailable, reason="Parquet Encryption is not currently enabled"
|
|
)
|
|
def test_dataset_encryption_decryption():
|
|
table = create_sample_table()
|
|
|
|
encryption_config = create_encryption_config()
|
|
decryption_config = create_decryption_config()
|
|
kms_connection_config = create_kms_connection_config()
|
|
|
|
crypto_factory = pe.CryptoFactory(kms_factory)
|
|
parquet_encryption_cfg = ds.ParquetEncryptionConfig(
|
|
crypto_factory, kms_connection_config, encryption_config
|
|
)
|
|
parquet_decryption_cfg = ds.ParquetDecryptionConfig(
|
|
crypto_factory, kms_connection_config, decryption_config
|
|
)
|
|
|
|
# create write_options with dataset encryption config
|
|
pformat = pa.dataset.ParquetFileFormat()
|
|
write_options = pformat.make_write_options(encryption_config=parquet_encryption_cfg)
|
|
|
|
mockfs = fs._MockFileSystem()
|
|
mockfs.create_dir("/")
|
|
|
|
ds.write_dataset(
|
|
data=table,
|
|
base_dir="sample_dataset",
|
|
format=pformat,
|
|
file_options=write_options,
|
|
filesystem=mockfs,
|
|
)
|
|
|
|
# read without decryption config -> should error is dataset was properly encrypted
|
|
pformat = pa.dataset.ParquetFileFormat()
|
|
with pytest.raises(IOError, match=r"no decryption"):
|
|
ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)
|
|
|
|
# set decryption config for parquet fragment scan options
|
|
pq_scan_opts = ds.ParquetFragmentScanOptions(
|
|
decryption_config=parquet_decryption_cfg
|
|
)
|
|
pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts)
|
|
dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)
|
|
|
|
assert table.equals(dataset.to_table())
|
|
|
|
# set decryption properties for parquet fragment scan options
|
|
decryption_properties = crypto_factory.file_decryption_properties(
|
|
kms_connection_config, decryption_config)
|
|
pq_scan_opts = ds.ParquetFragmentScanOptions(
|
|
decryption_properties=decryption_properties
|
|
)
|
|
|
|
pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts)
|
|
dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)
|
|
|
|
assert table.equals(dataset.to_table())
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not encryption_unavailable, reason="Parquet Encryption is currently enabled"
|
|
)
|
|
def test_write_dataset_parquet_without_encryption():
|
|
"""Test write_dataset with ParquetFileFormat and test if an exception is thrown
|
|
if you try to set encryption_config using make_write_options"""
|
|
|
|
# Set the encryption configuration using ParquetFileFormat
|
|
# and make_write_options
|
|
pformat = pa.dataset.ParquetFileFormat()
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
_ = pformat.make_write_options(encryption_config="some value")
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
encryption_unavailable, reason="Parquet Encryption is not currently enabled"
|
|
)
|
|
def test_large_row_encryption_decryption():
|
|
"""Test encryption and decryption of a large number of rows."""
|
|
|
|
class NoOpKmsClient(pe.KmsClient):
|
|
def wrap_key(self, key_bytes: bytes, _: str) -> bytes:
|
|
b = base64.b64encode(key_bytes)
|
|
return b
|
|
|
|
def unwrap_key(self, wrapped_key: bytes, _: str) -> bytes:
|
|
b = base64.b64decode(wrapped_key)
|
|
return b
|
|
|
|
row_count = 2**15 + 1
|
|
table = pa.Table.from_arrays(
|
|
[pa.array(
|
|
[random.random() for _ in range(row_count)],
|
|
type=pa.float32()
|
|
)], names=["foo"]
|
|
)
|
|
|
|
kms_config = pe.KmsConnectionConfig()
|
|
crypto_factory = pe.CryptoFactory(lambda _: NoOpKmsClient())
|
|
encryption_config = pe.EncryptionConfiguration(
|
|
footer_key="UNIMPORTANT_KEY",
|
|
column_keys={"UNIMPORTANT_KEY": ["foo"]},
|
|
double_wrapping=True,
|
|
plaintext_footer=False,
|
|
data_key_length_bits=128,
|
|
)
|
|
pqe_config = ds.ParquetEncryptionConfig(
|
|
crypto_factory, kms_config, encryption_config
|
|
)
|
|
pqd_config = ds.ParquetDecryptionConfig(
|
|
crypto_factory, kms_config, pe.DecryptionConfiguration()
|
|
)
|
|
scan_options = ds.ParquetFragmentScanOptions(decryption_config=pqd_config)
|
|
file_format = ds.ParquetFileFormat(default_fragment_scan_options=scan_options)
|
|
write_options = file_format.make_write_options(encryption_config=pqe_config)
|
|
file_decryption_properties = crypto_factory.file_decryption_properties(kms_config)
|
|
|
|
mockfs = fs._MockFileSystem()
|
|
mockfs.create_dir("/")
|
|
|
|
path = "large-row-test-dataset"
|
|
ds.write_dataset(table, path, format=file_format,
|
|
file_options=write_options, filesystem=mockfs)
|
|
|
|
file_path = path + "/part-0.parquet"
|
|
new_table = pq.ParquetFile(
|
|
file_path, decryption_properties=file_decryption_properties,
|
|
filesystem=mockfs
|
|
).read()
|
|
assert table == new_table
|
|
|
|
dataset = ds.dataset(path, format=file_format, filesystem=mockfs)
|
|
new_table = dataset.to_table()
|
|
assert table == new_table
|