From 852d12ee00561916f91519df851f76aa01d2a3ae Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Tue, 26 May 2026 14:41:03 +0000 Subject: [PATCH] feat(table): add SessionContext.registerStreamingTable for push-mode batch ingest --- .../org/apache/datafusion/SessionContext.java | 112 +++ .../java/org/apache/datafusion/TableSink.java | 220 ++++++ .../SessionContextStreamingTableTest.java | 685 ++++++++++++++++++ native/Cargo.lock | 1 + native/Cargo.toml | 5 +- native/src/lib.rs | 143 ++++ native/src/streaming_table.rs | 390 ++++++++++ 7 files changed, 1555 insertions(+), 1 deletion(-) create mode 100644 core/src/main/java/org/apache/datafusion/TableSink.java create mode 100644 core/src/test/java/org/apache/datafusion/SessionContextStreamingTableTest.java create mode 100644 native/src/streaming_table.rs diff --git a/core/src/main/java/org/apache/datafusion/SessionContext.java b/core/src/main/java/org/apache/datafusion/SessionContext.java index ffc58dd..9b23cfd 100644 --- a/core/src/main/java/org/apache/datafusion/SessionContext.java +++ b/core/src/main/java/org/apache/datafusion/SessionContext.java @@ -588,6 +588,115 @@ public void registerTable(String name, TableProvider provider) { registerTableNative(nativeHandle, name, schemaIpc, provider); } + /** + * Register a push-mode streaming table and return a {@link TableSink} the caller writes batches + * into. Companion to {@link #registerTable(String, TableProvider)} for event-driven producers + * that don't materialise an {@link org.apache.arrow.vector.ipc.ArrowReader} up front. + * + *

Producers (any thread) call {@link TableSink#write(VectorSchemaRoot)} to push a batch; + * DataFusion's {@code StreamingTableExec} polls the underlying mpsc channel during query + * execution. When the producer is done it calls {@link TableSink#close()} to signal end-of-stream + * (or {@link TableSink#fail(Throwable)} to signal an error). The {@code capacity} argument is the + * channel buffer size in batches; producers writing faster than the consumer block on {@code + * write} until space frees up. + * + *

Single-scan semantics. The registered table can be queried at most once. + * Subsequent scans against the same registration throw a {@link RuntimeException}. This matches + * the natural semantic for an event-driven producer (the data is consumed as it arrives) and + * keeps the implementation from buffering every batch internally. Callers who need re-scan should + * use {@link #registerTable(String, TableProvider)} with {@link SimpleTableProvider} instead. + * + *

The registered table outlives the returned {@link TableSink}: the sink represents the + * producer side of the channel, and dropping it (via {@link TableSink#close()}) signals + * end-of-stream without unregistering the table. + * + *

Schema constraint -- non-empty. {@code schema} must declare at least one + * column. The sink derives its Arrow C Data Interface scratch allocator from each incoming + * batch's first field vector (so the FFI scratch shares an allocator-root with the producer's + * buffers and the export call can transfer ownership), which has no allocator to borrow when the + * schema is empty. Zero-column streaming tables are uncommon in practice -- they would only arise + * from a planner that requested row counts from a streaming source -- and supporting them would + * add a second allocator-management path for a use case the OpenSearch prior art does not + * exercise. + * + *

Schema constraint -- no dictionary-encoded fields. Dictionary-encoded + * fields are not supported in v1. {@link + * org.apache.arrow.c.Data#exportVectorSchemaRoot(org.apache.arrow.memory.BufferAllocator, + * VectorSchemaRoot, org.apache.arrow.vector.dictionary.DictionaryProvider, + * org.apache.arrow.c.ArrowArray, org.apache.arrow.c.ArrowSchema)} requires a non-null {@code + * DictionaryProvider} when any field is dictionary-encoded, and {@link TableSink#write} passes + * {@code null} -- so a dictionary-encoded schema would NPE on first write. Rejecting at + * registration makes that breakage visible up front. A future overload that accepts a {@code + * DictionaryProvider} would lift this restriction; out of scope for v1. + * + * @param name the table name to register under. + * @param schema the fixed schema of all batches the producer will push. Must have at least one + * column and no dictionary-encoded fields. + * @param capacity the channel buffer size in batches; must be {@code > 0}. Tune based on the + * producer's burstiness vs. the consumer's drain rate. + * @return a {@link TableSink} the caller writes batches into. Owns native resources; must be + * closed. + * @throws IllegalArgumentException if {@code name} or {@code schema} is {@code null}; if {@code + * schema} has zero columns or any dictionary-encoded field; or if {@code capacity <= 0}. + * @throws IllegalStateException if this context is closed. + * @throws RuntimeException if native registration fails (e.g. a table with the same name is + * already registered). + */ + public TableSink registerStreamingTable(String name, Schema schema, int capacity) { + if (nativeHandle == 0) { + throw new IllegalStateException("SessionContext is closed"); + } + if (name == null) { + throw new IllegalArgumentException("registerStreamingTable name must be non-null"); + } + if (schema == null) { + throw new IllegalArgumentException("registerStreamingTable schema must be non-null"); + } + if (schema.getFields().isEmpty()) { + throw new IllegalArgumentException( + "registerStreamingTable schema must have at least one column " + + "(zero-column streaming tables are not supported -- see Javadoc)"); + } + // Walk recursively: Data.exportVectorSchemaRoot recurses into Struct/List/Map/Union + // children and dereferences the DictionaryProvider for any dictionary-encoded field at any + // depth. A null provider would NPE on first write for a nested dictionary, so we have to + // reject the schema if *any* descendant field carries an encoding -- top-level isn't enough. + for (Field field : schema.getFields()) { + checkNoDictionaryEncoding(field, field.getName()); + } + if (capacity <= 0) { + throw new IllegalArgumentException( + "registerStreamingTable capacity must be positive, was " + capacity); + } + byte[] schemaIpc = serializeSchemaIpc(schema); + long sinkHandle = registerStreamingTableNative(nativeHandle, name, schemaIpc, capacity); + // The TableSink derives its FFI scratch allocator from each incoming batch so that the + // sink's exported buffers share a root with the producer's vectors -- otherwise Arrow's + // C-Data export rejects the cross-root transfer. + return new TableSink(sinkHandle); + } + + /** + * Recursively reject any field that carries a {@code DictionaryEncoding}. {@code path} carries + * the dotted path from the schema root for the error message ("a.b.c"); on the first hit the + * caller sees which descendant tripped the check. + */ + private static void checkNoDictionaryEncoding(Field field, String path) { + if (field.getDictionary() != null) { + throw new IllegalArgumentException( + "registerStreamingTable does not support dictionary-encoded fields " + + "(field '" + + path + + "' has dictionary id " + + field.getDictionary().getId() + + "); v1 cannot supply a DictionaryProvider to Data.exportVectorSchemaRoot. " + + "See Javadoc for the planned follow-up."); + } + for (Field child : field.getChildren()) { + checkNoDictionaryEncoding(child, path + "." + child.getName()); + } + } + private static byte[] serializeSchemaIpc(Schema schema) { ByteArrayOutputStream baos = new ByteArrayOutputStream(); try (BufferAllocator allocator = new RootAllocator(); @@ -664,4 +773,7 @@ private static native void registerScalarUdf( private static native void registerTableNative( long handle, String name, byte[] schemaIpcBytes, TableProvider provider); + + private static native long registerStreamingTableNative( + long handle, String name, byte[] schemaIpcBytes, int capacity); } diff --git a/core/src/main/java/org/apache/datafusion/TableSink.java b/core/src/main/java/org/apache/datafusion/TableSink.java new file mode 100644 index 0000000..7fdebb2 --- /dev/null +++ b/core/src/main/java/org/apache/datafusion/TableSink.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; + +/** + * Push-mode batch sink for a streaming table registered via {@link + * SessionContext#registerStreamingTable(String, org.apache.arrow.vector.types.pojo.Schema, int)}. + * Producers push {@link VectorSchemaRoot} batches; the registered table's consumer (DataFusion) + * polls them out via the standard streaming-table plumbing. + * + *

Lifecycle: + * + *

+ * + *

Single-scan semantics. The streaming table this sink feeds can be queried at + * most once. After the consumer side has drained, the sink stops being useful and the registered + * table cannot be re-scanned. Callers who need re-scan should use the pull-mode {@link + * SessionContext#registerTable(String, TableProvider)} path instead. + * + *

Thread safety. Concurrent {@link #write(VectorSchemaRoot)} calls from + * multiple producer threads are safe but their relative ordering is not specified; callers that + * care about batch ordering should confine writes to a single thread. Concurrent {@link #close()} + * (or {@link #fail(Throwable)}) racing with in-flight {@link #write(VectorSchemaRoot)} calls is + * also safe: the close path drops the underlying mpsc sender first (which unblocks any write parked + * on backpressure with an error), then waits for in-flight writes to drain before freeing the + * native handle. New writes started after {@link #close()} began throw {@link + * IllegalStateException} promptly without touching the native side. + */ +public final class TableSink implements AutoCloseable { + static { + NativeLibraryLoader.loadLibrary(); + } + + /** + * Native handle to the {@code Box>}. Atomic: zeroed atomically by {@link + * #close()} / {@link #fail(Throwable)} to forbid new writes; the lifecycle lock below makes the + * eventual {@code dropHandleNative} wait for in-flight writes to drain. + */ + private final AtomicLong nativeHandle; + + /** + * Read-write lock guarding the native pointer's lifetime. {@link #write(VectorSchemaRoot)} holds + * the read lock for the duration of the native call so a concurrent {@link #close()} cannot drop + * the box while we're still using it. {@link #close()} / {@link #fail(Throwable)} take the write + * lock around {@code dropHandleNative} so it waits for all in-flight writes to release. + * + *

Crucial detail: {@code closeSinkNative} (drops the producer-side {@code Sender}) runs + * before we acquire the write lock. Otherwise a write parked on {@code + * Sender::blocking_send} would never release the read lock and {@code close()} would deadlock. + * Dropping the sender first lets the parked send return {@code Err} immediately, the write + * returns, releases the read lock, and {@code close()} can acquire the write lock and drop the + * box. + */ + private final ReentrantReadWriteLock lifecycle = new ReentrantReadWriteLock(); + + TableSink(long nativeHandle) { + if (nativeHandle == 0) { + throw new IllegalArgumentException("TableSink native handle is null"); + } + this.nativeHandle = new AtomicLong(nativeHandle); + } + + /** + * Send a batch through the channel. Blocks if the channel is at capacity until the consumer + * drains a batch or until the consumer side is dropped. + * + *

The batch's schema must match the schema this sink was registered with. The native side + * exports the batch via Arrow's C Data Interface; the underlying buffers are reference-counted + * and the caller can safely close {@code batch} immediately after this method returns. + * + * @throws IllegalArgumentException if {@code batch} is {@code null}. + * @throws IllegalStateException if this sink has been closed. + * @throws RuntimeException if the consumer side is gone (query cancelled or finished), if the + * batch's schema doesn't match the registered schema, or if the Arrow C Data export fails. + */ + public void write(VectorSchemaRoot batch) { + if (batch == null) { + throw new IllegalArgumentException("write batch must be non-null"); + } + // Hold the lifecycle read lock for the entire native call so a concurrent close() cannot + // free the native box while we're parked inside Sender::blocking_send. + lifecycle.readLock().lock(); + try { + long h = nativeHandle.get(); + if (h == 0) { + throw new IllegalStateException("TableSink is closed"); + } + // The FFI scratch (ArrowArray + ArrowSchema) must share an allocator-root with the batch's + // vectors so Data.exportVectorSchemaRoot can transfer ownership without crossing root + // boundaries. We derive the allocator from the batch itself; an empty batch (no vectors) + // has no allocator to borrow, but registerStreamingTable's schema requires at least one + // column, so this path is unreachable in practice. + BufferAllocator batchAllocator = batchAllocator(batch); + try (ArrowArray ffiArray = ArrowArray.allocateNew(batchAllocator); + ArrowSchema ffiSchema = ArrowSchema.allocateNew(batchAllocator)) { + Data.exportVectorSchemaRoot(batchAllocator, batch, null, ffiArray, ffiSchema); + writeBatchNative(h, ffiArray.memoryAddress(), ffiSchema.memoryAddress()); + } + } finally { + lifecycle.readLock().unlock(); + } + } + + private static BufferAllocator batchAllocator(VectorSchemaRoot batch) { + if (batch.getFieldVectors().isEmpty()) { + throw new IllegalArgumentException( + "TableSink.write requires a batch with at least one column"); + } + FieldVector first = batch.getFieldVectors().get(0); + BufferAllocator allocator = first.getAllocator(); + if (allocator == null) { + throw new IllegalStateException( + "VectorSchemaRoot's field vectors have no allocator; was the batch initialised?"); + } + return allocator; + } + + /** + * End-of-stream signal: the consumer's next poll sees a clean EOF. Idempotent. Subsequent calls + * to {@link #write(VectorSchemaRoot)} throw {@link IllegalStateException}. + * + *

Safe to call from {@code try-with-resources}. Note that {@code close()} alone signals a + * clean end-of-stream; if the producer encountered an error, use {@link + * #fail(Throwable)} instead so the consumer observes the error. + */ + @Override + public void close() { + long h = nativeHandle.getAndSet(0); + if (h == 0) { + return; // already closed; idempotent. + } + // 1. Drop the producer-side Sender so any concurrent write parked on backpressure unblocks + // with an error and releases its read lock. Doing this BEFORE acquiring the write lock + // is what avoids the close-deadlocks-on-stuck-write race. + closeSinkNative(h); + // 2. Wait for in-flight writes to release the read lock, then drop the box. The Arc clones + // those writes hold keep the inner TableSinkHandle alive across this drop; the Box itself + // is what we're freeing. + lifecycle.writeLock().lock(); + try { + dropHandleNative(h); + } finally { + lifecycle.writeLock().unlock(); + } + } + + /** + * End-of-stream with error: the consumer's next poll surfaces a {@link RuntimeException} whose + * message includes the supplied cause's message. After this call the sink is closed; further + * {@link #write(VectorSchemaRoot)} or {@link #fail(Throwable)} calls throw. + * + * @throws IllegalArgumentException if {@code cause} is {@code null}. + * @throws IllegalStateException if this sink has already been closed. + */ + public void fail(Throwable cause) { + if (cause == null) { + throw new IllegalArgumentException("fail cause must be non-null"); + } + long h = nativeHandle.getAndSet(0); + if (h == 0) { + throw new IllegalStateException("TableSink is closed"); + } + String message = cause.getMessage(); + if (message == null) { + message = cause.getClass().getName(); + } + // Same two-phase shape as close(): record the terminal error + drop the sender first so + // any parked write unblocks; then take the write lock and drop the box. + failSinkNative(h, message); + lifecycle.writeLock().lock(); + try { + dropHandleNative(h); + } finally { + lifecycle.writeLock().unlock(); + } + } + + private static native void writeBatchNative(long handle, long arrayAddr, long schemaAddr); + + private static native void closeSinkNative(long handle); + + private static native void failSinkNative(long handle, String message); + + private static native void dropHandleNative(long handle); +} diff --git a/core/src/test/java/org/apache/datafusion/SessionContextStreamingTableTest.java b/core/src/test/java/org/apache/datafusion/SessionContextStreamingTableTest.java new file mode 100644 index 0000000..71bc603 --- /dev/null +++ b/core/src/test/java/org/apache/datafusion/SessionContextStreamingTableTest.java @@ -0,0 +1,685 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Test; + +class SessionContextStreamingTableTest { + + /** Build a `(id: int32 not-null)` schema. */ + private static Schema oneIntSchema() { + return new Schema( + Collections.singletonList( + new Field("id", FieldType.notNullable(new ArrowType.Int(32, true)), null))); + } + + /** + * Build a single-column int batch with the given values, allocated against the supplied root + * allocator. Caller closes the returned root. + */ + private static VectorSchemaRoot makeIntBatch(BufferAllocator allocator, int[] values) { + VectorSchemaRoot root = VectorSchemaRoot.create(oneIntSchema(), allocator); + IntVector vec = (IntVector) root.getVector(0); + vec.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + vec.set(i, values[i]); + } + vec.setValueCount(values.length); + root.setRowCount(values.length); + return root; + } + + @Test + void writeAndScanFromSameThread() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + try (TableSink sink = ctx.registerStreamingTable("t", oneIntSchema(), 4)) { + try (VectorSchemaRoot b1 = makeIntBatch(allocator, new int[] {1, 2, 3}); + VectorSchemaRoot b2 = makeIntBatch(allocator, new int[] {4, 5})) { + sink.write(b1); + sink.write(b2); + } + } // close() = EOF before the SQL query runs. + try (DataFrame df = ctx.sql("SELECT count(*) FROM t"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + BigIntVector count = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); + assertEquals(5L, count.get(0)); + } + } + } + + @Test + void producerOnSeparateThread() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + TableSink sink = ctx.registerStreamingTable("t", oneIntSchema(), 2); + Thread producer = + new Thread( + () -> { + try (VectorSchemaRoot b1 = makeIntBatch(allocator, new int[] {10, 20, 30}); + VectorSchemaRoot b2 = makeIntBatch(allocator, new int[] {40})) { + sink.write(b1); + sink.write(b2); + } finally { + sink.close(); + } + }); + producer.start(); + try (DataFrame df = ctx.sql("SELECT sum(id) FROM t"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + BigIntVector sum = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); + assertEquals(100L, sum.get(0)); + } + producer.join(5_000); + assertTrue(!producer.isAlive(), "producer thread did not finish"); + } + } + + @Test + void closeSignalsEndOfStream() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + try (TableSink sink = ctx.registerStreamingTable("t", oneIntSchema(), 4)) { + try (VectorSchemaRoot b = makeIntBatch(allocator, new int[] {1, 2})) { + sink.write(b); + } + sink.close(); + } + try (DataFrame df = ctx.sql("SELECT count(*) FROM t"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + BigIntVector count = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); + assertEquals(2L, count.get(0)); + } + } + } + + @Test + void failPropagatesError() { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + TableSink sink = ctx.registerStreamingTable("t", oneIntSchema(), 4); + try (VectorSchemaRoot b = makeIntBatch(allocator, new int[] {1, 2})) { + sink.write(b); + } + sink.fail(new RuntimeException("producer-side boom")); + RuntimeException e = + assertThrows( + RuntimeException.class, + () -> { + try (DataFrame df = ctx.sql("SELECT count(*) FROM t"); + ArrowReader reader = df.collect(allocator)) { + // collect() may either throw directly or surface on first loadNextBatch. + while (reader.loadNextBatch()) { + // drain + } + } + }); + assertTrue( + e.getMessage().contains("producer-side boom"), + () -> "unexpected error message: " + e.getMessage()); + } + } + + /** + * Regression test for the close()-vs-blocked-write() use-after-free. With capacity 1 and one + * successful write, the data channel is full; a second write parks inside {@code + * Sender::blocking_send}. We then close() from another thread. close() must: + * + *

    + *
  1. Drop the sender so the parked write returns Err and releases its read lock. + *
  2. Wait for the read lock before freeing the native box. + *
+ * + *

If either step is missing, dropHandleNative would free the box while writeBatchNative still + * holds a borrowed pointer -- use-after-free. + */ + @Test + void closeDuringBlockedWriteIsSafe() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + TableSink sink = ctx.registerStreamingTable("t", oneIntSchema(), 1); + AtomicReference producerError = new AtomicReference<>(); + CountDownLatch firstWriteDone = new CountDownLatch(1); + Thread producer = + new Thread( + () -> { + try (VectorSchemaRoot b1 = makeIntBatch(allocator, new int[] {1}); + VectorSchemaRoot b2 = makeIntBatch(allocator, new int[] {2})) { + sink.write(b1); // fills the capacity-1 channel + firstWriteDone.countDown(); + sink.write(b2); // parks on blocking_send; close() unblocks with Err + } catch (Throwable t) { + producerError.set(t); + } + }); + producer.start(); + assertTrue(firstWriteDone.await(5, TimeUnit.SECONDS)); + // Give the producer a moment to actually park on the second write. + Thread.sleep(100); + // close() from this (main) thread races with the producer's parked write. Must not UAF. + sink.close(); + producer.join(5_000); + assertTrue(!producer.isAlive(), "producer thread did not unblock after close()"); + // The producer's second write should have surfaced the consumer-closed error. + assertNotNull( + producerError.get(), + "producer's parked write should have thrown when close() dropped the receiver"); + assertTrue( + producerError.get() instanceof RuntimeException, + () -> "expected RuntimeException, got " + producerError.get()); + } + } + + /** + * Regression test for the "fail() blocks on full channel pre-query" deadlock. With capacity 1 and + * one successful write, the data channel is full. The producer then calls fail() before any + * consumer starts -- if fail() were implemented as `tx.blocking_send(Err(...))`, it would park + * forever because no consumer is draining. The sideband-error implementation must terminate + * synchronously and the consumer must observe the error when it eventually runs. + */ + @Test + void failOnFullChannelBeforeConsumer() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + TableSink sink = ctx.registerStreamingTable("t", oneIntSchema(), 1); + try (VectorSchemaRoot b = makeIntBatch(allocator, new int[] {1})) { + sink.write(b); // channel now full + } + // Run fail() on a separate thread with a hard timeout so a regression doesn't hang the + // whole test suite. The fix must let fail() return promptly even when the channel is full. + Thread failer = new Thread(() -> sink.fail(new RuntimeException("pre-consumer boom"))); + failer.start(); + failer.join(5_000); + assertTrue( + !failer.isAlive(), + "fail() blocked on a full channel; sideband-error path is missing or broken"); + // Consumer now runs; observes the queued batch first, then the terminal error. + RuntimeException e = + assertThrows( + RuntimeException.class, + () -> { + try (DataFrame df = ctx.sql("SELECT count(*) FROM t"); + ArrowReader reader = df.collect(allocator)) { + while (reader.loadNextBatch()) { + // drain + } + } + }); + assertTrue( + e.getMessage().contains("pre-consumer boom"), + () -> "unexpected error message: " + e.getMessage()); + } + } + + @Test + void backpressureBlocksProducer() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + // Capacity 1: first write goes through; second blocks until the consumer drains. + TableSink sink = ctx.registerStreamingTable("t", oneIntSchema(), 1); + CountDownLatch firstWriteDone = new CountDownLatch(1); + CountDownLatch secondWriteStarted = new CountDownLatch(1); + AtomicReference producerError = new AtomicReference<>(); + Thread producer = + new Thread( + () -> { + try (VectorSchemaRoot b1 = makeIntBatch(allocator, new int[] {1}); + VectorSchemaRoot b2 = makeIntBatch(allocator, new int[] {2}); + VectorSchemaRoot b3 = makeIntBatch(allocator, new int[] {3})) { + sink.write(b1); + firstWriteDone.countDown(); + // Two more writes; one of these must block because the channel is at capacity. + secondWriteStarted.countDown(); + sink.write(b2); + sink.write(b3); + } catch (Throwable t) { + producerError.set(t); + } finally { + sink.close(); + } + }); + producer.start(); + assertTrue(firstWriteDone.await(5, TimeUnit.SECONDS)); + assertTrue(secondWriteStarted.await(5, TimeUnit.SECONDS)); + // The producer must still be alive at this point because it's blocked on the second + // (or third) write. No way to assert "thread is parked" cleanly from JUnit, so rely on + // the consumer side draining and observing all three batches. + try (DataFrame df = ctx.sql("SELECT count(*) FROM t"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + BigIntVector count = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); + assertEquals(3L, count.get(0)); + } + producer.join(5_000); + assertTrue(!producer.isAlive(), "producer thread did not finish"); + assertEquals(null, producerError.get(), () -> "producer error: " + producerError.get()); + } + } + + @Test + void secondScanThrows() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + try (TableSink sink = ctx.registerStreamingTable("t", oneIntSchema(), 4)) { + try (VectorSchemaRoot b = makeIntBatch(allocator, new int[] {7})) { + sink.write(b); + } + } + try (DataFrame df = ctx.sql("SELECT count(*) FROM t"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + } + // Second scan against the same registered streaming table must fail. + RuntimeException e = + assertThrows( + RuntimeException.class, + () -> { + try (DataFrame df2 = ctx.sql("SELECT count(*) FROM t"); + ArrowReader r2 = df2.collect(allocator)) { + while (r2.loadNextBatch()) { + // drain + } + } + }); + assertTrue( + e.getMessage().toLowerCase().contains("single-scan") + || e.getMessage().toLowerCase().contains("already consumed"), + () -> "unexpected error message: " + e.getMessage()); + } + } + + @Test + void writeAfterCloseThrows() { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + TableSink sink = ctx.registerStreamingTable("t", oneIntSchema(), 4); + sink.close(); + try (VectorSchemaRoot b = makeIntBatch(allocator, new int[] {1})) { + assertThrows(IllegalStateException.class, () -> sink.write(b)); + } + // close() is idempotent. + sink.close(); + } + } + + @Test + void schemaMismatchRejected() { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + // Register with int32; try to write a batch with a different schema (int64 column). + try (TableSink sink = ctx.registerStreamingTable("t", oneIntSchema(), 4)) { + Schema otherSchema = + new Schema( + Collections.singletonList( + new Field("id", FieldType.notNullable(new ArrowType.Int(64, true)), null))); + try (VectorSchemaRoot bad = VectorSchemaRoot.create(otherSchema, allocator)) { + BigIntVector v = (BigIntVector) bad.getVector(0); + v.allocateNew(1); + v.set(0, 42L); + v.setValueCount(1); + bad.setRowCount(1); + assertThrows(RuntimeException.class, () -> sink.write(bad)); + } + } + } + } + + @Test + void nullArgumentValidation() { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + assertThrows( + IllegalArgumentException.class, + () -> ctx.registerStreamingTable(null, oneIntSchema(), 4)); + assertThrows(IllegalArgumentException.class, () -> ctx.registerStreamingTable("t", null, 4)); + assertThrows( + IllegalArgumentException.class, () -> ctx.registerStreamingTable("t", oneIntSchema(), 0)); + assertThrows( + IllegalArgumentException.class, + () -> ctx.registerStreamingTable("t", oneIntSchema(), -1)); + try (TableSink sink = ctx.registerStreamingTable("t", oneIntSchema(), 4)) { + assertThrows(IllegalArgumentException.class, () -> sink.write(null)); + assertThrows(IllegalArgumentException.class, () -> sink.fail(null)); + } + // Verify the sink's allocator gets cleaned up implicitly via close(). + assertNotNull(allocator); + } + } + + /** + * Regression stress test for the lost-wakeup race in {@code close()} vs. blocked {@code write()}. + * With {@code Notify::notify_waiters} alone (the wake is dropped if the writer hasn't registered + * its {@code Notified} future yet), a writer preempted between cloning the sender and entering + * {@code select!} would park forever on a full channel after a concurrent {@code close()}. This + * test runs many iterations with no consumer ever reading, so any occurrence of the race + * deterministically deadlocks (caught by {@code producer.join(timeout)} returning while the + * thread is still alive). + * + *

Each iteration: capacity 1, two writes (second parks on backpressure), {@code close()} from + * a separate thread immediately after the first write completes, no consumer. The fix (durable + * {@code closed_flag} re-check after registering the {@code Notified} future) prevents the parked + * writer from missing the wake. + */ + @Test + void closeWakesBlockedWriteUnderStress() throws Exception { + final int iterations = 100; + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + for (int i = 0; i < iterations; i++) { + TableSink sink = ctx.registerStreamingTable("t" + i, oneIntSchema(), 1); + AtomicReference producerError = new AtomicReference<>(); + CountDownLatch readyToBlock = new CountDownLatch(1); + Thread producer = + new Thread( + () -> { + try (VectorSchemaRoot b1 = makeIntBatch(allocator, new int[] {1}); + VectorSchemaRoot b2 = makeIntBatch(allocator, new int[] {2})) { + sink.write(b1); + readyToBlock.countDown(); + sink.write(b2); // parks on backpressure; close() must wake it + } catch (Throwable t) { + producerError.set(t); + } + }); + producer.start(); + readyToBlock.await(5, TimeUnit.SECONDS); + // Race: close() may run before, during, or after the producer registers its Notified + // future. With the broken notify_waiters() approach, "before registration" loses the + // wake. With the closed_flag re-check, all interleavings terminate. + sink.close(); + producer.join(5_000); + assertTrue( + !producer.isAlive(), + "iteration " + (i + 1) + ": producer did not unblock after close() (lost wakeup)"); + assertNotNull( + producerError.get(), + "iteration " + (i + 1) + ": producer's parked write should have thrown"); + } + } + } + + /** + * Regression test for zero-column schemas. {@link TableSink#write} derives its FFI scratch + * allocator from the batch's first field vector, which means it cannot handle a zero-column + * batch. Rather than fail at first write (after registration succeeds), {@code + * registerStreamingTable} rejects empty schemas up front so the broken state is visible to the + * caller immediately. + */ + @Test + void zeroColumnSchemaRejectedAtRegistration() { + try (SessionContext ctx = new SessionContext()) { + Schema empty = new Schema(java.util.Collections.emptyList()); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> ctx.registerStreamingTable("t", empty, 4)); + assertTrue( + e.getMessage().contains("at least one column"), + () -> "expected 'at least one column' in: " + e.getMessage()); + } + } + + /** + * Regression test for top-level Schema metadata being dropped during the per-batch FFI import. On + * the Rust side {@code RecordBatch::from(StructArray)} rebuilds the schema as {@code + * Schema::new(fields)}, losing the metadata even though it was correctly delivered by the C Data + * Interface. Without the fields-only comparison + re-attach in {@code TableSinkHandle::write}, + * any caller registering a schema with top-level metadata gets a confusing "schema does not + * match" error on the first write even though the fields and buffers are identical. + */ + @Test + void writeAcceptsSchemaWithTopLevelMetadata() throws Exception { + java.util.Map meta = java.util.Map.of("source", "shard-A", "version", "1"); + Schema schemaWithMeta = + new Schema( + Collections.singletonList( + new Field("id", FieldType.notNullable(new ArrowType.Int(32, true)), null)), + meta); + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext(); + TableSink sink = ctx.registerStreamingTable("t", schemaWithMeta, 4)) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(schemaWithMeta, allocator)) { + IntVector vec = (IntVector) root.getVector(0); + vec.allocateNew(2); + vec.set(0, 7); + vec.set(1, 9); + vec.setValueCount(2); + root.setRowCount(2); + sink.write(root); // would throw "schema does not match" before the fix + } + sink.close(); + try (DataFrame df = ctx.sql("SELECT count(*) FROM t"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + BigIntVector count = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); + assertEquals(2L, count.get(0)); + } + } + } + + /** + * Regression test for the cross-lifecycle drop hazard. Closing the {@link SessionContext} before + * {@link TableSink#close()} should drop the registered table (and its receiver) so that + * subsequent {@code write()} calls fail with a "consumer side closed" error. Earlier the sink + * also held an {@code Arc} to the receiver mutex, which kept the receiver alive even after the + * table was gone -- a producer that tried to write under backpressure would then park forever + * with no path to wake. The fix is to keep the receiver owned exclusively by the partition + * stream. + */ + @Test + void writeAfterSessionContextClosedFailsPromptly() throws Exception { + BufferAllocator allocator = new RootAllocator(); + TableSink sink = null; + try { + try (SessionContext ctx = new SessionContext()) { + // capacity 1 + one queued write so the next write would otherwise park. + sink = ctx.registerStreamingTable("t", oneIntSchema(), 1); + try (VectorSchemaRoot b1 = makeIntBatch(allocator, new int[] {1})) { + sink.write(b1); + } + } // SessionContext closes here, dropping the registered table and its receiver. + // The next write must surface an error promptly rather than parking on a defunct channel. + final TableSink finalSink = sink; + Thread writer = + new Thread( + () -> { + try (VectorSchemaRoot b2 = makeIntBatch(allocator, new int[] {2})) { + finalSink.write(b2); + } catch (RuntimeException ignored) { + // expected: consumer side closed. + } + }); + writer.start(); + writer.join(5_000); + assertTrue( + !writer.isAlive(), + "write() parked after the SessionContext (and its receiver) was dropped"); + } finally { + // Clean up: TableSink owns native resources independent of the session. + if (sink != null) { + sink.close(); + } + allocator.close(); + } + } + + /** + * Regression test for dictionary-encoded fields. {@link TableSink#write} calls {@code + * Data.exportVectorSchemaRoot} with a null {@code DictionaryProvider}; if the registered schema + * declares any dictionary-encoded field, the export path requires a non-null provider and would + * otherwise throw NPE on first write. We reject at registration so the breakage is visible up + * front; lifting this restriction would mean an overload that accepts a {@code + * DictionaryProvider}, which is out of scope for v1. + */ + @Test + void dictionaryEncodedSchemaRejectedAtRegistration() { + try (SessionContext ctx = new SessionContext()) { + org.apache.arrow.vector.types.pojo.DictionaryEncoding enc = + new org.apache.arrow.vector.types.pojo.DictionaryEncoding( + /* id */ 0L, /* ordered */ false, /* indexType */ null); + Field dictField = + new Field( + "color", new FieldType(true, new ArrowType.Int(32, true), enc), /* children */ null); + Schema schema = new Schema(Collections.singletonList(dictField)); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> ctx.registerStreamingTable("t", schema, 4)); + String msg = e.getMessage(); + assertTrue( + msg.contains("dictionary-encoded") && msg.contains("color"), + () -> "expected mention of 'dictionary-encoded' and 'color' in: " + msg); + } + } + + /** + * Regression test for nested dictionary-encoded fields. {@code Data.exportVectorSchemaRoot} + * recurses into Struct/List/Map/Union children, so a dictionary at any depth would NPE on first + * write if the validation only inspected top-level fields. The error message must name the dotted + * path so the caller can find the offending child. + */ + @Test + void nestedDictionaryEncodedSchemaRejectedAtRegistration() { + try (SessionContext ctx = new SessionContext()) { + org.apache.arrow.vector.types.pojo.DictionaryEncoding enc = + new org.apache.arrow.vector.types.pojo.DictionaryEncoding( + /* id */ 7L, /* ordered */ false, /* indexType */ null); + Field nestedDict = + new Field("code", new FieldType(true, new ArrowType.Int(32, true), enc), null); + Field structField = + new Field( + "row", + FieldType.notNullable(new ArrowType.Struct()), + Collections.singletonList(nestedDict)); + Schema schema = new Schema(Collections.singletonList(structField)); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> ctx.registerStreamingTable("t", schema, 4)); + String msg = e.getMessage(); + assertTrue( + msg.contains("dictionary-encoded") && msg.contains("row.code"), + () -> "expected mention of 'dictionary-encoded' and 'row.code' in: " + msg); + } + } + + /** + * Regression test for re-entering the Tokio runtime from {@link TableSink#write}. A {@code + * TableProvider.scan} callback runs on a Tokio worker thread (the executor invokes it from inside + * the same {@code runtime.block_on} that drives query execution). If {@code write()} naively + * calls {@code Runtime::block_on} from there, Tokio panics with "Cannot start a runtime from + * within a runtime". The fix detects an existing runtime context via {@code + * Handle::try_current()} and uses {@code block_in_place} + {@code Handle::block_on} -- the + * supported pattern for synchronously waiting on a future from inside a worker. + * + *

Test shape: a {@link TableProvider} whose {@code scan} (a) returns its own one-row reader + * (so the trigger query can finish) and (b) writes one batch into a sibling streaming sink. We + * then SELECT from the trigger table to invoke the scan callback, close the sink, and SELECT the + * count from the streaming table -- it should be 1, not panic. + */ + @Test + void writeFromInsideTableProviderScanDoesNotPanic() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + TableSink streamingSink = + ctx.registerStreamingTable("streamed", oneIntSchema(), /* capacity */ 4); + try { + // Trigger table: schema is a single int "id". Its scan() implementation writes one row + // into `streamingSink` while DataFusion is on a Tokio worker thread. + TableProvider trigger = + new TableProvider() { + @Override + public Schema schema() { + return oneIntSchema(); + } + + @Override + public org.apache.arrow.vector.ipc.ArrowReader scan(BufferAllocator scanAllocator) { + // (a) Push one batch into the streaming sink. This is the call that previously + // panicked with "Cannot start a runtime from within a runtime". + try (VectorSchemaRoot pushed = makeIntBatch(allocator, new int[] {42})) { + streamingSink.write(pushed); + } + // (b) Build and serialise a one-row trigger batch as IPC bytes; return a reader + // backed by those bytes. Same shape as PR #65's InMemoryTableProvider fixture. + java.io.ByteArrayOutputStream baos = new java.io.ByteArrayOutputStream(); + try (BufferAllocator tmp = new RootAllocator(); + VectorSchemaRoot root = VectorSchemaRoot.create(oneIntSchema(), tmp)) { + IntVector v = (IntVector) root.getVector(0); + v.allocateNew(1); + v.set(0, 0); + v.setValueCount(1); + root.setRowCount(1); + try (org.apache.arrow.vector.ipc.ArrowStreamWriter writer = + new org.apache.arrow.vector.ipc.ArrowStreamWriter( + root, null, java.nio.channels.Channels.newChannel(baos))) { + writer.start(); + writer.writeBatch(); + writer.end(); + } catch (java.io.IOException e) { + throw new RuntimeException(e); + } + } + return new org.apache.arrow.vector.ipc.ArrowStreamReader( + new java.io.ByteArrayInputStream(baos.toByteArray()), scanAllocator); + } + }; + ctx.registerTable("trigger", trigger); + // Drive the scan callback. This runs on a Tokio worker; the inner sink.write must not + // panic the JVM. + try (DataFrame df = ctx.sql("SELECT count(*) FROM trigger"); + ArrowReader r = df.collect(allocator)) { + assertTrue(r.loadNextBatch()); + assertEquals(1L, ((BigIntVector) r.getVectorSchemaRoot().getVector(0)).get(0)); + } + // Close the sink (no more writes coming). + } finally { + streamingSink.close(); + } + // Now drain the streaming table; we should see the single row written from inside the + // scan callback. + try (DataFrame df = ctx.sql("SELECT count(*) FROM streamed"); + ArrowReader r = df.collect(allocator)) { + assertTrue(r.loadNextBatch()); + assertEquals(1L, ((BigIntVector) r.getVectorSchemaRoot().getVector(0)).get(0)); + } + } + } +} diff --git a/native/Cargo.lock b/native/Cargo.lock index 8c56280..71af013 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1316,6 +1316,7 @@ dependencies = [ "protoc-bin-vendored", "tokio", "tokio-metrics", + "tokio-stream", "url", ] diff --git a/native/Cargo.toml b/native/Cargo.toml index c462408..f11b7d0 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -80,11 +80,14 @@ jni = "0.21" # so we share the same `dyn ObjectStore` vtable and don't double-link. object_store = { version = "0.13", default-features = false } prost = "0.14" -tokio = { version = "1", features = ["rt-multi-thread"] } +tokio = { version = "1", features = ["rt-multi-thread", "sync"] } # Tokio runtime metrics. Optional + cfg-gated: this crate's API surface lives # behind `--cfg tokio_unstable`, so enabling the `runtime-metrics` feature also # requires the caller to set `RUSTFLAGS="--cfg tokio_unstable"` at build time. tokio-metrics = { version = "0.5", optional = true } +# Already pulled in transitively by datafusion-physical-plan; declared here so +# `streaming_table` can use `ReceiverStream` directly. +tokio-stream = "0.1" url = "2" [build-dependencies] diff --git a/native/src/lib.rs b/native/src/lib.rs index 4fd7a8a..455ce18 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -27,6 +27,7 @@ mod object_store; mod proto; mod runtime_metrics; mod schema; +mod streaming_table; mod table_provider; mod udf; @@ -1294,3 +1295,145 @@ pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerTableNa Ok(()) }) } + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerStreamingTableNative< + 'local, +>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + name: JString<'local>, + schema_ipc_bytes: JByteArray<'local>, + capacity: jint, +) -> jlong { + try_unwrap_or_throw(&mut env, 0, |env| -> JniResult { + if handle == 0 { + return Err("SessionContext handle is null".into()); + } + if capacity <= 0 { + return Err(format!("capacity must be positive, was {}", capacity).into()); + } + // SAFETY: handle is a valid Box allocated by createSessionContext. + let ctx = unsafe { &*(handle as *const SessionContext) }; + let name: String = env.get_string(&name)?.into(); + + let schema = crate::schema::decode_optional_schema(env, schema_ipc_bytes)? + .ok_or("schema bytes were null")?; + let schema = Arc::new(schema); + + let (table, sink) = + crate::streaming_table::make_streaming_table(schema, capacity as usize)?; + let _ = ctx.register_table(name.as_str(), table)?; + + // Hand the sink off to Java as an opaque pointer to an Arc clone of + // the handle. We Box+leak the Arc so Java can release exactly one + // strong ref via dropHandleNative when TableSink.close() runs -- + // any in-flight writeBatchNative call holds its own Arc clone, so + // a concurrent close() cannot turn into a use-after-free. + let sink_arc: std::sync::Arc = + std::sync::Arc::new(sink); + let sink_box = Box::new(sink_arc); + Ok(Box::into_raw(sink_box) as jlong) + }) +} + +/// Clone the producer-side `Arc` referenced by `handle`. +/// +/// SAFETY: caller asserts `handle` is the live pointer returned by +/// `registerStreamingTableNative` and not yet freed by `dropHandleNative`. +/// The Java side enforces this by zeroing its `nativeHandle` field inside +/// `close()` / `fail()` *before* calling `dropHandleNative`, so any thread +/// that observed a non-zero handle on entry to `write()` / +/// `closeSinkNative` / `failSinkNative` is racing with at most one +/// concurrent `dropHandleNative`. +/// +/// Cloning the Arc here gives every native call its own strong reference +/// for the duration of the Rust work (which can park on +/// `Sender::blocking_send`); the eventual `dropHandleNative` only releases +/// the Java-side strong ref. The inner `TableSinkHandle` is freed when the +/// last clone drops, never mid-call. +unsafe fn clone_sink_arc(handle: jlong) -> std::sync::Arc { + let arc_ref = &*(handle as *const std::sync::Arc); + std::sync::Arc::clone(arc_ref) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_TableSink_writeBatchNative<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + array_addr: jlong, + schema_addr: jlong, +) { + try_unwrap_or_throw(&mut env, (), |_env| -> JniResult<()> { + if handle == 0 { + return Err("TableSink handle is null".into()); + } + // SAFETY: see `clone_sink_arc`. The Arc clone keeps the inner + // TableSinkHandle alive even if Java drops its strong ref while + // we're parked on `Sender::blocking_send` below. + let sink = unsafe { clone_sink_arc(handle) }; + // SAFETY: array_addr and schema_addr point at FFI structs the Java + // side just populated via Data.exportVectorSchemaRoot. We take + // ownership and release them when the resulting RecordBatch drops. + let batch = + unsafe { crate::streaming_table::import_batch_from_ffi(array_addr, schema_addr)? }; + sink.write(batch).map_err(|e| e.into()) + }) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_TableSink_closeSinkNative<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, +) { + try_unwrap_or_throw(&mut env, (), |_env| -> JniResult<()> { + if handle == 0 { + return Ok(()); + } + let sink = unsafe { clone_sink_arc(handle) }; + sink.close(); + Ok(()) + }) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_TableSink_failSinkNative<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + message: JString<'local>, +) { + try_unwrap_or_throw(&mut env, (), |env| -> JniResult<()> { + if handle == 0 { + return Err("TableSink handle is null".into()); + } + let sink = unsafe { clone_sink_arc(handle) }; + let msg: String = env.get_string(&message)?.into(); + sink.fail(msg).map_err(|e| e.into()) + }) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_TableSink_dropHandleNative<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, +) { + try_unwrap_or_throw(&mut env, (), |_env| -> JniResult<()> { + if handle != 0 { + // SAFETY: handle is the Box> allocated by + // registerStreamingTableNative. Reclaiming the Box drops the + // Java-side strong ref; the inner TableSinkHandle is freed only + // when the last in-flight write/close/fail clone drops. + unsafe { + drop(Box::from_raw( + handle as *mut std::sync::Arc, + )); + } + } + Ok(()) + }) +} diff --git a/native/src/streaming_table.rs b/native/src/streaming_table.rs new file mode 100644 index 0000000..f782c3a --- /dev/null +++ b/native/src/streaming_table.rs @@ -0,0 +1,390 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Push-mode streaming table backed by a `tokio::mpsc` channel. +//! +//! Companion to PR #65's pull-mode `JavaTableProvider`. Producers (Java +//! threads) push batches into a `TableSink` that owns the sender end of the +//! channel; the registered table holds the receiver end and exposes it as a +//! `StreamingTable` over a single `PartitionStream`. Single-scan: a +//! registered streaming table can be queried at most once. + +use std::fmt; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; + +use datafusion::arrow::array::{make_array, RecordBatch, StructArray}; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}; +use datafusion::catalog::streaming::StreamingTable; +use datafusion::error::{DataFusionError, Result as DfResult}; +use datafusion::execution::context::TaskContext; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::streaming::PartitionStream; +use datafusion::physical_plan::SendableRecordBatchStream; +use futures::stream::StreamExt; +use jni::sys::jlong; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; + +/// Sideband slot for a terminal error reported by `TableSink::fail()`. Shared +/// between the producer (writer) and consumer (partition stream). Using a +/// sideband slot rather than `tx.send(Err(...))` lets `fail()` always +/// terminate immediately: the producer drops the sender so the consumer +/// observes end-of-stream, and the consumer-wrapping stream consults the +/// slot at EOS to surface the error. Otherwise `fail()` would have to call +/// `blocking_send(Err(...))` which deadlocks if the data channel is full +/// and no consumer is reading yet (e.g. failure happens before any query +/// runs against the table). +type TerminalError = Arc>>; + +/// State of the partition-stream receiver. Two reachable values: +/// +/// - `Available(Receiver)` — initial state. The first `execute()` call moves +/// it to `Taken`. +/// - `Taken` — a consumer has scanned the table. Subsequent scans throw +/// the single-scan error. +/// +/// Note: `close()` and `fail()` do **not** mutate this slot. close() is +/// the happy-path EOF and any queued batches must still be observable by +/// a not-yet-started consumer; fail() merely sets the sideband terminal +/// error. The close-vs-blocked-write wakeup uses a dedicated `closed` +/// notify on the sink (see `TableSinkHandle::closed`). +/// +/// **Owned exclusively by `JavaPartitionStream`.** `TableSinkHandle` holds +/// only the `Sender`. If the sink also held an `Arc` to this state, the +/// `Receiver` would outlive the registered table -- a producer that called +/// `write()` after dropping the `SessionContext` would park forever +/// because the channel still has a live receiver but no path to consume. +enum ReceiverState { + Available(mpsc::Receiver>), + Taken, +} + +struct JavaPartitionStream { + schema: SchemaRef, + rx: Mutex, + /// Set by the producer's `fail()` before the sender is dropped. The + /// receiver-wrapping stream consults this on end-of-stream and surfaces + /// it as the terminal item, so the consumer sees the producer's error + /// even when the data channel was full at fail-time. + terminal_error: TerminalError, +} + +impl fmt::Debug for JavaPartitionStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JavaPartitionStream") + .field("schema", &self.schema) + .finish() + } +} + +impl PartitionStream for JavaPartitionStream { + fn schema(&self) -> &SchemaRef { + &self.schema + } + + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + // We can't return a Result from execute(), so on a re-scan attempt + // we synthesise an error stream that yields a single Err. + let receiver = { + let mut slot = self.rx.lock().expect("JavaPartitionStream lock poisoned"); + let prev = std::mem::replace(&mut *slot, ReceiverState::Taken); + match prev { + ReceiverState::Available(rx) => rx, + ReceiverState::Taken => { + let err = DataFusionError::Execution( + "streaming table is single-scan and was already consumed; \ + re-register the table with a new TableSink to scan again" + .to_string(), + ); + let stream = futures::stream::once(async move { Err(err) }); + return Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream)); + } + } + }; + // The data path is the receiver stream (yields DfResult). + // After it drains, append a one-shot tail that pulls the terminal + // error out of the sideband slot (populated by `fail()`) and emits + // it as the final Err item, or nothing if the producer closed + // cleanly. Both halves yield `DfResult` so they can be + // chained directly. + let terminal = Arc::clone(&self.terminal_error); + let data_stream = ReceiverStream::new(receiver); + let terminal_stream = futures::stream::unfold(Some(terminal), |state| async move { + let terminal = state?; + let maybe_err = terminal + .lock() + .expect("terminal_error lock poisoned") + .take(); + // Map into the (item, next-state) tuple the unfold expects: + // emit one terminal-error item and end, or end immediately. + maybe_err.map(|err| (Err::(err), None)) + }); + let stream = StreamExt::chain(data_stream, terminal_stream); + Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream)) + } +} + +/// Public-shaped handle owned by the Java `TableSink`. +/// +/// Lifecycle slots: +/// +/// - `tx` (`Mutex>`): `close()` / `fail()` take it out so +/// subsequent `write()` calls fail synchronously with "sink closed". A +/// `write()` already in flight may have already cloned the sender before +/// `close()` ran; the cloned sender keeps the channel alive until the +/// parked `send` wakes. +/// - `closed_flag` (`AtomicBool`): durable close signal. Set to `true` by +/// `close()` / `fail()` once and never cleared. `write()` checks it +/// *after* registering its `Notify::notified()` future, which closes the +/// lost-wakeup window inherent in `Notify::notify_waiters` (which only +/// delivers permits to waiters that are *already* registered). The +/// `AtomicBool` is what `notify_waiters` is missing: persistence. +/// - `closed_notify` (`Notify`): wakes any `write()` parked on backpressure. +/// `notify_waiters()` is fire-and-forget, so the `closed_flag` re-check +/// after registering the `Notified` future is load-bearing -- without +/// it, a `write()` preempted between cloning `tx` and constructing the +/// future would miss the wake and park forever on a full channel with +/// no consumer. +pub(crate) struct TableSinkHandle { + schema: SchemaRef, + tx: Mutex>>>, + /// Durable close flag. Set once by close() / fail(); checked by write() + /// after registering its Notified future to defeat the lost-wakeup race. + closed_flag: Arc, + /// Wakes any `write()` parked on backpressure. Used together with + /// `closed_flag`: writers register a `notified()` future first and then + /// re-check the flag, so any `notify_waiters()` raised between the two + /// either delivers the wake (because the future is already registered) + /// or is observed by the flag re-check. + closed_notify: Arc, + /// Sideband slot for the terminal error reported by `fail()`. Shared + /// with the receiver-wrapping stream in `JavaPartitionStream::execute`. + terminal_error: TerminalError, +} + +impl TableSinkHandle { + fn new( + schema: SchemaRef, + tx: mpsc::Sender>, + terminal_error: TerminalError, + ) -> Self { + Self { + schema, + tx: Mutex::new(Some(tx)), + closed_flag: Arc::new(AtomicBool::new(false)), + closed_notify: Arc::new(tokio::sync::Notify::new()), + terminal_error, + } + } + + /// Send a batch through the channel, blocking on backpressure. The schema + /// of the imported `RecordBatch` is validated against the registered + /// schema before sending; a mismatch fails synchronously without + /// consuming channel capacity. + /// + /// Internally enters the shared Tokio runtime to `select!` between the + /// actual send and a close signal raised by `close()` / `fail()`. + /// + /// Lost-wakeup defence: `Notify::notify_waiters()` only wakes waiters + /// that have already registered their `Notified` future at the time of + /// the call. A naive implementation that creates the future *inside* + /// `select!` would miss the wake if a concurrent close fired between + /// the sender clone and the future registration -- the writer would + /// then park on a full channel with no consumer and no other path to + /// wake. The fix is to construct the `Notified` future first, *then* + /// re-check a durable `closed_flag` AtomicBool. Any close that fired + /// in between either set the flag (caught by the re-check) or notified + /// the future (caught by the select), so the wakeup cannot be lost. + pub(crate) fn write(&self, batch: RecordBatch) -> Result<(), String> { + // The imported batch came through `RecordBatch::from(StructArray)`, + // which rebuilds the schema as `Schema::new(fields)` -- i.e. it + // drops top-level Schema metadata even when the C-Data Interface + // delivered it correctly through field metadata. Compare fields + // only, then re-attach the registered SchemaRef so the consumer + // sees the original metadata. + if batch.schema().fields() != self.schema.fields() { + return Err(format!( + "TableSink batch schema {:?} does not match registered schema {:?}", + batch.schema(), + self.schema + )); + } + let batch = RecordBatch::try_new(Arc::clone(&self.schema), batch.columns().to_vec()) + .map_err(|e| format!("failed to re-attach registered schema to imported batch: {e}"))?; + let tx = self + .tx + .lock() + .expect("TableSinkHandle lock poisoned") + .as_ref() + .ok_or_else(|| "TableSink is closed".to_string())? + .clone(); + let closed_flag = Arc::clone(&self.closed_flag); + let closed_notify = Arc::clone(&self.closed_notify); + let send_future = async move { + // Register the notified future BEFORE the closed_flag re-check. + // Notify's contract: a notification arriving between this + // construction and the await is delivered to this future. + let notified = closed_notify.notified(); + tokio::pin!(notified); + // Acquire ordering pairs with Release in close()/fail() so any + // notify_waiters() that ran before the flag was set is + // guaranteed to have delivered to a waiter registered before + // this load. + if closed_flag.load(Ordering::Acquire) { + return Err("TableSink was closed concurrently".to_string()); + } + tokio::select! { + send_res = tx.send(Ok(batch)) => send_res.map_err(|_| { + "consumer side of streaming table is closed (query cancelled or completed)" + .to_string() + }), + _ = &mut notified => { + Err("TableSink was closed concurrently".to_string()) + } + } + }; + // `Runtime::block_on` panics with "Cannot start a runtime from within + // a runtime" if `write()` is invoked from a thread that is already + // inside a Tokio worker (e.g. a Java `TableProvider.scan` or UDF + // callback dispatched by DataFusion's executor while the consumer + // side of the same JNI library is driving a query). Detect that + // case via `Handle::try_current()` and use `block_in_place` + + // `Handle::block_on`, which is the supported pattern for + // synchronously waiting on a future from inside a multi-thread + // runtime worker -- it tells the scheduler this worker is about + // to block so it can spawn a replacement. + match tokio::runtime::Handle::try_current() { + Ok(handle) => tokio::task::block_in_place(|| handle.block_on(send_future)), + Err(_) => crate::runtime().block_on(send_future), + } + } + + /// Drop the sender and notify any `write()` parked on backpressure. + /// Idempotent. The receiver is left intact: a not-yet-started consumer + /// is still entitled to read every batch the producer queued before + /// closing. + /// + /// Order is load-bearing: set the durable flag *before* notifying. + /// `Notify::notify_waiters` only wakes already-registered waiters; + /// the flag's `Release` ordering pairs with the `Acquire` re-check + /// inside `write()` so a writer whose `Notified` future hadn't yet + /// registered when notify_waiters ran is guaranteed to observe the + /// flag and bail out instead of parking on a full channel. + pub(crate) fn close(&self) { + let _ = self + .tx + .lock() + .expect("TableSinkHandle lock poisoned") + .take(); + self.closed_flag.store(true, Ordering::Release); + self.closed_notify.notify_waiters(); + } + + /// Record a terminal error in the sideband slot, then drop the sender. + /// The receiver-wrapping stream observes end-of-stream on the channel + /// and surfaces the sideband error as its final item. + /// + /// Doing it this way (sideband slot + sender drop) rather than + /// `tx.blocking_send(Err(...))` is mandatory: `blocking_send` parks if + /// the data channel is full and the receiver hasn't been started yet, + /// which would deadlock the producer in pre-query failure paths + /// (capacity 1, one successful write, then fail before any consumer). + pub(crate) fn fail(&self, message: String) -> Result<(), String> { + let mut tx_slot = self.tx.lock().expect("TableSinkHandle lock poisoned"); + if tx_slot.is_none() { + return Err("TableSink is already closed".to_string()); + } + // Stash the error first so the consumer never sees end-of-stream + // without it. + *self + .terminal_error + .lock() + .expect("terminal_error lock poisoned") = Some(DataFusionError::Execution(message)); + // Drop the sender, then signal close (flag-then-notify, same + // ordering as close()). The consumer's ReceiverStream observes EOS + // once the last sender clone drops; the chained terminal-error + // tail emits the Err we just stashed. + tx_slot.take(); + self.closed_flag.store(true, Ordering::Release); + self.closed_notify.notify_waiters(); + Ok(()) + } +} + +/// Construct a `(StreamingTable, TableSinkHandle)` pair sharing an mpsc +/// channel of the given capacity. Caller registers the `StreamingTable` on a +/// `SessionContext` and hands the `TableSinkHandle` back to Java. +pub(crate) fn make_streaming_table( + schema: SchemaRef, + capacity: usize, +) -> DfResult<(Arc, TableSinkHandle)> { + let (tx, rx) = mpsc::channel(capacity); + let terminal_error: TerminalError = Arc::new(Mutex::new(None)); + // The Receiver lives entirely inside JavaPartitionStream. When the + // registered StreamingTable is dropped (e.g. SessionContext.close() + // before TableSink.close()), the partition stream drops and the + // Receiver with it, so any subsequent producer-side `Sender::send` + // returns Err immediately rather than parking on a dangling channel. + let partition = Arc::new(JavaPartitionStream { + schema: Arc::clone(&schema), + rx: Mutex::new(ReceiverState::Available(rx)), + terminal_error: Arc::clone(&terminal_error), + }); + let table = StreamingTable::try_new(Arc::clone(&schema), vec![partition])?; + Ok(( + Arc::new(table), + TableSinkHandle::new(schema, tx, terminal_error), + )) +} + +/// Decode a Java-exported batch from a `(FFI_ArrowArray, FFI_ArrowSchema)` +/// pair into a `RecordBatch`. SAFETY: caller must guarantee the two FFI +/// structs were freshly populated by `Data.exportVectorSchemaRoot` and have +/// not yet been imported elsewhere. +/// +/// The exported root surfaces as a `StructArray` whose fields are the row +/// columns; `RecordBatch::from(StructArray)` re-projects those into the +/// expected shape. +pub(crate) unsafe fn import_batch_from_ffi( + array_addr: jlong, + schema_addr: jlong, +) -> DfResult { + if array_addr == 0 || schema_addr == 0 { + return Err(DataFusionError::Execution( + "FFI array or schema address is null".to_string(), + )); + } + // Take ownership of the structs Java populated. From here we are + // responsible for releasing them via Drop. + let ffi_array = std::ptr::read(array_addr as *const FFI_ArrowArray); + let ffi_schema = std::ptr::read(schema_addr as *const FFI_ArrowSchema); + let array_data = from_ffi(ffi_array, &ffi_schema) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + let array = make_array(array_data); + let struct_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Execution( + "exported VectorSchemaRoot did not import as StructArray".to_string(), + ) + })? + .clone(); + Ok(RecordBatch::from(struct_array)) +}