diff --git a/pyiceberg/encryption/__init__.py b/pyiceberg/encryption/__init__.py new file mode 100644 index 0000000000..9c86adbdce --- /dev/null +++ b/pyiceberg/encryption/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Iceberg table encryption support.""" diff --git a/pyiceberg/encryption/ciphers.py b/pyiceberg/encryption/ciphers.py new file mode 100644 index 0000000000..ed023ce53e --- /dev/null +++ b/pyiceberg/encryption/ciphers.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""AES-GCM encryption/decryption primitives and AGS1 stream decryption.""" + +from __future__ import annotations + +import os +import struct + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +NONCE_LENGTH = 12 +GCM_TAG_LENGTH = 16 + + +def aes_gcm_encrypt(key: bytes, plaintext: bytes, aad: bytes | None = None) -> bytes: + """Encrypt using AES-GCM. Returns nonce || ciphertext || tag.""" + nonce = os.urandom(NONCE_LENGTH) + aesgcm = AESGCM(key) + ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, aad) + return nonce + ciphertext_with_tag + + +def aes_gcm_decrypt(key: bytes, ciphertext: bytes, aad: bytes | None = None) -> bytes: + """Decrypt AES-GCM data in format: nonce || ciphertext || tag.""" + if len(ciphertext) < NONCE_LENGTH + GCM_TAG_LENGTH: + raise ValueError(f"Ciphertext too short: {len(ciphertext)} bytes") + nonce = ciphertext[:NONCE_LENGTH] + encrypted_data = ciphertext[NONCE_LENGTH:] + aesgcm = AESGCM(key) + return aesgcm.decrypt(nonce, encrypted_data, aad) + + +# AGS1 stream constants +GCM_STREAM_MAGIC = b"AGS1" +GCM_STREAM_HEADER_LENGTH = 8 # 4 magic + 4 block size + + +def stream_block_aad(aad_prefix: bytes, block_index: int) -> bytes: + """Construct per-block AAD for AGS1 stream encryption. + + Format: aad_prefix || block_index (4 bytes, little-endian). + """ + index_bytes = struct.pack(" bytes: + """Decrypt an entire AGS1 stream and return the plaintext. + + AGS1 format: + - Header: "AGS1" (4 bytes) + plain_block_size (4 bytes LE) + - Blocks: each block is nonce(12) + ciphertext(up to 1MB) + tag(16) + - Each block's AAD = aad_prefix + block_index (4 bytes LE) + + """ + if len(encrypted_data) < GCM_STREAM_HEADER_LENGTH: + raise ValueError(f"AGS1 stream too short: {len(encrypted_data)} bytes") + + magic = encrypted_data[:4] + if magic != GCM_STREAM_MAGIC: + raise ValueError(f"Invalid AGS1 magic: {magic!r}, expected {GCM_STREAM_MAGIC!r}") + + plain_block_size = struct.unpack_from("= cipher_block_size: + block_cipher_size = cipher_block_size + else: + block_cipher_size = remaining + + if block_cipher_size < NONCE_LENGTH + GCM_TAG_LENGTH: + raise ValueError( + f"Truncated AGS1 block at offset {offset}: {block_cipher_size} bytes (minimum {NONCE_LENGTH + GCM_TAG_LENGTH})" + ) + + block_data = stream_data[offset : offset + block_cipher_size] + nonce = block_data[:NONCE_LENGTH] + ciphertext_with_tag = block_data[NONCE_LENGTH:] + + aad = stream_block_aad(aad_prefix, block_index) + plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, aad) + result.extend(plaintext) + + offset += block_cipher_size + block_index += 1 + + return bytes(result) diff --git a/pyiceberg/encryption/io.py b/pyiceberg/encryption/io.py new file mode 100644 index 0000000000..310cd842a3 --- /dev/null +++ b/pyiceberg/encryption/io.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""InputFile implementation backed by in-memory bytes.""" + +from __future__ import annotations + +import io +from types import TracebackType + +from pyiceberg.io import InputFile, InputStream + + +class BytesInputStream(InputStream): + """InputStream implementation backed by a bytes buffer.""" + + def __init__(self, data: bytes) -> None: + self._buffer = io.BytesIO(data) + + def read(self, size: int = 0) -> bytes: + if size <= 0: + return self._buffer.read() + return self._buffer.read(size) + + def seek(self, offset: int, whence: int = 0) -> int: + return self._buffer.seek(offset, whence) + + def tell(self) -> int: + return self._buffer.tell() + + def close(self) -> None: + self._buffer.close() + + def __enter__(self) -> BytesInputStream: + """Enter the context manager.""" + return self + + def __exit__( + self, + exctype: type[BaseException] | None, + excinst: BaseException | None, + exctb: TracebackType | None, + ) -> None: + """Exit the context manager and close the stream.""" + self.close() + + +class BytesInputFile(InputFile): + """InputFile implementation backed by in-memory bytes. + + Used to wrap decrypted data so that it can be read by + AvroFile and other readers that expect an InputFile. + """ + + def __init__(self, location: str, data: bytes) -> None: + super().__init__(location) + self._data = data + + def __len__(self) -> int: + """Return the length of the underlying data.""" + return len(self._data) + + def exists(self) -> bool: + return True + + def open(self, seekable: bool = True) -> InputStream: + return BytesInputStream(self._data) diff --git a/pyiceberg/encryption/key_metadata.py b/pyiceberg/encryption/key_metadata.py new file mode 100644 index 0000000000..4d22778d5c --- /dev/null +++ b/pyiceberg/encryption/key_metadata.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""StandardKeyMetadata Avro serialization. + +Wire format: ``0x01 version byte || Avro-encoded fields`` + +Avro schema: + - encryption_key: bytes (required) + - aad_prefix: union[null, bytes] (optional) + - file_length: union[null, long] (optional) +""" + +from __future__ import annotations + +from dataclasses import dataclass + +V1 = 0x01 + + +def _read_avro_long(data: bytes, offset: int) -> tuple[int, int]: + """Read a zigzag-encoded Avro long from data at offset. Returns (value, new_offset).""" + result = 0 + shift = 0 + while True: + if offset >= len(data): + raise ValueError("Unexpected end of Avro data reading long") + b = data[offset] + offset += 1 + result |= (b & 0x7F) << shift + if (b & 0x80) == 0: + break + shift += 7 + # Zigzag decode + return (result >> 1) ^ -(result & 1), offset + + +def _read_avro_bytes(data: bytes, offset: int) -> tuple[bytes, int]: + """Read Avro bytes (length-prefixed). Returns (bytes_value, new_offset).""" + length, offset = _read_avro_long(data, offset) + if length < 0: + raise ValueError(f"Negative Avro bytes length: {length}") + end = offset + length + if end > len(data): + raise ValueError("Unexpected end of Avro data reading bytes") + return data[offset:end], end + + +@dataclass(frozen=True) +class StandardKeyMetadata: + """Standard key metadata for Iceberg table encryption. + + Contains the plaintext encryption key (DEK), AAD prefix, and optional file length. + """ + + encryption_key: bytes + aad_prefix: bytes = b"" + file_length: int | None = None + + @staticmethod + def deserialize(data: bytes) -> StandardKeyMetadata: + """Deserialize from wire format: ``0x01 version || Avro-encoded fields``.""" + if not data: + raise ValueError("Empty key metadata buffer") + + version = data[0] + if version != V1: + raise ValueError(f"Unsupported key metadata version: {version}") + + offset = 1 + + # Read encryption_key (required bytes) + encryption_key, offset = _read_avro_bytes(data, offset) + + # Read aad_prefix (optional: union[null, bytes]) + union_index, offset = _read_avro_long(data, offset) + if union_index == 0: + aad_prefix = b"" + elif union_index == 1: + aad_prefix, offset = _read_avro_bytes(data, offset) + else: + raise ValueError(f"Invalid union index for aad_prefix: {union_index}") + + # Read file_length (optional: union[null, long]) + file_length = None + if offset < len(data): + union_index, offset = _read_avro_long(data, offset) + if union_index == 0: + file_length = None + elif union_index == 1: + file_length, offset = _read_avro_long(data, offset) + else: + raise ValueError(f"Invalid union index for file_length: {union_index}") + + return StandardKeyMetadata( + encryption_key=encryption_key, + aad_prefix=aad_prefix, + file_length=file_length, + ) + + def serialize(self) -> bytes: + """Serialize to wire format: ``0x01 version || Avro-encoded fields``.""" + parts = [bytes([V1])] + + # encryption_key (required bytes) + parts.append(_encode_avro_bytes(self.encryption_key)) + + # aad_prefix (union[null, bytes]) + if self.aad_prefix: + parts.append(_encode_avro_long(1)) # union index 1 = bytes + parts.append(_encode_avro_bytes(self.aad_prefix)) + else: + parts.append(_encode_avro_long(0)) # union index 0 = null + + # file_length (union[null, long]) + if self.file_length is not None: + parts.append(_encode_avro_long(1)) # union index 1 = long + parts.append(_encode_avro_long(self.file_length)) + else: + parts.append(_encode_avro_long(0)) # union index 0 = null + + return b"".join(parts) + + +def _encode_avro_long(value: int) -> bytes: + """Encode a long as zigzag-encoded Avro varint.""" + # Zigzag encode + n = (value << 1) ^ (value >> 63) + result = bytearray() + while n & ~0x7F: + result.append((n & 0x7F) | 0x80) + n >>= 7 + result.append(n & 0x7F) + return bytes(result) + + +def _encode_avro_bytes(data: bytes) -> bytes: + """Encode bytes with Avro length prefix.""" + return _encode_avro_long(len(data)) + data diff --git a/pyiceberg/encryption/kms.py b/pyiceberg/encryption/kms.py new file mode 100644 index 0000000000..2ba1a5e1ac --- /dev/null +++ b/pyiceberg/encryption/kms.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Key Management Service interfaces and implementations.""" + +from __future__ import annotations + +import importlib +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from pyiceberg.encryption.ciphers import aes_gcm_decrypt, aes_gcm_encrypt + +if TYPE_CHECKING: + from pyiceberg.typedef import Properties + +logger = logging.getLogger(__name__) + +PY_KMS_IMPL = "py-kms-impl" + + +class KeyManagementClient(ABC): + """Abstract base class for key management operations.""" + + @abstractmethod + def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: + """Wrap (encrypt) a key using the master key identified by wrapping_key_id.""" + + @abstractmethod + def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: + """Unwrap (decrypt) a wrapped key using the master key identified by wrapping_key_id.""" + + def initialize(self, properties: dict[str, str]) -> None: # noqa: B027 + """Initialize the KMS client from catalog/table properties.""" + + +class InMemoryKms(KeyManagementClient): + """In-memory KMS for testing. NOT for production use.""" + + def __init__(self, master_keys: dict[str, bytes] | None = None) -> None: + self._master_keys: dict[str, bytes] = dict(master_keys) if master_keys else {} + + def initialize(self, properties: dict[str, str]) -> None: + for key, value in properties.items(): + if key.startswith("encryption.kms.key."): + key_id = key[len("encryption.kms.key.") :] + self._master_keys[key_id] = bytes.fromhex(value) + + def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: + master_key = self._master_keys.get(wrapping_key_id) + if master_key is None: + raise ValueError(f"Wrapping key not found: {wrapping_key_id}") + return aes_gcm_encrypt(master_key, key, aad=None) + + def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: + master_key = self._master_keys.get(wrapping_key_id) + if master_key is None: + raise ValueError(f"Wrapping key not found: {wrapping_key_id}") + return aes_gcm_decrypt(master_key, wrapped_key, aad=None) + + +def load_kms_client(properties: Properties) -> KeyManagementClient | None: + """Load a KMS client from properties using py-kms-impl. + + Follows the same pattern as py-io-impl for FileIO. + + The property 'py-kms-impl' should be a fully qualified Python class name + (e.g., 'pyiceberg.encryption.kms.InMemoryKms'). The class must be a + subclass of KeyManagementClient. + + Args: + properties: Catalog and/or table properties. + + Returns: + An initialized KeyManagementClient, or None if py-kms-impl is not set. + """ + kms_impl = properties.get(PY_KMS_IMPL) + if kms_impl is None: + return None + + path_parts = kms_impl.split(".") + if len(path_parts) < 2: + raise ValueError(f"py-kms-impl should be a full path (module.ClassName), got: {kms_impl}") + + module_name, class_name = ".".join(path_parts[:-1]), path_parts[-1] + try: + module = importlib.import_module(module_name) + except ModuleNotFoundError as e: + raise ValueError(f"Could not import KMS module: {module_name}") from e + + kms_class = getattr(module, class_name, None) + if kms_class is None: + raise ValueError(f"KMS class {class_name} not found in module {module_name}") + + if not (isinstance(kms_class, type) and issubclass(kms_class, KeyManagementClient)): + raise ValueError(f"{kms_impl} is not a subclass of KeyManagementClient") + + client = kms_class() + client.initialize(dict(properties)) + return client diff --git a/pyiceberg/encryption/manager.py b/pyiceberg/encryption/manager.py new file mode 100644 index 0000000000..07219d35ad --- /dev/null +++ b/pyiceberg/encryption/manager.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Encryption manager implementing two-layer envelope key management. + +Key hierarchy: + - Master Key (in KMS) wraps KEK + - KEK wraps DEK (using local AES-GCM) + - DEK encrypts data (manifest lists, manifests, data files) + +The KEK timestamp is used as AAD when wrapping/unwrapping DEKs +to prevent timestamp tampering attacks. +""" + +from __future__ import annotations + +import logging + +from pyiceberg.encryption.ciphers import aes_gcm_decrypt, decrypt_ags1_stream +from pyiceberg.encryption.key_metadata import StandardKeyMetadata +from pyiceberg.encryption.kms import KeyManagementClient + +logger = logging.getLogger(__name__) + +KEK_CREATED_AT_PROPERTY = "KEY_TIMESTAMP" + + +class EncryptedKey: + """Represents an encrypted key entry from table metadata.""" + + def __init__( + self, + key_id: str, + encrypted_key_metadata: bytes, + encrypted_by_id: str | None = None, + properties: dict[str, str] | None = None, + ) -> None: + self.key_id = key_id + self.encrypted_key_metadata = encrypted_key_metadata + self.encrypted_by_id = encrypted_by_id + self.properties = properties or {} + + def __repr__(self) -> str: + """Return a string representation of the EncryptedKey.""" + return ( + f"EncryptedKey(key_id={self.key_id!r}, " + f"encrypted_by_id={self.encrypted_by_id!r}, " + f"metadata_len={len(self.encrypted_key_metadata)})" + ) + + +class EncryptionManager: + """Manages encryption/decryption for an Iceberg table. + + Orchestrates the two-layer envelope key management: + 1. Unwrap KEK via KMS using master key + 2. Use KEK to decrypt manifest list/manifest key metadata (with timestamp AAD) + 3. Parse StandardKeyMetadata to get DEK + AAD prefix + 4. Decrypt AGS1 streams or provide FileDecryptionProperties for Parquet + """ + + def __init__( + self, + kms_client: KeyManagementClient, + encryption_keys: dict[str, EncryptedKey] | None = None, + ) -> None: + self._kms = kms_client + self._encryption_keys = encryption_keys or {} + self._kek_cache: dict[str, bytes] = {} + + def _unwrap_kek(self, kek: EncryptedKey) -> bytes: + """Unwrap a KEK using the KMS, with caching.""" + if kek.key_id in self._kek_cache: + return self._kek_cache[kek.key_id] + + if not kek.encrypted_by_id: + raise ValueError(f"KEK '{kek.key_id}' has no encrypted_by_id") + + plaintext = self._kms.unwrap_key(kek.encrypted_key_metadata, kek.encrypted_by_id) + self._kek_cache[kek.key_id] = plaintext + return plaintext + + def _unwrap_dek(self, wrapped_dek: bytes, kek_key_id: str) -> bytes: + """Unwrap a DEK using the specified KEK. + + Uses the KEK timestamp as AAD to prevent timestamp tampering. + """ + kek = self._encryption_keys.get(kek_key_id) + if kek is None: + raise ValueError(f"KEK not found in encryption keys: {kek_key_id}") + + kek_bytes = self._unwrap_kek(kek) + + # Use KEK timestamp as AAD to prevent tampering + aad = kek.properties.get(KEK_CREATED_AT_PROPERTY) + aad_bytes = aad.encode("utf-8") if aad else None + + return aes_gcm_decrypt(kek_bytes, wrapped_dek, aad=aad_bytes) + + def unwrap_key_metadata(self, encrypted_key: EncryptedKey) -> bytes: + """Unwrap key metadata that was KEK-wrapped. + + Given an EncryptedKey entry (e.g., from a snapshot's key-id mapping), + unwrap it using the KEK identified by encrypted_by_id. + """ + if not encrypted_key.encrypted_by_id: + raise ValueError(f"EncryptedKey '{encrypted_key.key_id}' has no encrypted_by_id") + + return self._unwrap_dek( + encrypted_key.encrypted_key_metadata, + encrypted_key.encrypted_by_id, + ) + + def decrypt_manifest_list(self, encrypted_data: bytes, snapshot_key_id: str) -> bytes: + """Decrypt an AGS1-encrypted manifest list. + + 1. Look up the EncryptedKey for the snapshot's key_id + 2. Unwrap the key metadata using the KEK + 3. Parse StandardKeyMetadata to get DEK + AAD prefix + 4. Decrypt the AGS1 stream + """ + encrypted_key = self._encryption_keys.get(snapshot_key_id) + if encrypted_key is None: + raise ValueError(f"Snapshot key not found in encryption keys: {snapshot_key_id}") + + # Unwrap the key metadata + key_metadata_bytes = self.unwrap_key_metadata(encrypted_key) + key_metadata = StandardKeyMetadata.deserialize(key_metadata_bytes) + + return decrypt_ags1_stream( + key=key_metadata.encryption_key, + encrypted_data=encrypted_data, + aad_prefix=key_metadata.aad_prefix, + ) + + def decrypt_manifest(self, encrypted_data: bytes, key_metadata_bytes: bytes) -> bytes: + """Decrypt an AGS1-encrypted manifest file. + + The key_metadata_bytes are from ManifestFile.key_metadata -- these contain + the plaintext DEK and AAD prefix (NOT wrapped by KEK, since they're already + stored inside the encrypted manifest list). + """ + key_metadata = StandardKeyMetadata.deserialize(key_metadata_bytes) + + return decrypt_ags1_stream( + key=key_metadata.encryption_key, + encrypted_data=encrypted_data, + aad_prefix=key_metadata.aad_prefix, + ) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 8f22261f5d..851d83c96e 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -185,6 +185,7 @@ from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string if TYPE_CHECKING: + from pyiceberg.encryption.manager import EncryptionManager from pyiceberg.table import FileScanTask, WriteTask logger = logging.getLogger(__name__) @@ -1106,12 +1107,47 @@ def _get_file_format(file_format: FileFormat, **kwargs: dict[str, Any]) -> ds.Fi raise ValueError(f"Unsupported file format: {file_format}") -def _read_deletes(io: FileIO, data_file: DataFile) -> dict[str, pa.ChunkedArray]: +def _get_decryption_properties(key_metadata_bytes: bytes) -> Any: + """Build FileDecryptionProperties from Iceberg key metadata. + + Requires a custom PyArrow build with pyarrow.parquet.encryption support. + """ + try: + import pyarrow.parquet.encryption as pe + except ImportError as e: + raise ImportError( + "Parquet Modular Encryption requires a PyArrow build with encryption support. " + "See PYARROW_ENCRYPTION_HANDOFF.md for build instructions." + ) from e + + from pyiceberg.encryption.key_metadata import StandardKeyMetadata + + key_metadata = StandardKeyMetadata.deserialize(key_metadata_bytes) + return pe.create_decryption_properties( + footer_key=key_metadata.encryption_key, + aad_prefix=key_metadata.aad_prefix if key_metadata.aad_prefix else None, + ) + + +def _read_deletes( + io: FileIO, data_file: DataFile, encryption_manager: EncryptionManager | None = None +) -> dict[str, pa.ChunkedArray]: if data_file.file_format == FileFormat.PARQUET: + arrow_format = _get_file_format( + data_file.file_format, dictionary_columns=("file_path",), pre_buffer=True, buffer_size=ONE_MEGABYTE + ) + + if data_file.key_metadata is not None and encryption_manager is not None: + decryption_properties = _get_decryption_properties(data_file.key_metadata) + scan_options = ds.ParquetFragmentScanOptions( + decryption_properties=decryption_properties, + pre_buffer=True, + buffer_size=ONE_MEGABYTE, + ) + arrow_format = ds.ParquetFileFormat(default_fragment_scan_options=scan_options) + with io.new_input(data_file.file_path).open() as fi: - delete_fragment = _get_file_format( - data_file.file_format, dictionary_columns=("file_path",), pre_buffer=True, buffer_size=ONE_MEGABYTE - ).make_fragment(fi) + delete_fragment = arrow_format.make_fragment(fi) table = ds.Scanner.from_fragment(fragment=delete_fragment).to_table() table = table.unify_dictionaries() return { @@ -1614,8 +1650,21 @@ def _task_to_record_batches( partition_spec: PartitionSpec | None = None, format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, downcast_ns_timestamp_to_us: bool | None = None, + encryption_manager: EncryptionManager | None = None, ) -> Iterator[pa.RecordBatch]: arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) + + # For encrypted files, create a ParquetFileFormat with decryption properties + # so that make_fragment can read the encrypted metadata + if task.file.key_metadata is not None and encryption_manager is not None: + decryption_properties = _get_decryption_properties(task.file.key_metadata) + scan_options = ds.ParquetFragmentScanOptions( + decryption_properties=decryption_properties, + pre_buffer=True, + buffer_size=(ONE_MEGABYTE * 8), + ) + arrow_format = ds.ParquetFileFormat(default_fragment_scan_options=scan_options) + with io.new_input(task.file.file_path).open() as fin: fragment = arrow_format.make_fragment(fin) physical_schema = fragment.physical_schema @@ -1645,14 +1694,15 @@ def _task_to_record_batches( file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) - fragment_scanner = ds.Scanner.from_fragment( - fragment=fragment, - schema=physical_schema, + scanner_kwargs: dict[str, Any] = { + "fragment": fragment, + "schema": physical_schema, # This will push down the query to Arrow. # But in case there are positional deletes, we have to apply them first - filter=pyarrow_filter if not positional_deletes else None, - columns=[col.name for col in file_project_schema.columns], - ) + "filter": pyarrow_filter if not positional_deletes else None, + "columns": [col.name for col in file_project_schema.columns], + } + fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs) next_index = 0 batches = fragment_scanner.to_batches() @@ -1691,14 +1741,16 @@ def _task_to_record_batches( ) -def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[str, list[ChunkedArray]]: +def _read_all_delete_files( + io: FileIO, tasks: Iterable[FileScanTask], encryption_manager: EncryptionManager | None = None +) -> dict[str, list[ChunkedArray]]: deletes_per_file: dict[str, list[ChunkedArray]] = {} unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks])) if len(unique_deletes) > 0: executor = ExecutorFactory.get_or_create() deletes_per_files: Iterator[dict[str, ChunkedArray]] = executor.map( lambda args: _read_deletes(*args), - [(io, delete_file) for delete_file in unique_deletes], + [(io, delete_file, encryption_manager) for delete_file in unique_deletes], ) for delete in deletes_per_files: for file, arr in delete.items(): @@ -1718,6 +1770,7 @@ class ArrowScan: _case_sensitive: bool _limit: int | None _downcast_ns_timestamp_to_us: bool | None + _encryption_manager: EncryptionManager | None """Scan the Iceberg Table and create an Arrow construct. Attributes: @@ -1727,6 +1780,7 @@ class ArrowScan: _bound_row_filter: Schema bound row expression to filter the data with _case_sensitive: Case sensitivity when looking up column names _limit: Limit the number of records. + _encryption_manager: Optional encryption manager for decrypting data files. """ def __init__( @@ -1737,6 +1791,7 @@ def __init__( row_filter: BooleanExpression, case_sensitive: bool = True, limit: int | None = None, + encryption_manager: EncryptionManager | None = None, ) -> None: self._table_metadata = table_metadata self._io = io @@ -1745,6 +1800,7 @@ def __init__( self._case_sensitive = case_sensitive self._limit = limit self._downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) + self._encryption_manager = encryption_manager @property def _projected_field_ids(self) -> set[int]: @@ -1807,7 +1863,7 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record ResolveError: When a required field cannot be found in the file ValueError: When a field type in the file cannot be projected to the schema type """ - deletes_per_file = _read_all_delete_files(self._io, tasks) + deletes_per_file = _read_all_delete_files(self._io, tasks, self._encryption_manager) total_row_count = 0 executor = ExecutorFactory.get_or_create() @@ -1855,6 +1911,7 @@ def _record_batches_from_scan_tasks_and_deletes( self._table_metadata.specs().get(task.file.spec_id), self._table_metadata.format_version, self._downcast_ns_timestamp_to_us, + encryption_manager=self._encryption_manager, ) for batch in batches: if self._limit is not None: diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index cca0af7628..319c95c524 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -24,6 +24,7 @@ from enum import Enum from types import TracebackType from typing import ( + TYPE_CHECKING, Any, Literal, ) @@ -39,6 +40,9 @@ from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema from pyiceberg.typedef import Record, TableVersion + +if TYPE_CHECKING: + from pyiceberg.encryption.manager import EncryptionManager from pyiceberg.types import ( BinaryType, BooleanType, @@ -858,18 +862,34 @@ def has_added_files(self) -> bool: def has_existing_files(self) -> bool: return self.existing_files_count is None or self.existing_files_count > 0 - def fetch_manifest_entry(self, io: FileIO, discard_deleted: bool = True) -> list[ManifestEntry]: + def fetch_manifest_entry( + self, + io: FileIO, + discard_deleted: bool = True, + encryption_manager: EncryptionManager | None = None, + ) -> list[ManifestEntry]: """ Read the manifest entries from the manifest file. Args: io: The FileIO to fetch the file. discard_deleted: Filter on live entries. + encryption_manager: Optional encryption manager for decrypting encrypted manifests. Returns: An Iterator of manifest entries. """ input_file = io.new_input(self.manifest_path) + + # If this manifest has key_metadata, it's AGS1-encrypted + if self.key_metadata is not None and encryption_manager is not None: + from pyiceberg.encryption.io import BytesInputFile + + with input_file.open() as f: + encrypted_data = f.read() + decrypted_data = encryption_manager.decrypt_manifest(encrypted_data, self.key_metadata) + input_file = BytesInputFile(self.manifest_path, decrypted_data) + with AvroFile[ManifestEntry]( input_file, MANIFEST_ENTRY_SCHEMAS[DEFAULT_READ_VERSION], @@ -900,7 +920,12 @@ def __hash__(self) -> int: _manifest_cache_lock = threading.RLock() -def _manifests(io: FileIO, manifest_list: str) -> tuple[ManifestFile, ...]: +def _manifests( + io: FileIO, + manifest_list: str, + encryption_manager: EncryptionManager | None = None, + snapshot_key_id: str | None = None, +) -> tuple[ManifestFile, ...]: """Read manifests from a manifest list, deduplicating ManifestFile objects via cache. Caches individual ManifestFile objects by manifest_path. This is memory-efficient @@ -920,12 +945,14 @@ def _manifests(io: FileIO, manifest_list: str) -> tuple[ManifestFile, ...]: Args: io: FileIO instance for reading the manifest list. manifest_list: Path to the manifest list file. + encryption_manager: Optional encryption manager for decrypting encrypted manifest lists. + snapshot_key_id: Optional key ID from snapshot for manifest list decryption. Returns: A tuple of ManifestFile objects. """ file = io.new_input(manifest_list) - manifest_files = list(read_manifest_list(file)) + manifest_files = list(read_manifest_list(file, encryption_manager=encryption_manager, snapshot_key_id=snapshot_key_id)) result = [] with _manifest_cache_lock: @@ -940,16 +967,31 @@ def _manifests(io: FileIO, manifest_list: str) -> tuple[ManifestFile, ...]: return tuple(result) -def read_manifest_list(input_file: InputFile) -> Iterator[ManifestFile]: +def read_manifest_list( + input_file: InputFile, + encryption_manager: EncryptionManager | None = None, + snapshot_key_id: str | None = None, +) -> Iterator[ManifestFile]: """ Read the manifests from the manifest list. Args: input_file: The input file where the stream can be read from. + encryption_manager: Optional encryption manager for decrypting encrypted manifest lists. + snapshot_key_id: Optional key ID from snapshot for manifest list decryption. Returns: An iterator of ManifestFiles that are part of the list. """ + # If we have encryption info, decrypt the manifest list first + if snapshot_key_id is not None and encryption_manager is not None: + from pyiceberg.encryption.io import BytesInputFile + + with input_file.open() as f: + encrypted_data = f.read() + decrypted_data = encryption_manager.decrypt_manifest_list(encrypted_data, snapshot_key_id) + input_file = BytesInputFile(input_file.location, decrypted_data) + with AvroFile[ManifestFile]( input_file, MANIFEST_LIST_FILE_SCHEMAS[DEFAULT_READ_VERSION], diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index bb8765b651..f846b5d07f 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -107,6 +107,8 @@ from pyiceberg.catalog import Catalog from pyiceberg.catalog.rest.scan_planning import RESTContentFile, RESTDeleteFile, RESTFileScanTask + from pyiceberg.encryption.kms import KeyManagementClient + from pyiceberg.encryption.manager import EncryptionManager ALWAYS_TRUE = AlwaysTrue() DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write" @@ -1890,6 +1892,7 @@ def _open_manifest( manifest: ManifestFile, partition_filter: Callable[[DataFile], bool], metrics_evaluator: Callable[[DataFile], bool], + encryption_manager: EncryptionManager | None = None, ) -> list[ManifestEntry]: """Open a manifest file and return matching manifest entries. @@ -1898,7 +1901,7 @@ def _open_manifest( """ return [ manifest_entry - for manifest_entry in manifest.fetch_manifest_entry(io, discard_deleted=True) + for manifest_entry in manifest.fetch_manifest_entry(io, discard_deleted=True, encryption_manager=encryption_manager) if partition_filter(manifest_entry.data_file) and metrics_evaluator(manifest_entry.data_file) ] @@ -1987,6 +1990,48 @@ def _check_sequence_number(min_sequence_number: int, manifest: ManifestFile) -> and (manifest.sequence_number or INITIAL_SEQUENCE_NUMBER) >= min_sequence_number ) + @cached_property + def _encryption_manager(self) -> EncryptionManager | None: + """Create an EncryptionManager if the table has encryption configured.""" + encryption_keys = getattr(self.table_metadata, "encryption_keys", []) + table_key_id = self.table_metadata.properties.get("encryption.key-id") + + if not encryption_keys or not table_key_id: + return None + + from pyiceberg.encryption.manager import EncryptedKey + from pyiceberg.encryption.manager import EncryptionManager as EncryptionManagerClass + + # Build the encryption keys map + enc_keys_map: dict[str, EncryptedKey] = {} + for ek in encryption_keys: + enc_keys_map[ek.key_id] = EncryptedKey( + key_id=ek.key_id, + encrypted_key_metadata=ek.encrypted_key_metadata_bytes, + encrypted_by_id=ek.encrypted_by_id, + properties=dict(ek.properties), + ) + + # Get the KMS client from the catalog or create one + kms_client = self._get_kms_client() + if kms_client is None: + return None + + return EncryptionManagerClass( + kms_client=kms_client, + encryption_keys=enc_keys_map, + ) + + def _get_kms_client(self) -> KeyManagementClient | None: + """Load a KMS client via py-kms-impl from catalog/table properties.""" + from pyiceberg.encryption.kms import load_kms_client + + all_properties = {**self.table_metadata.properties} + if self.catalog is not None: + all_properties = {**self.catalog.properties, **all_properties} + + return load_kms_client(all_properties) + def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: """Filter and return manifest entries based on partition and metrics evaluators. @@ -1997,6 +2042,8 @@ def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: if not snapshot: return iter([]) + encryption_manager = self._encryption_manager + # step 1: filter manifests using partition summaries # the filter depends on the partition spec used to write the manifest file, so create a cache of filters for each spec id @@ -2004,7 +2051,7 @@ def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: manifests = [ manifest_file - for manifest_file in snapshot.manifests(self.io) + for manifest_file in snapshot.manifests(self.io, encryption_manager=encryption_manager) if manifest_evaluators[manifest_file.partition_spec_id](manifest_file) ] @@ -2025,6 +2072,7 @@ def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: manifest, partition_evaluators[manifest.partition_spec_id], self._build_metrics_evaluator(), + encryption_manager, ) for manifest in manifests if self._check_sequence_number(min_sequence_number, manifest) @@ -2113,7 +2161,13 @@ def to_arrow(self) -> pa.Table: from pyiceberg.io.pyarrow import ArrowScan return ArrowScan( - self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit + self.table_metadata, + self.io, + self.projection(), + self.row_filter, + self.case_sensitive, + self.limit, + encryption_manager=self._encryption_manager, ).to_table(self.plan_files()) def to_arrow_batch_reader(self) -> pa.RecordBatchReader: @@ -2133,7 +2187,13 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( - self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit + self.table_metadata, + self.io, + self.projection(), + self.row_filter, + self.case_sensitive, + self.limit, + encryption_manager=self._encryption_manager, ).to_record_batches(self.plan_files()) return pa.RecordBatchReader.from_batches( diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 26b6e3d3ad..ed07177af4 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import base64 import datetime import uuid from collections.abc import Iterable @@ -48,6 +49,25 @@ from pyiceberg.utils.config import Config from pyiceberg.utils.datetime import datetime_to_millis + +class EncryptedKeyModel(IcebergBaseModel): + """An encrypted key entry in table metadata. + + Matches the EncryptedKey schema in the REST API spec. + """ + + key_id: str = Field(alias="key-id") + encrypted_key_metadata: str = Field(alias="encrypted-key-metadata") + """Base64-encoded encrypted key metadata bytes.""" + encrypted_by_id: str | None = Field(alias="encrypted-by-id", default=None) + properties: dict[str, str] = Field(default_factory=dict) + + @property + def encrypted_key_metadata_bytes(self) -> bytes: + """Decode the base64-encoded encrypted key metadata.""" + return base64.b64decode(self.encrypted_key_metadata) + + CURRENT_SNAPSHOT_ID = "current-snapshot-id" CURRENT_SCHEMA_ID = "current-schema-id" SCHEMAS = "schemas" @@ -574,6 +594,10 @@ def construct_refs(self) -> TableMetadata: next_row_id: int | None = Field(alias="next-row-id", default=None) """A long higher than all assigned row IDs; the next snapshot's `first-row-id`.""" + encryption_keys: list[EncryptedKeyModel] = Field(alias="encryption-keys", default_factory=list) + """Encryption key entries for the two-layer envelope encryption scheme. + Only valid for format version 3 and higher.""" + def model_dump_json(self, exclude_none: bool = True, exclude: Any | None = None, by_alias: bool = True, **kwargs: Any) -> str: raise NotImplementedError("Writing V3 is not yet supported, see: https://github.com/apache/iceberg-python/issues/1551") diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 7e4c6eb1ec..52bbc54852 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -31,6 +31,7 @@ from pyiceberg.schema import Schema if TYPE_CHECKING: + from pyiceberg.encryption.manager import EncryptionManager from pyiceberg.table.metadata import TableMetadata from pyiceberg.typedef import IcebergBaseModel @@ -252,6 +253,9 @@ class Snapshot(IcebergBaseModel): added_rows: int | None = Field( alias="added-rows", default=None, description="The upper bound of the number of rows with assigned row IDs" ) + key_id: str | None = Field( + alias="key-id", default=None, description="ID of the encryption key used to encrypt this snapshot's manifest list" + ) def __str__(self) -> str: """Return the string representation of the Snapshot class.""" @@ -277,9 +281,16 @@ def __repr__(self) -> str: filtered_fields = [field for field in fields if field is not None] return f"Snapshot({', '.join(filtered_fields)})" - def manifests(self, io: FileIO) -> list[ManifestFile]: + def manifests(self, io: FileIO, encryption_manager: EncryptionManager | None = None) -> list[ManifestFile]: """Return the manifests for the given snapshot.""" - return list(_manifests(io, self.manifest_list)) + return list( + _manifests( + io, + self.manifest_list, + encryption_manager=encryption_manager, + snapshot_key_id=self.key_id, + ) + ) class MetadataLogEntry(IcebergBaseModel): diff --git a/tests/encryption/__init__.py b/tests/encryption/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/encryption/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/encryption/test_ciphers.py b/tests/encryption/test_ciphers.py new file mode 100644 index 0000000000..ccb7d058d5 --- /dev/null +++ b/tests/encryption/test_ciphers.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import struct + +import pytest + +from pyiceberg.encryption.ciphers import ( + GCM_STREAM_MAGIC, + GCM_TAG_LENGTH, + NONCE_LENGTH, + aes_gcm_decrypt, + aes_gcm_encrypt, + decrypt_ags1_stream, + stream_block_aad, +) + + +class TestAesGcm: + def test_roundtrip(self) -> None: + key = os.urandom(16) + plaintext = b"hello, encryption" + ciphertext = aes_gcm_encrypt(key, plaintext) + assert aes_gcm_decrypt(key, ciphertext) == plaintext + + def test_roundtrip_with_aad(self) -> None: + key = os.urandom(16) + plaintext = b"hello, encryption" + aad = b"additional-data" + ciphertext = aes_gcm_encrypt(key, plaintext, aad=aad) + assert aes_gcm_decrypt(key, ciphertext, aad=aad) == plaintext + + def test_wrong_key_fails(self) -> None: + from cryptography.exceptions import InvalidTag + + key = os.urandom(16) + wrong_key = os.urandom(16) + ciphertext = aes_gcm_encrypt(key, b"secret") + with pytest.raises(InvalidTag): + aes_gcm_decrypt(wrong_key, ciphertext) + + def test_wrong_aad_fails(self) -> None: + from cryptography.exceptions import InvalidTag + + key = os.urandom(16) + ciphertext = aes_gcm_encrypt(key, b"secret", aad=b"correct") + with pytest.raises(InvalidTag): + aes_gcm_decrypt(key, ciphertext, aad=b"wrong") + + def test_wire_format(self) -> None: + """Verify the wire format: nonce(12) || ciphertext || tag(16).""" + key = os.urandom(16) + plaintext = b"test" + ciphertext = aes_gcm_encrypt(key, plaintext) + # Minimum size: nonce + tag + at least len(plaintext) of ciphertext + assert len(ciphertext) == NONCE_LENGTH + len(plaintext) + GCM_TAG_LENGTH + + def test_ciphertext_too_short(self) -> None: + key = os.urandom(16) + with pytest.raises(ValueError, match="Ciphertext too short"): + aes_gcm_decrypt(key, b"short") + + +class TestStreamBlockAad: + def test_with_prefix(self) -> None: + aad = stream_block_aad(b"prefix", 0) + assert aad == b"prefix" + struct.pack(" None: + aad = stream_block_aad(b"prefix", 42) + assert aad == b"prefix" + struct.pack(" None: + aad = stream_block_aad(b"", 7) + assert aad == struct.pack(" bytes: + """Build an AGS1 encrypted stream for testing.""" + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + + aesgcm = AESGCM(key) + header = GCM_STREAM_MAGIC + struct.pack(" None: + key = os.urandom(16) + plaintext = b"hello AGS1 stream" + aad_prefix = b"test-aad" + encrypted = _encrypt_ags1_stream(key, plaintext, aad_prefix) + assert decrypt_ags1_stream(key, encrypted, aad_prefix) == plaintext + + def test_roundtrip_multi_block(self) -> None: + """Test with a small block size to force multiple blocks.""" + key = os.urandom(16) + plaintext = b"A" * 200 + aad_prefix = b"multi" + # Use a 64-byte block size to get multiple blocks + encrypted = _encrypt_ags1_stream(key, plaintext, aad_prefix, plain_block_size=64) + assert decrypt_ags1_stream(key, encrypted, aad_prefix) == plaintext + + def test_roundtrip_empty_payload(self) -> None: + key = os.urandom(16) + header = GCM_STREAM_MAGIC + struct.pack(" None: + data = b"XXXX" + struct.pack(" None: + with pytest.raises(ValueError, match="AGS1 stream too short"): + decrypt_ags1_stream(os.urandom(16), b"AGS1", b"") + + def test_custom_block_size(self) -> None: + """Verify the block size from the header is respected, not hardcoded.""" + key = os.urandom(16) + plaintext = b"B" * 300 + aad_prefix = b"custom" + # Encrypt with a 100-byte block size + encrypted = _encrypt_ags1_stream(key, plaintext, aad_prefix, plain_block_size=100) + # Verify the header contains 100 + assert struct.unpack_from(" None: + key = os.urandom(16) + # Header + a few bytes that are too short for even nonce+tag + data = GCM_STREAM_MAGIC + struct.pack(" None: + stream = BytesInputStream(b"hello") + assert stream.read() == b"hello" + + def test_read_partial(self) -> None: + stream = BytesInputStream(b"hello world") + assert stream.read(5) == b"hello" + assert stream.read(6) == b" world" + + def test_seek_and_tell(self) -> None: + stream = BytesInputStream(b"abcdef") + assert stream.tell() == 0 + stream.seek(3) + assert stream.tell() == 3 + assert stream.read(2) == b"de" + + def test_context_manager(self) -> None: + with BytesInputStream(b"data") as stream: + assert stream.read() == b"data" + + def test_implements_input_stream_protocol(self) -> None: + stream = BytesInputStream(b"test") + assert isinstance(stream, InputStream) + + +class TestBytesInputFile: + def test_len(self) -> None: + f = BytesInputFile("file://test", b"hello") + assert len(f) == 5 + + def test_exists(self) -> None: + f = BytesInputFile("file://test", b"") + assert f.exists() is True + + def test_open_and_read(self) -> None: + f = BytesInputFile("file://test", b"content") + with f.open() as stream: + assert stream.read() == b"content" + + def test_location(self) -> None: + f = BytesInputFile("s3://bucket/path", b"data") + assert f.location == "s3://bucket/path" diff --git a/tests/encryption/test_key_metadata.py b/tests/encryption/test_key_metadata.py new file mode 100644 index 0000000000..e9f05fc078 --- /dev/null +++ b/tests/encryption/test_key_metadata.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +import pytest + +from pyiceberg.encryption.key_metadata import StandardKeyMetadata + + +class TestStandardKeyMetadata: + def test_roundtrip_key_only(self) -> None: + key = os.urandom(16) + original = StandardKeyMetadata(encryption_key=key) + serialized = original.serialize() + restored = StandardKeyMetadata.deserialize(serialized) + assert restored.encryption_key == key + assert restored.aad_prefix == b"" + assert restored.file_length is None + + def test_roundtrip_with_aad_prefix(self) -> None: + key = os.urandom(16) + aad = os.urandom(8) + original = StandardKeyMetadata(encryption_key=key, aad_prefix=aad) + serialized = original.serialize() + restored = StandardKeyMetadata.deserialize(serialized) + assert restored.encryption_key == key + assert restored.aad_prefix == aad + assert restored.file_length is None + + def test_roundtrip_all_fields(self) -> None: + key = os.urandom(32) + aad = os.urandom(16) + original = StandardKeyMetadata(encryption_key=key, aad_prefix=aad, file_length=12345) + serialized = original.serialize() + restored = StandardKeyMetadata.deserialize(serialized) + assert restored.encryption_key == key + assert restored.aad_prefix == aad + assert restored.file_length == 12345 + + def test_version_byte(self) -> None: + """First byte should always be 0x01.""" + key = os.urandom(16) + serialized = StandardKeyMetadata(encryption_key=key).serialize() + assert serialized[0] == 0x01 + + def test_deserialize_empty(self) -> None: + with pytest.raises(ValueError, match="Empty key metadata"): + StandardKeyMetadata.deserialize(b"") + + def test_deserialize_wrong_version(self) -> None: + with pytest.raises(ValueError, match="Unsupported key metadata version"): + StandardKeyMetadata.deserialize(b"\x02\x00") + + def test_frozen(self) -> None: + """StandardKeyMetadata is a frozen dataclass.""" + skm = StandardKeyMetadata(encryption_key=b"key") + with pytest.raises(AttributeError): + skm.encryption_key = b"other" # type: ignore[misc] + + def test_roundtrip_large_file_length(self) -> None: + """Zigzag encoding should handle large values correctly.""" + key = os.urandom(16) + original = StandardKeyMetadata(encryption_key=key, file_length=2**40) + serialized = original.serialize() + restored = StandardKeyMetadata.deserialize(serialized) + assert restored.file_length == 2**40 diff --git a/tests/encryption/test_kms.py b/tests/encryption/test_kms.py new file mode 100644 index 0000000000..d445f27201 --- /dev/null +++ b/tests/encryption/test_kms.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +import pytest + +from pyiceberg.encryption.kms import InMemoryKms, KeyManagementClient, load_kms_client + + +class TestInMemoryKms: + def test_wrap_unwrap_roundtrip(self) -> None: + master_key = os.urandom(16) + kms = InMemoryKms(master_keys={"keyA": master_key}) + plaintext_key = os.urandom(16) + wrapped = kms.wrap_key(plaintext_key, "keyA") + unwrapped = kms.unwrap_key(wrapped, "keyA") + assert unwrapped == plaintext_key + + def test_unknown_wrapping_key(self) -> None: + kms = InMemoryKms() + with pytest.raises(ValueError, match="Wrapping key not found"): + kms.wrap_key(b"key", "nonexistent") + + def test_unknown_unwrapping_key(self) -> None: + kms = InMemoryKms() + with pytest.raises(ValueError, match="Wrapping key not found"): + kms.unwrap_key(b"wrapped", "nonexistent") + + def test_initialize_from_properties(self) -> None: + kms = InMemoryKms() + key_hex = os.urandom(16).hex() + kms.initialize({"encryption.kms.key.testKey": key_hex}) + # Should be able to wrap/unwrap with the initialized key + wrapped = kms.wrap_key(b"secret", "testKey") + assert kms.unwrap_key(wrapped, "testKey") == b"secret" + + def test_initialize_ignores_unrelated_properties(self) -> None: + kms = InMemoryKms() + kms.initialize({"some.other.prop": "value", "encryption.kms.key.k1": os.urandom(16).hex()}) + with pytest.raises(ValueError, match="Wrapping key not found"): + kms.wrap_key(b"key", "nonexistent") + + def test_wrap_unwrap_with_standard_test_keys(self) -> None: + """Wrap/unwrap with the standard Iceberg test master keys.""" + kms = InMemoryKms( + master_keys={ + "keyA": b"0123456789012345", + "keyB": b"1123456789012345", + } + ) + plaintext = os.urandom(16) + wrapped_a = kms.wrap_key(plaintext, "keyA") + assert kms.unwrap_key(wrapped_a, "keyA") == plaintext + wrapped_b = kms.wrap_key(plaintext, "keyB") + assert kms.unwrap_key(wrapped_b, "keyB") == plaintext + + def test_wrong_master_key_fails_unwrap(self) -> None: + from cryptography.exceptions import InvalidTag + + kms = InMemoryKms( + master_keys={ + "keyA": os.urandom(16), + "keyB": os.urandom(16), + } + ) + wrapped = kms.wrap_key(b"secret", "keyA") + with pytest.raises(InvalidTag): + kms.unwrap_key(wrapped, "keyB") + + +class TestLoadKmsClient: + def test_returns_none_when_not_configured(self) -> None: + assert load_kms_client({}) is None + + def test_loads_in_memory_kms(self) -> None: + client = load_kms_client( + { + "py-kms-impl": "pyiceberg.encryption.kms.InMemoryKms", + "encryption.kms.key.myKey": os.urandom(16).hex(), + } + ) + assert client is not None + assert isinstance(client, InMemoryKms) + # Should be initialized — the key should be usable + wrapped = client.wrap_key(b"data", "myKey") + assert client.unwrap_key(wrapped, "myKey") == b"data" + + def test_invalid_short_path(self) -> None: + with pytest.raises(ValueError, match="full path"): + load_kms_client({"py-kms-impl": "InMemoryKms"}) + + def test_nonexistent_module(self) -> None: + with pytest.raises(ValueError, match="Could not import"): + load_kms_client({"py-kms-impl": "nonexistent.module.Kms"}) + + def test_nonexistent_class(self) -> None: + with pytest.raises(ValueError, match="not found in module"): + load_kms_client({"py-kms-impl": "pyiceberg.encryption.kms.NonexistentClass"}) + + def test_not_a_subclass(self) -> None: + with pytest.raises(ValueError, match="not a subclass"): + # AESGCM is a real class but not a KeyManagementClient + load_kms_client({"py-kms-impl": "cryptography.hazmat.primitives.ciphers.aead.AESGCM"}) + + def test_custom_kms_impl(self) -> None: + """Verify that a custom KMS implementation can be loaded by module path.""" + + class _TestKms(KeyManagementClient): + initialized_with: dict[str, str] = {} + + def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: + return key + + def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: + return wrapped_key + + def initialize(self, properties: dict[str, str]) -> None: + _TestKms.initialized_with = properties + + # Register in the module namespace so importlib can find it + import pyiceberg.encryption.kms as kms_module + + kms_module._TestKms = _TestKms # type: ignore[attr-defined] + try: + client = load_kms_client({"py-kms-impl": "pyiceberg.encryption.kms._TestKms", "foo": "bar"}) + assert client is not None + assert isinstance(client, _TestKms) + assert _TestKms.initialized_with.get("foo") == "bar" + finally: + delattr(kms_module, "_TestKms") diff --git a/tests/encryption/test_manager.py b/tests/encryption/test_manager.py new file mode 100644 index 0000000000..fed30cca84 --- /dev/null +++ b/tests/encryption/test_manager.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import struct + +import pytest + +from pyiceberg.encryption.ciphers import ( + GCM_STREAM_MAGIC, + NONCE_LENGTH, + aes_gcm_encrypt, + stream_block_aad, +) +from pyiceberg.encryption.key_metadata import StandardKeyMetadata +from pyiceberg.encryption.kms import InMemoryKms +from pyiceberg.encryption.manager import EncryptedKey, EncryptionManager + + +def _make_ags1_stream(key: bytes, plaintext: bytes, aad_prefix: bytes, plain_block_size: int = 1024 * 1024) -> bytes: + """Build an AGS1 encrypted stream for testing.""" + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + + aesgcm = AESGCM(key) + header = GCM_STREAM_MAGIC + struct.pack(" tuple[EncryptionManager, bytes, bytes, str]: + """Build an EncryptionManager with test keys, mimicking the REST catalog flow. + + Returns (manager, dek, aad_prefix, manifest_list_key_id). + """ + master_key = b"0123456789012345" # 16 bytes, standard Iceberg test key "keyA" + kms = InMemoryKms(master_keys={"keyA": master_key}) + + # Create a KEK (simulating what the REST catalog would provide) + kek_bytes = os.urandom(16) + kek_wrapped = kms.wrap_key(kek_bytes, "keyA") + kek_timestamp = "1234567890" + + # Create a DEK for the manifest list + dek = os.urandom(16) + aad_prefix = os.urandom(8) + key_metadata = StandardKeyMetadata(encryption_key=dek, aad_prefix=aad_prefix) + key_metadata_bytes = key_metadata.serialize() + + # Wrap the DEK key metadata with the KEK (using timestamp as AAD) + wrapped_dek = aes_gcm_encrypt(kek_bytes, key_metadata_bytes, aad=kek_timestamp.encode("utf-8")) + + encryption_keys = { + "kek-1": EncryptedKey( + key_id="kek-1", + encrypted_key_metadata=kek_wrapped, + encrypted_by_id="keyA", + properties={"KEY_TIMESTAMP": kek_timestamp}, + ), + "mlk-1": EncryptedKey( + key_id="mlk-1", + encrypted_key_metadata=wrapped_dek, + encrypted_by_id="kek-1", + ), + } + + manager = EncryptionManager(kms_client=kms, encryption_keys=encryption_keys) + return manager, dek, aad_prefix, "mlk-1" + + +class TestEncryptionManager: + def test_decrypt_manifest_list(self) -> None: + manager, dek, aad_prefix, mlk_id = _build_test_encryption_manager() + plaintext = b"manifest list content here" + encrypted_stream = _make_ags1_stream(dek, plaintext, aad_prefix) + result = manager.decrypt_manifest_list(encrypted_stream, mlk_id) + assert result == plaintext + + def test_decrypt_manifest(self) -> None: + dek = os.urandom(16) + aad_prefix = os.urandom(8) + plaintext = b"manifest content here" + encrypted_stream = _make_ags1_stream(dek, plaintext, aad_prefix) + key_metadata = StandardKeyMetadata(encryption_key=dek, aad_prefix=aad_prefix) + + # The manager only needs a KMS for KEK unwrapping; manifest decryption + # uses the plaintext key metadata directly (from inside the encrypted manifest list) + kms = InMemoryKms() + manager = EncryptionManager(kms_client=kms) + result = manager.decrypt_manifest(encrypted_stream, key_metadata.serialize()) + assert result == plaintext + + def test_kek_caching(self) -> None: + """KEK should be unwrapped once and cached.""" + manager, dek, aad_prefix, mlk_id = _build_test_encryption_manager() + plaintext = b"test" + encrypted = _make_ags1_stream(dek, plaintext, aad_prefix) + + # Decrypt twice + manager.decrypt_manifest_list(encrypted, mlk_id) + manager.decrypt_manifest_list(encrypted, mlk_id) + + # KEK should be cached + assert "kek-1" in manager._kek_cache + + def test_missing_snapshot_key(self) -> None: + kms = InMemoryKms() + manager = EncryptionManager(kms_client=kms) + with pytest.raises(ValueError, match="Snapshot key not found"): + manager.decrypt_manifest_list(b"data", "nonexistent-key") + + def test_missing_kek(self) -> None: + kms = InMemoryKms() + encryption_keys = { + "mlk-1": EncryptedKey( + key_id="mlk-1", + encrypted_key_metadata=b"wrapped", + encrypted_by_id="missing-kek", + ), + } + manager = EncryptionManager(kms_client=kms, encryption_keys=encryption_keys) + with pytest.raises(ValueError, match="KEK not found"): + manager.decrypt_manifest_list(b"data", "mlk-1") + + def test_kek_without_encrypted_by_id(self) -> None: + kms = InMemoryKms(master_keys={"keyA": os.urandom(16)}) + encryption_keys = { + "kek-1": EncryptedKey(key_id="kek-1", encrypted_key_metadata=b"data"), + "mlk-1": EncryptedKey( + key_id="mlk-1", + encrypted_key_metadata=b"wrapped", + encrypted_by_id="kek-1", + ), + } + manager = EncryptionManager(kms_client=kms, encryption_keys=encryption_keys) + with pytest.raises(ValueError, match="has no encrypted_by_id"): + manager.decrypt_manifest_list(b"data", "mlk-1")