Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions pyiceberg/encryption/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Iceberg table encryption support."""
116 changes: 116 additions & 0 deletions pyiceberg/encryption/ciphers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""AES-GCM encryption/decryption primitives and AGS1 stream decryption."""

from __future__ import annotations

import os
import struct

from cryptography.hazmat.primitives.ciphers.aead import AESGCM

NONCE_LENGTH = 12
GCM_TAG_LENGTH = 16


def aes_gcm_encrypt(key: bytes, plaintext: bytes, aad: bytes | None = None) -> bytes:
"""Encrypt using AES-GCM. Returns nonce || ciphertext || tag."""
nonce = os.urandom(NONCE_LENGTH)
aesgcm = AESGCM(key)
ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, aad)
return nonce + ciphertext_with_tag


def aes_gcm_decrypt(key: bytes, ciphertext: bytes, aad: bytes | None = None) -> bytes:
"""Decrypt AES-GCM data in format: nonce || ciphertext || tag."""
if len(ciphertext) < NONCE_LENGTH + GCM_TAG_LENGTH:
raise ValueError(f"Ciphertext too short: {len(ciphertext)} bytes")
nonce = ciphertext[:NONCE_LENGTH]
encrypted_data = ciphertext[NONCE_LENGTH:]
aesgcm = AESGCM(key)
return aesgcm.decrypt(nonce, encrypted_data, aad)


# AGS1 stream constants
GCM_STREAM_MAGIC = b"AGS1"
GCM_STREAM_HEADER_LENGTH = 8 # 4 magic + 4 block size


def stream_block_aad(aad_prefix: bytes, block_index: int) -> bytes:
"""Construct per-block AAD for AGS1 stream encryption.

Format: aad_prefix || block_index (4 bytes, little-endian).
"""
index_bytes = struct.pack("<I", block_index)
if not aad_prefix:
return index_bytes
return aad_prefix + index_bytes


def decrypt_ags1_stream(key: bytes, encrypted_data: bytes, aad_prefix: bytes) -> bytes:
"""Decrypt an entire AGS1 stream and return the plaintext.

AGS1 format:
- Header: "AGS1" (4 bytes) + plain_block_size (4 bytes LE)
- Blocks: each block is nonce(12) + ciphertext(up to 1MB) + tag(16)
- Each block's AAD = aad_prefix + block_index (4 bytes LE)

"""
if len(encrypted_data) < GCM_STREAM_HEADER_LENGTH:
raise ValueError(f"AGS1 stream too short: {len(encrypted_data)} bytes")

magic = encrypted_data[:4]
if magic != GCM_STREAM_MAGIC:
raise ValueError(f"Invalid AGS1 magic: {magic!r}, expected {GCM_STREAM_MAGIC!r}")

plain_block_size = struct.unpack_from("<I", encrypted_data, 4)[0]
cipher_block_size = plain_block_size + NONCE_LENGTH + GCM_TAG_LENGTH

stream_data = encrypted_data[GCM_STREAM_HEADER_LENGTH:]
if not stream_data:
return b""

aesgcm = AESGCM(key)
result = bytearray()
offset = 0
block_index = 0

while offset < len(stream_data):
# Determine this block's cipher size
remaining = len(stream_data) - offset
if remaining >= cipher_block_size:
block_cipher_size = cipher_block_size
else:
block_cipher_size = remaining

if block_cipher_size < NONCE_LENGTH + GCM_TAG_LENGTH:
raise ValueError(
f"Truncated AGS1 block at offset {offset}: {block_cipher_size} bytes (minimum {NONCE_LENGTH + GCM_TAG_LENGTH})"
)

block_data = stream_data[offset : offset + block_cipher_size]
nonce = block_data[:NONCE_LENGTH]
ciphertext_with_tag = block_data[NONCE_LENGTH:]

aad = stream_block_aad(aad_prefix, block_index)
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, aad)
result.extend(plaintext)

offset += block_cipher_size
block_index += 1

return bytes(result)
80 changes: 80 additions & 0 deletions pyiceberg/encryption/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""InputFile implementation backed by in-memory bytes."""

from __future__ import annotations

import io
from types import TracebackType

from pyiceberg.io import InputFile, InputStream


class BytesInputStream(InputStream):
"""InputStream implementation backed by a bytes buffer."""

def __init__(self, data: bytes) -> None:
self._buffer = io.BytesIO(data)

def read(self, size: int = 0) -> bytes:
if size <= 0:
return self._buffer.read()
return self._buffer.read(size)

def seek(self, offset: int, whence: int = 0) -> int:
return self._buffer.seek(offset, whence)

def tell(self) -> int:
return self._buffer.tell()

def close(self) -> None:
self._buffer.close()

def __enter__(self) -> BytesInputStream:
"""Enter the context manager."""
return self

def __exit__(
self,
exctype: type[BaseException] | None,
excinst: BaseException | None,
exctb: TracebackType | None,
) -> None:
"""Exit the context manager and close the stream."""
self.close()


class BytesInputFile(InputFile):
"""InputFile implementation backed by in-memory bytes.

