diff --git a/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowMessage.java b/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowMessage.java index 6f8e893405..06a507fa9c 100644 --- a/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowMessage.java +++ b/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowMessage.java @@ -16,6 +16,8 @@ */ package org.apache.arrow.vector.ipc.message; +import java.util.Map; + /** Interface for Arrow IPC messages (https://arrow.apache.org/docs/format/IPC.html). */ public interface ArrowMessage extends FBSerializable, AutoCloseable { @@ -26,6 +28,15 @@ public interface ArrowMessage extends FBSerializable, AutoCloseable { /** Returns the flatbuffer enum value indicating the type of the message. */ byte getMessageType(); + /** + * Returns custom metadata for this message, or null if none. + * + * @return custom metadata map, or null if no custom metadata is present + */ + default Map getCustomMetadata() { + return null; + } + /** * Visitor interface for implementations of {@link ArrowMessage}. * diff --git a/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowRecordBatch.java b/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowRecordBatch.java index bc6bfa8c86..82d7f33758 100644 --- a/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowRecordBatch.java +++ b/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowRecordBatch.java @@ -19,7 +19,9 @@ import com.google.flatbuffers.FlatBufferBuilder; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.apache.arrow.flatbuf.RecordBatch; import org.apache.arrow.memory.ArrowBuf; @@ -52,6 +54,8 @@ public class ArrowRecordBatch implements ArrowMessage { private final List variadicBufferCounts; + private final Map customMetadata; + private boolean closed = false; public ArrowRecordBatch(int length, List nodes, List buffers) { @@ -66,6 +70,30 @@ public ArrowRecordBatch( this(length, nodes, buffers, bodyCompression, null, true); } + /** + * Construct a record batch from nodes with custom metadata. + * + * @param length how many rows in this batch + * @param nodes field level info + * @param buffers will be retained until this recordBatch is closed + * @param customMetadata custom metadata for this record batch + */ + public ArrowRecordBatch( + int length, + List nodes, + List buffers, + Map customMetadata) { + this( + length, + nodes, + buffers, + NoCompressionCodec.DEFAULT_BODY_COMPRESSION, + null, + true, + true, + customMetadata); + } + /** * Construct a record batch from nodes. * @@ -152,6 +180,39 @@ public ArrowRecordBatch( List variadicBufferCounts, boolean alignBuffers, boolean retainBuffers) { + this( + length, + nodes, + buffers, + bodyCompression, + variadicBufferCounts, + alignBuffers, + retainBuffers, + null); + } + + /** + * Construct a record batch from nodes. + * + * @param length how many rows in this batch + * @param nodes field level info + * @param buffers will be retained until this recordBatch is closed + * @param bodyCompression compression info. + * @param variadicBufferCounts the number of buffers in each variadic section. + * @param alignBuffers Whether to align buffers to an 8 byte boundary. + * @param retainBuffers Whether to retain() each source buffer in the constructor. If false, the + * caller is responsible for retaining the buffers beforehand. + * @param customMetadata custom metadata for this record batch. + */ + public ArrowRecordBatch( + int length, + List nodes, + List buffers, + ArrowBodyCompression bodyCompression, + List variadicBufferCounts, + boolean alignBuffers, + boolean retainBuffers, + Map customMetadata) { super(); this.length = length; this.nodes = nodes; @@ -159,6 +220,10 @@ public ArrowRecordBatch( Preconditions.checkArgument(bodyCompression != null, "body compression cannot be null"); this.bodyCompression = bodyCompression; this.variadicBufferCounts = variadicBufferCounts; + this.customMetadata = + customMetadata == null + ? Collections.emptyMap() + : Collections.unmodifiableMap(new HashMap<>(customMetadata)); List arrowBuffers = new ArrayList<>(buffers.size()); long offset = 0; for (ArrowBuf arrowBuf : buffers) { @@ -188,13 +253,18 @@ private ArrowRecordBatch( List nodes, List buffers, ArrowBodyCompression bodyCompression, - List variadicBufferCounts) { + List variadicBufferCounts, + Map customMetadata) { this.length = length; this.nodes = nodes; this.buffers = buffers; Preconditions.checkArgument(bodyCompression != null, "body compression cannot be null"); this.bodyCompression = bodyCompression; this.variadicBufferCounts = variadicBufferCounts; + this.customMetadata = + customMetadata == null + ? Collections.emptyMap() + : Collections.unmodifiableMap(new HashMap<>(customMetadata)); this.closed = false; List arrowBuffers = new ArrayList<>(); long offset = 0; @@ -218,6 +288,16 @@ public ArrowBodyCompression getBodyCompression() { return bodyCompression; } + /** + * Get the custom metadata for this record batch. + * + * @return the custom metadata as an unmodifiable map + */ + @Override + public Map getCustomMetadata() { + return customMetadata; + } + /** * Get the nodes in this record batch. * @@ -268,7 +348,7 @@ public ArrowRecordBatch cloneWithTransfer(final BufferAllocator allocator) { .collect(Collectors.toList()); close(); return new ArrowRecordBatch( - false, length, nodes, newBufs, bodyCompression, variadicBufferCounts); + false, length, nodes, newBufs, bodyCompression, variadicBufferCounts, customMetadata); } /** diff --git a/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java b/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java index 36f6ea449b..99d11d29b9 100644 --- a/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java +++ b/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java @@ -22,10 +22,13 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.arrow.flatbuf.Buffer; import org.apache.arrow.flatbuf.DictionaryBatch; import org.apache.arrow.flatbuf.FieldNode; +import org.apache.arrow.flatbuf.KeyValue; import org.apache.arrow.flatbuf.Message; import org.apache.arrow.flatbuf.MessageHeader; import org.apache.arrow.flatbuf.MetadataVersion; @@ -326,7 +329,12 @@ public static ByteBuffer serializeMetadata(ArrowMessage message, IpcOption write FlatBufferBuilder builder = new FlatBufferBuilder(); int batchOffset = message.writeTo(builder); return serializeMessage( - builder, message.getMessageType(), batchOffset, message.computeBodyLength(), writeOption); + builder, + message.getMessageType(), + batchOffset, + message.computeBodyLength(), + writeOption, + message.getCustomMetadata()); } /** @@ -340,7 +348,18 @@ public static ByteBuffer serializeMetadata(ArrowMessage message, IpcOption write public static ArrowRecordBatch deserializeRecordBatch( Message recordBatchMessage, ArrowBuf bodyBuffer) throws IOException { RecordBatch recordBatchFB = (RecordBatch) recordBatchMessage.header(new RecordBatch()); - return deserializeRecordBatch(recordBatchFB, bodyBuffer); + // Extract custom metadata from the Message + Map customMetadata = null; + if (recordBatchMessage.customMetadataLength() > 0) { + customMetadata = new HashMap<>(); + for (int i = 0; i < recordBatchMessage.customMetadataLength(); i++) { + KeyValue kv = recordBatchMessage.customMetadata(i); + String key = kv.key(); + String value = kv.value(); + customMetadata.put(key == null ? "" : key, value == null ? "" : value); + } + } + return deserializeRecordBatch(recordBatchFB, bodyBuffer, customMetadata); } /** @@ -395,10 +414,22 @@ public static ArrowRecordBatch deserializeRecordBatch( RecordBatch recordBatchFB = (RecordBatch) messageFB.header(new RecordBatch()); + // Extract custom metadata from the Message + Map customMetadata = null; + if (messageFB.customMetadataLength() > 0) { + customMetadata = new HashMap<>(); + for (int i = 0; i < messageFB.customMetadataLength(); i++) { + KeyValue kv = messageFB.customMetadata(i); + String key = kv.key(); + String value = kv.value(); + customMetadata.put(key == null ? "" : key, value == null ? "" : value); + } + } + // Now read the body final ArrowBuf body = buffer.slice(block.getMetadataLength(), totalLen - block.getMetadataLength()); - return deserializeRecordBatch(recordBatchFB, body); + return deserializeRecordBatch(recordBatchFB, body, customMetadata); } /** @@ -411,6 +442,22 @@ public static ArrowRecordBatch deserializeRecordBatch( */ public static ArrowRecordBatch deserializeRecordBatch(RecordBatch recordBatchFB, ArrowBuf body) throws IOException { + return deserializeRecordBatch(recordBatchFB, body, null); + } + + /** + * Deserializes an ArrowRecordBatch given the Flatbuffer metadata, in-memory body, and custom + * metadata. + * + * @param recordBatchFB Deserialized FlatBuffer record batch + * @param body Read body of the record batch + * @param customMetadata Custom metadata from the Message + * @return ArrowRecordBatch from metadata and in-memory body + * @throws IOException on error + */ + public static ArrowRecordBatch deserializeRecordBatch( + RecordBatch recordBatchFB, ArrowBuf body, Map customMetadata) + throws IOException { // Now read the body int nodesLength = recordBatchFB.nodesLength(); List nodes = new ArrayList<>(); @@ -452,7 +499,9 @@ public static ArrowRecordBatch deserializeRecordBatch(RecordBatch recordBatchFB, buffers, bodyCompression, variadicBufferCounts, - /*alignBuffers*/ true); + /*alignBuffers*/ true, + /*retainBuffers*/ true, + customMetadata); body.getReferenceManager().release(); return arrowRecordBatch; } @@ -676,11 +725,47 @@ public static ByteBuffer serializeMessage( int headerOffset, long bodyLength, IpcOption writeOption) { + return serializeMessage(builder, headerType, headerOffset, bodyLength, writeOption, null); + } + + /** + * Serializes an Arrow message with metadata and custom metadata into a ByteBuffer. + * + * @param builder to write the flatbuf to + * @param headerType the type of the header + * @param headerOffset the offset in the buffer where the header starts + * @param bodyLength the length of the body + * @param writeOption IPC write options + * @param customMetadata custom metadata to attach to the message + * @return the corresponding ByteBuffer + */ + public static ByteBuffer serializeMessage( + FlatBufferBuilder builder, + byte headerType, + int headerOffset, + long bodyLength, + IpcOption writeOption, + Map customMetadata) { + int customMetadataOffset = 0; + if (customMetadata != null && !customMetadata.isEmpty()) { + int[] metadataOffsets = new int[customMetadata.size()]; + int i = 0; + for (Map.Entry entry : customMetadata.entrySet()) { + int keyOffset = builder.createString(entry.getKey()); + int valueOffset = builder.createString(entry.getValue()); + metadataOffsets[i++] = KeyValue.createKeyValue(builder, keyOffset, valueOffset); + } + customMetadataOffset = Message.createCustomMetadataVector(builder, metadataOffsets); + } + Message.startMessage(builder); Message.addHeaderType(builder, headerType); Message.addHeader(builder, headerOffset); Message.addVersion(builder, writeOption.metadataVersion.toFlatbufID()); Message.addBodyLength(builder, bodyLength); + if (customMetadataOffset != 0) { + Message.addCustomMetadata(builder, customMetadataOffset); + } builder.finish(Message.endMessage(builder)); return builder.dataBuffer(); } diff --git a/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java b/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java index b529ca645a..6ce48bc51b 100644 --- a/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java +++ b/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java @@ -31,7 +31,9 @@ import java.nio.channels.Channels; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -242,4 +244,42 @@ public static void verifyBatch(ArrowRecordBatch batch, byte[] validity, byte[] v assertArrayEquals(validity, MessageSerializerTest.array(buffers.get(0))); assertArrayEquals(values, MessageSerializerTest.array(buffers.get(1))); } + + @Test + public void testRecordBatchCustomMetadata() throws Exception { + byte[] validity = new byte[] {(byte) 255, 0}; + byte[] values = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + + BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); + ArrowBuf validityb = buf(alloc, validity); + ArrowBuf valuesb = buf(alloc, values); + + Map customMetadata = new HashMap<>(); + customMetadata.put("key1", "value1"); + customMetadata.put("key2", "value2"); + + ArrowRecordBatch batch = + new ArrowRecordBatch( + 16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb), customMetadata); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), batch); + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + ReadChannel channel = new ReadChannel(Channels.newChannel(in)); + ArrowMessage deserialized = MessageSerializer.deserializeMessageBatch(channel, alloc); + + assertEquals(ArrowRecordBatch.class, deserialized.getClass()); + ArrowRecordBatch deserializedBatch = (ArrowRecordBatch) deserialized; + verifyBatch(deserializedBatch, validity, values); + + Map deserializedMetadata = deserializedBatch.getCustomMetadata(); + assertEquals(customMetadata, deserializedMetadata); + + validityb.close(); + valuesb.close(); + batch.close(); + deserialized.close(); + alloc.close(); + } }