Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.arrow.vector.ipc.message;

import java.util.Map;

/** Interface for Arrow IPC messages (https://arrow.apache.org/docs/format/IPC.html). */
public interface ArrowMessage extends FBSerializable, AutoCloseable {

Expand All @@ -26,6 +28,15 @@ public interface ArrowMessage extends FBSerializable, AutoCloseable {
/** Returns the flatbuffer enum value indicating the type of the message. */
byte getMessageType();

/**
* Returns custom metadata for this message, or null if none.
*
* @return custom metadata map, or null if no custom metadata is present
*/
default Map<String, String> getCustomMetadata() {
return null;
}

/**
* Visitor interface for implementations of {@link ArrowMessage}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import com.google.flatbuffers.FlatBufferBuilder;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.arrow.flatbuf.RecordBatch;
import org.apache.arrow.memory.ArrowBuf;
Expand Down Expand Up @@ -52,6 +54,8 @@ public class ArrowRecordBatch implements ArrowMessage {

private final List<Long> variadicBufferCounts;

private final Map<String, String> customMetadata;

private boolean closed = false;

public ArrowRecordBatch(int length, List<ArrowFieldNode> nodes, List<ArrowBuf> buffers) {
Expand All @@ -66,6 +70,30 @@ public ArrowRecordBatch(
this(length, nodes, buffers, bodyCompression, null, true);
}

/**
* Construct a record batch from nodes with custom metadata.
*
* @param length how many rows in this batch
* @param nodes field level info
* @param buffers will be retained until this recordBatch is closed
* @param customMetadata custom metadata for this record batch
*/
public ArrowRecordBatch(
int length,
List<ArrowFieldNode> nodes,
List<ArrowBuf> buffers,
Map<String, String> customMetadata) {
this(
length,
nodes,
buffers,
NoCompressionCodec.DEFAULT_BODY_COMPRESSION,
null,
true,
true,
customMetadata);
}

/**
* Construct a record batch from nodes.
*
Expand Down Expand Up @@ -152,13 +180,50 @@ public ArrowRecordBatch(
List<Long> variadicBufferCounts,
boolean alignBuffers,
boolean retainBuffers) {
this(
length,
nodes,
buffers,
bodyCompression,
variadicBufferCounts,
alignBuffers,
retainBuffers,
null);
}

/**
* Construct a record batch from nodes.
*
* @param length how many rows in this batch
* @param nodes field level info
* @param buffers will be retained until this recordBatch is closed
* @param bodyCompression compression info.
* @param variadicBufferCounts the number of buffers in each variadic section.
* @param alignBuffers Whether to align buffers to an 8 byte boundary.
* @param retainBuffers Whether to retain() each source buffer in the constructor. If false, the
* caller is responsible for retaining the buffers beforehand.
* @param customMetadata custom metadata for this record batch.
*/
public ArrowRecordBatch(
int length,
List<ArrowFieldNode> nodes,
List<ArrowBuf> buffers,
ArrowBodyCompression bodyCompression,
List<Long> variadicBufferCounts,
boolean alignBuffers,
boolean retainBuffers,
Map<String, String> customMetadata) {
super();
this.length = length;
this.nodes = nodes;
this.buffers = buffers;
Preconditions.checkArgument(bodyCompression != null, "body compression cannot be null");
this.bodyCompression = bodyCompression;
this.variadicBufferCounts = variadicBufferCounts;
this.customMetadata =
customMetadata == null
? Collections.emptyMap()
: Collections.unmodifiableMap(new HashMap<>(customMetadata));
List<ArrowBuffer> arrowBuffers = new ArrayList<>(buffers.size());
long offset = 0;
for (ArrowBuf arrowBuf : buffers) {
Expand Down Expand Up @@ -188,13 +253,18 @@ private ArrowRecordBatch(
List<ArrowFieldNode> nodes,
List<ArrowBuf> buffers,
ArrowBodyCompression bodyCompression,
List<Long> variadicBufferCounts) {
List<Long> variadicBufferCounts,
Map<String, String> customMetadata) {
this.length = length;
this.nodes = nodes;
this.buffers = buffers;
Preconditions.checkArgument(bodyCompression != null, "body compression cannot be null");
this.bodyCompression = bodyCompression;
this.variadicBufferCounts = variadicBufferCounts;
this.customMetadata =
customMetadata == null
? Collections.emptyMap()
: Collections.unmodifiableMap(new HashMap<>(customMetadata));
this.closed = false;
List<ArrowBuffer> arrowBuffers = new ArrayList<>();
long offset = 0;
Expand All @@ -218,6 +288,16 @@ public ArrowBodyCompression getBodyCompression() {
return bodyCompression;
}

/**
* Get the custom metadata for this record batch.
*
* @return the custom metadata as an unmodifiable map
*/
@Override
public Map<String, String> getCustomMetadata() {
return customMetadata;
}

/**
* Get the nodes in this record batch.
*
Expand Down Expand Up @@ -268,7 +348,7 @@ public ArrowRecordBatch cloneWithTransfer(final BufferAllocator allocator) {
.collect(Collectors.toList());
close();
return new ArrowRecordBatch(
false, length, nodes, newBufs, bodyCompression, variadicBufferCounts);
false, length, nodes, newBufs, bodyCompression, variadicBufferCounts, customMetadata);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.arrow.flatbuf.Buffer;
import org.apache.arrow.flatbuf.DictionaryBatch;
import org.apache.arrow.flatbuf.FieldNode;
import org.apache.arrow.flatbuf.KeyValue;
import org.apache.arrow.flatbuf.Message;
import org.apache.arrow.flatbuf.MessageHeader;
import org.apache.arrow.flatbuf.MetadataVersion;
Expand Down Expand Up @@ -326,7 +329,12 @@ public static ByteBuffer serializeMetadata(ArrowMessage message, IpcOption write
FlatBufferBuilder builder = new FlatBufferBuilder();
int batchOffset = message.writeTo(builder);
return serializeMessage(
builder, message.getMessageType(), batchOffset, message.computeBodyLength(), writeOption);
builder,
message.getMessageType(),
batchOffset,
message.computeBodyLength(),
writeOption,
message.getCustomMetadata());
}

/**
Expand All @@ -340,7 +348,18 @@ public static ByteBuffer serializeMetadata(ArrowMessage message, IpcOption write
public static ArrowRecordBatch deserializeRecordBatch(
Message recordBatchMessage, ArrowBuf bodyBuffer) throws IOException {
RecordBatch recordBatchFB = (RecordBatch) recordBatchMessage.header(new RecordBatch());
return deserializeRecordBatch(recordBatchFB, bodyBuffer);
// Extract custom metadata from the Message
Map<String, String> customMetadata = null;
if (recordBatchMessage.customMetadataLength() > 0) {
customMetadata = new HashMap<>();
for (int i = 0; i < recordBatchMessage.customMetadataLength(); i++) {
KeyValue kv = recordBatchMessage.customMetadata(i);
String key = kv.key();
String value = kv.value();
customMetadata.put(key == null ? "" : key, value == null ? "" : value);
}
}
return deserializeRecordBatch(recordBatchFB, bodyBuffer, customMetadata);
}

/**
Expand Down Expand Up @@ -395,10 +414,22 @@ public static ArrowRecordBatch deserializeRecordBatch(

RecordBatch recordBatchFB = (RecordBatch) messageFB.header(new RecordBatch());

// Extract custom metadata from the Message
Map<String, String> customMetadata = null;
if (messageFB.customMetadataLength() > 0) {
customMetadata = new HashMap<>();
for (int i = 0; i < messageFB.customMetadataLength(); i++) {
KeyValue kv = messageFB.customMetadata(i);
String key = kv.key();
String value = kv.value();
customMetadata.put(key == null ? "" : key, value == null ? "" : value);
}
}

// Now read the body
final ArrowBuf body =
buffer.slice(block.getMetadataLength(), totalLen - block.getMetadataLength());
return deserializeRecordBatch(recordBatchFB, body);
return deserializeRecordBatch(recordBatchFB, body, customMetadata);
}

/**
Expand All @@ -411,6 +442,22 @@ public static ArrowRecordBatch deserializeRecordBatch(
*/
public static ArrowRecordBatch deserializeRecordBatch(RecordBatch recordBatchFB, ArrowBuf body)
throws IOException {
return deserializeRecordBatch(recordBatchFB, body, null);
}

/**
* Deserializes an ArrowRecordBatch given the Flatbuffer metadata, in-memory body, and custom
* metadata.
*
* @param recordBatchFB Deserialized FlatBuffer record batch
* @param body Read body of the record batch
* @param customMetadata Custom metadata from the Message
* @return ArrowRecordBatch from metadata and in-memory body
* @throws IOException on error
*/
public static ArrowRecordBatch deserializeRecordBatch(
RecordBatch recordBatchFB, ArrowBuf body, Map<String, String> customMetadata)
throws IOException {
// Now read the body
int nodesLength = recordBatchFB.nodesLength();
List<ArrowFieldNode> nodes = new ArrayList<>();
Expand Down Expand Up @@ -452,7 +499,9 @@ public static ArrowRecordBatch deserializeRecordBatch(RecordBatch recordBatchFB,
buffers,
bodyCompression,
variadicBufferCounts,
/*alignBuffers*/ true);
/*alignBuffers*/ true,
/*retainBuffers*/ true,
customMetadata);
body.getReferenceManager().release();
return arrowRecordBatch;
}
Expand Down Expand Up @@ -676,11 +725,47 @@ public static ByteBuffer serializeMessage(
int headerOffset,
long bodyLength,
IpcOption writeOption) {
return serializeMessage(builder, headerType, headerOffset, bodyLength, writeOption, null);
}

/**
* Serializes an Arrow message with metadata and custom metadata into a ByteBuffer.
*
* @param builder to write the flatbuf to
* @param headerType the type of the header
* @param headerOffset the offset in the buffer where the header starts
* @param bodyLength the length of the body
* @param writeOption IPC write options
* @param customMetadata custom metadata to attach to the message
* @return the corresponding ByteBuffer
*/
public static ByteBuffer serializeMessage(
FlatBufferBuilder builder,
byte headerType,
int headerOffset,
long bodyLength,
IpcOption writeOption,
Map<String, String> customMetadata) {
int customMetadataOffset = 0;
if (customMetadata != null && !customMetadata.isEmpty()) {
int[] metadataOffsets = new int[customMetadata.size()];
int i = 0;
for (Map.Entry<String, String> entry : customMetadata.entrySet()) {
int keyOffset = builder.createString(entry.getKey());
int valueOffset = builder.createString(entry.getValue());
metadataOffsets[i++] = KeyValue.createKeyValue(builder, keyOffset, valueOffset);
}
customMetadataOffset = Message.createCustomMetadataVector(builder, metadataOffsets);
}

Message.startMessage(builder);
Message.addHeaderType(builder, headerType);
Message.addHeader(builder, headerOffset);
Message.addVersion(builder, writeOption.metadataVersion.toFlatbufID());
Message.addBodyLength(builder, bodyLength);
if (customMetadataOffset != 0) {
Message.addCustomMetadata(builder, customMetadataOffset);
}
builder.finish(Message.endMessage(builder));
return builder.dataBuffer();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
import java.nio.channels.Channels;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
Expand Down Expand Up @@ -242,4 +244,42 @@ public static void verifyBatch(ArrowRecordBatch batch, byte[] validity, byte[] v
assertArrayEquals(validity, MessageSerializerTest.array(buffers.get(0)));
assertArrayEquals(values, MessageSerializerTest.array(buffers.get(1)));
}

@Test
public void testRecordBatchCustomMetadata() throws Exception {
byte[] validity = new byte[] {(byte) 255, 0};
byte[] values = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE);
ArrowBuf validityb = buf(alloc, validity);
ArrowBuf valuesb = buf(alloc, values);

Map<String, String> customMetadata = new HashMap<>();
customMetadata.put("key1", "value1");
customMetadata.put("key2", "value2");

ArrowRecordBatch batch =
new ArrowRecordBatch(
16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb), customMetadata);

ByteArrayOutputStream out = new ByteArrayOutputStream();
MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), batch);

ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
ReadChannel channel = new ReadChannel(Channels.newChannel(in));
ArrowMessage deserialized = MessageSerializer.deserializeMessageBatch(channel, alloc);

assertEquals(ArrowRecordBatch.class, deserialized.getClass());
ArrowRecordBatch deserializedBatch = (ArrowRecordBatch) deserialized;
verifyBatch(deserializedBatch, validity, values);

Map<String, String> deserializedMetadata = deserializedBatch.getCustomMetadata();
assertEquals(customMetadata, deserializedMetadata);

validityb.close();
valuesb.close();
batch.close();
deserialized.close();
alloc.close();
}
}