Used to wrap decrypted data so that it can be read by
AvroFile and other readers that expect an InputFile.
"""

def __init__(self, location: str, data: bytes) -> None:
super().__init__(location)
self._data = data

def __len__(self) -> int:
"""Return the length of the underlying data."""
return len(self._data)

def exists(self) -> bool:
return True

def open(self, seekable: bool = True) -> InputStream:
return BytesInputStream(self._data)
152 changes: 152 additions & 0 deletions pyiceberg/encryption/key_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""StandardKeyMetadata Avro serialization.

Wire format: ``0x01 version byte || Avro-encoded fields``

Avro schema:
- encryption_key: bytes (required)
- aad_prefix: union[null, bytes] (optional)
- file_length: union[null, long] (optional)
"""

from __future__ import annotations

from dataclasses import dataclass

V1 = 0x01


def _read_avro_long(data: bytes, offset: int) -> tuple[int, int]:
"""Read a zigzag-encoded Avro long from data at offset. Returns (value, new_offset)."""
result = 0
shift = 0
while True:
if offset >= len(data):
raise ValueError("Unexpected end of Avro data reading long")
b = data[offset]
offset += 1
result |= (b & 0x7F) << shift
if (b & 0x80) == 0:
break
shift += 7
# Zigzag decode
return (result >> 1) ^ -(result & 1), offset


def _read_avro_bytes(data: bytes, offset: int) -> tuple[bytes, int]:
"""Read Avro bytes (length-prefixed). Returns (bytes_value, new_offset)."""
length, offset = _read_avro_long(data, offset)
if length < 0:
raise ValueError(f"Negative Avro bytes length: {length}")
end = offset + length
if end > len(data):
raise ValueError("Unexpected end of Avro data reading bytes")
return data[offset:end], end


@dataclass(frozen=True)
class StandardKeyMetadata:
"""Standard key metadata for Iceberg table encryption.

Contains the plaintext encryption key (DEK), AAD prefix, and optional file length.
"""

encryption_key: bytes
aad_prefix: bytes = b""
file_length: int | None = None

@staticmethod
def deserialize(data: bytes) -> StandardKeyMetadata:
"""Deserialize from wire format: ``0x01 version || Avro-encoded fields``."""
if not data:
raise ValueError("Empty key metadata buffer")

version = data[0]
if version != V1:
raise ValueError(f"Unsupported key metadata version: {version}")

offset = 1

# Read encryption_key (required bytes)
encryption_key, offset = _read_avro_bytes(data, offset)

# Read aad_prefix (optional: union[null, bytes])
union_index, offset = _read_avro_long(data, offset)
if union_index == 0:
aad_prefix = b""
elif union_index == 1:
aad_prefix, offset = _read_avro_bytes(data, offset)
else:
raise ValueError(f"Invalid union index for aad_prefix: {union_index}")

# Read file_length (optional: union[null, long])
file_length = None
if offset < len(data):
union_index, offset = _read_avro_long(data, offset)
if union_index == 0:
file_length = None
elif union_index == 1:
file_length, offset = _read_avro_long(data, offset)
else:
raise ValueError(f"Invalid union index for file_length: {union_index}")

return StandardKeyMetadata(
encryption_key=encryption_key,
aad_prefix=aad_prefix,
file_length=file_length,
)

def serialize(self) -> bytes:
"""Serialize to wire format: ``0x01 version || Avro-encoded fields``."""
parts = [bytes([V1])]

# encryption_key (required bytes)
parts.append(_encode_avro_bytes(self.encryption_key))

# aad_prefix (union[null, bytes])
if self.aad_prefix:
parts.append(_encode_avro_long(1)) # union index 1 = bytes
parts.append(_encode_avro_bytes(self.aad_prefix))
else:
parts.append(_encode_avro_long(0)) # union index 0 = null

# file_length (union[null, long])
if self.file_length is not None:
parts.append(_encode_avro_long(1)) # union index 1 = long
parts.append(_encode_avro_long(self.file_length))
else:
parts.append(_encode_avro_long(0)) # union index 0 = null

return b"".join(parts)


def _encode_avro_long(value: int) -> bytes:
"""Encode a long as zigzag-encoded Avro varint."""
# Zigzag encode
n = (value << 1) ^ (value >> 63)
result = bytearray()
while n & ~0x7F:
result.append((n & 0x7F) | 0x80)
n >>= 7
result.append(n & 0x7F)
return bytes(result)


def _encode_avro_bytes(data: bytes) -> bytes:
"""Encode bytes with Avro length prefix."""
return _encode_avro_long(len(data)) + data
Loading
Loading