diff --git a/core/src/main/java/org/apache/datafusion/ColumnarValue.java b/core/src/main/java/org/apache/datafusion/ColumnarValue.java
new file mode 100644
index 0000000..cbe0703
--- /dev/null
+++ b/core/src/main/java/org/apache/datafusion/ColumnarValue.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion;
+
+import java.util.Objects;
+
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+
+/**
+ * The value of a scalar UDF argument or result: either a per-row {@link Array} of length {@code
+ * rowCount}, or a {@link Scalar} (length-1 vector) that the framework broadcasts.
+ *
+ *
Mirrors DataFusion's {@code datafusion::logical_expr::ColumnarValue} enum. Use {@link
+ * #array(FieldVector)} and {@link #scalar(FieldVector)} factories rather than constructing the
+ * records directly so the length invariants are enforced consistently.
+ */
+public sealed interface ColumnarValue permits ColumnarValue.Array, ColumnarValue.Scalar {
+
+ /** The underlying Arrow vector. For {@link Scalar} this vector has {@code valueCount == 1}. */
+ FieldVector vector();
+
+ /** Convenience: the vector's declared Arrow type. */
+ default ArrowType dataType() {
+ return vector().getField().getType();
+ }
+
+ /** Wrap an arbitrary-length vector as an {@link Array}. */
+ static ColumnarValue array(FieldVector vector) {
+ return new Array(Objects.requireNonNull(vector, "vector"));
+ }
+
+ /**
+ * Wrap a length-1 vector as a {@link Scalar}.
+ *
+ * @throws IllegalArgumentException if {@code vector.getValueCount() != 1}
+ */
+ static ColumnarValue scalar(FieldVector vector) {
+ Objects.requireNonNull(vector, "vector");
+ if (vector.getValueCount() != 1) {
+ throw new IllegalArgumentException(
+ "Scalar vector must have valueCount == 1, got " + vector.getValueCount());
+ }
+ return new Scalar(vector);
+ }
+
+ /** Per-row Arrow vector of length equal to the batch row count. */
+ record Array(FieldVector vector) implements ColumnarValue {
+ public Array {
+ Objects.requireNonNull(vector, "vector");
+ }
+ }
+
+ /** Length-1 Arrow vector representing a single value broadcast across all rows. */
+ record Scalar(FieldVector vector) implements ColumnarValue {
+ public Scalar {
+ Objects.requireNonNull(vector, "vector");
+ if (vector.getValueCount() != 1) {
+ throw new IllegalArgumentException(
+ "Scalar vector must have valueCount == 1, got " + vector.getValueCount());
+ }
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/datafusion/ScalarFunction.java b/core/src/main/java/org/apache/datafusion/ScalarFunction.java
index 676154e..b83c636 100644
--- a/core/src/main/java/org/apache/datafusion/ScalarFunction.java
+++ b/core/src/main/java/org/apache/datafusion/ScalarFunction.java
@@ -22,7 +22,6 @@
import java.util.List;
import org.apache.arrow.memory.BufferAllocator;
-import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.types.pojo.ArrowType;
/**
@@ -46,7 +45,9 @@ public interface ScalarFunction {
*/
List argTypes();
- /** Declared return type. The returned {@link FieldVector} must have this exact type. */
+ /**
+ * Declared return type. The returned {@link ColumnarValue}'s vector must have this exact type.
+ */
ArrowType returnType();
/**
@@ -59,14 +60,16 @@ public interface ScalarFunction {
/**
* Compute the function result for one input batch.
*
- * @param allocator the {@link BufferAllocator} that MUST be used for any new {@link FieldVector}
+ * @param allocator the {@link BufferAllocator} that MUST be used for any new Arrow vector
* allocation, including the result. Buffers allocated from other allocators will not survive
* the JNI handoff.
- * @param args one {@link FieldVector} per declared argument, all of the same length. These are
- * read-only views; the implementation must NOT close them.
- * @return a {@link FieldVector} of the declared return type and the same length as the inputs.
- * Ownership transfers to the framework on return; the implementation must NOT close the
- * returned vector.
+ * @param args the per-arg {@link ColumnarValue}s and the batch row count. Each {@link
+ * ColumnarValue} is a read-only view; the implementation must NOT close its underlying
+ * vector.
+ * @return a {@link ColumnarValue} of the declared return type. If {@link ColumnarValue.Array},
+ * the underlying vector must have length {@code args.rowCount()}; if {@link
+ * ColumnarValue.Scalar}, length 1. Ownership of the returned vector transfers to the
+ * framework; the implementation must NOT close it.
*/
- FieldVector evaluate(BufferAllocator allocator, List args);
+ ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args);
}
diff --git a/core/src/main/java/org/apache/datafusion/ScalarFunctionArgs.java b/core/src/main/java/org/apache/datafusion/ScalarFunctionArgs.java
new file mode 100644
index 0000000..927fcb1
--- /dev/null
+++ b/core/src/main/java/org/apache/datafusion/ScalarFunctionArgs.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion;
+
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * Bundle of inputs passed to {@link ScalarFunction#evaluate}: the per-arg {@link ColumnarValue}s
+ * (in declared order) and the batch row count DataFusion is driving.
+ *
+ * Mirrors DataFusion's {@code datafusion::logical_expr::ScalarFunctionArgs}. {@code rowCount} is
+ * the only channel by which an array-returning UDF without array-typed inputs (all-scalar args, or
+ * nullary) can size its output. Nullary UDFs that prefer to broadcast a single value should return
+ * {@link ColumnarValue#scalar(org.apache.arrow.vector.FieldVector) ColumnarValue.scalar(...)}
+ * instead, which removes the need to consult {@code rowCount}.
+ */
+public record ScalarFunctionArgs(List args, int rowCount) {
+ public ScalarFunctionArgs {
+ args = List.copyOf(Objects.requireNonNull(args, "args"));
+ if (rowCount < 0) {
+ throw new IllegalArgumentException("rowCount must be >= 0, got " + rowCount);
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/datafusion/internal/JniBridge.java b/core/src/main/java/org/apache/datafusion/internal/JniBridge.java
index 210e996..8248357 100644
--- a/core/src/main/java/org/apache/datafusion/internal/JniBridge.java
+++ b/core/src/main/java/org/apache/datafusion/internal/JniBridge.java
@@ -19,6 +19,7 @@
package org.apache.datafusion.internal;
+import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.c.ArrowArray;
@@ -27,7 +28,9 @@
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.datafusion.ColumnarValue;
import org.apache.datafusion.ScalarFunction;
+import org.apache.datafusion.ScalarFunctionArgs;
/** Internal trampoline invoked from native code on every UDF call. Not part of the public API. */
public final class JniBridge {
@@ -40,54 +43,100 @@ public final class JniBridge {
private JniBridge() {}
+ /** argKind byte signalling a {@link ColumnarValue.Array} arg. */
+ private static final byte KIND_ARRAY = 0;
+
+ /** argKind byte signalling a {@link ColumnarValue.Scalar} arg. */
+ private static final byte KIND_SCALAR = 1;
+
/**
* Invoke a scalar UDF for one batch. Called from native code; not for application use.
*
- * @param impl the registered {@link ScalarFunction} implementation
- * @param argsArrayAddr address of a populated {@code FFI_ArrowArray} struct holding the input
- * batch as a struct array (one field per UDF argument)
- * @param argsSchemaAddr address of the matching {@code FFI_ArrowSchema}
- * @param resultArrayAddr address of an empty {@code FFI_ArrowArray} the bridge writes into
- * @param resultSchemaAddr address of an empty {@code FFI_ArrowSchema} the bridge writes into
- * @param expectedRowCount the row count the result vector must have
+ * Args arrive split into two struct arrays: {@code arrayArgs*} of length {@code rowCount}
+ * holding the {@link ColumnarValue.Array} arguments in their relative order, and {@code
+ * scalarArgs*} of length 1 holding the {@link ColumnarValue.Scalar} arguments. {@code argKinds}
+ * records the original positional order so the bridge can interleave them back into a single
+ * {@code List} for the user.
+ *
+ * @return {@link #KIND_ARRAY} if the UDF returned an Array, {@link #KIND_SCALAR} if it returned a
+ * Scalar. The native caller uses this to reconstruct the right {@code ColumnarValue} variant.
*/
- public static void invokeScalarUdf(
+ public static byte invokeScalarUdf(
ScalarFunction impl,
- long argsArrayAddr,
- long argsSchemaAddr,
+ long arrayArgsArrayAddr,
+ long arrayArgsSchemaAddr,
+ long scalarArgsArrayAddr,
+ long scalarArgsSchemaAddr,
+ byte[] argKinds,
long resultArrayAddr,
long resultSchemaAddr,
- int expectedRowCount) {
- ArrowArray argsArr = ArrowArray.wrap(argsArrayAddr);
- ArrowSchema argsSch = ArrowSchema.wrap(argsSchemaAddr);
+ int rowCount) {
+ ArrowArray arrayArr = ArrowArray.wrap(arrayArgsArrayAddr);
+ ArrowSchema arraySch = ArrowSchema.wrap(arrayArgsSchemaAddr);
+ ArrowArray scalarArr = ArrowArray.wrap(scalarArgsArrayAddr);
+ ArrowSchema scalarSch = ArrowSchema.wrap(scalarArgsSchemaAddr);
ArrowArray resultArr = ArrowArray.wrap(resultArrayAddr);
ArrowSchema resultSch = ArrowSchema.wrap(resultSchemaAddr);
- try (VectorSchemaRoot root = Data.importVectorSchemaRoot(ALLOCATOR, argsArr, argsSch, null)) {
- List argVectors = root.getFieldVectors();
+ try (VectorSchemaRoot arrayRoot =
+ Data.importVectorSchemaRoot(ALLOCATOR, arrayArr, arraySch, null);
+ VectorSchemaRoot scalarRoot =
+ Data.importVectorSchemaRoot(ALLOCATOR, scalarArr, scalarSch, null)) {
- FieldVector result = impl.evaluate(ALLOCATOR, argVectors);
+ List arrayFields = arrayRoot.getFieldVectors();
+ List scalarFields = scalarRoot.getFieldVectors();
+
+ List args = new ArrayList<>(argKinds.length);
+ int arrayIdx = 0;
+ int scalarIdx = 0;
+ for (byte kind : argKinds) {
+ if (kind == KIND_ARRAY) {
+ args.add(ColumnarValue.array(arrayFields.get(arrayIdx++)));
+ } else if (kind == KIND_SCALAR) {
+ args.add(ColumnarValue.scalar(scalarFields.get(scalarIdx++)));
+ } else {
+ throw new IllegalStateException("Unknown argKind byte: " + kind);
+ }
+ }
+
+ ColumnarValue result = impl.evaluate(ALLOCATOR, new ScalarFunctionArgs(args, rowCount));
if (result == null) {
throw new IllegalStateException("ScalarFunction.evaluate returned null");
}
- if (result.getValueCount() != expectedRowCount) {
+
+ FieldVector resultVec = result.vector();
+ byte resultKind;
+ int expectedLen;
+ if (result instanceof ColumnarValue.Array) {
+ resultKind = KIND_ARRAY;
+ expectedLen = rowCount;
+ } else {
+ resultKind = KIND_SCALAR;
+ expectedLen = 1;
+ }
+
+ if (resultVec.getValueCount() != expectedLen) {
try {
throw new IllegalStateException(
- "ScalarFunction.evaluate returned vector with "
- + result.getValueCount()
+ "ScalarFunction.evaluate returned "
+ + (resultKind == KIND_ARRAY ? "Array" : "Scalar")
+ + " vector with "
+ + resultVec.getValueCount()
+ " rows; expected "
- + expectedRowCount);
+ + expectedLen);
} finally {
- result.close();
+ resultVec.close();
}
}
try {
- Data.exportVector(ALLOCATOR, result, null, resultArr, resultSch);
+ Data.exportVector(ALLOCATOR, resultVec, null, resultArr, resultSch);
} finally {
- result.close();
+ resultVec.close();
}
+
+ return resultKind;
}
}
}
diff --git a/core/src/test/java/org/apache/datafusion/ColumnarValueTest.java b/core/src/test/java/org/apache/datafusion/ColumnarValueTest.java
new file mode 100644
index 0000000..42879e0
--- /dev/null
+++ b/core/src/test/java/org/apache/datafusion/ColumnarValueTest.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.junit.jupiter.api.Test;
+
+class ColumnarValueTest {
+
+ private static final ArrowType INT32 = new ArrowType.Int(32, true);
+
+ @Test
+ void array_factory_returnsArrayVariant() {
+ try (BufferAllocator allocator = new RootAllocator();
+ IntVector v = new IntVector("v", allocator)) {
+ v.allocateNew(3);
+ v.setValueCount(3);
+ ColumnarValue cv = ColumnarValue.array(v);
+ assertSame(v, cv.vector());
+ assertEquals(INT32, cv.dataType());
+ }
+ }
+
+ @Test
+ void scalar_factory_returnsScalarVariant() {
+ try (BufferAllocator allocator = new RootAllocator();
+ IntVector v = new IntVector("v", allocator)) {
+ v.allocateNew(1);
+ v.set(0, 42);
+ v.setValueCount(1);
+ ColumnarValue cv = ColumnarValue.scalar(v);
+ assertSame(v, cv.vector());
+ assertEquals(INT32, cv.dataType());
+ }
+ }
+
+ @Test
+ void scalar_factory_rejectsNonOneLength() {
+ try (BufferAllocator allocator = new RootAllocator();
+ IntVector v = new IntVector("v", allocator)) {
+ v.allocateNew(2);
+ v.setValueCount(2);
+ IllegalArgumentException ex =
+ assertThrows(IllegalArgumentException.class, () -> ColumnarValue.scalar(v));
+ assertEquals("Scalar vector must have valueCount == 1, got 2", ex.getMessage());
+ }
+ }
+
+ @Test
+ void array_factory_rejectsNull() {
+ assertThrows(NullPointerException.class, () -> ColumnarValue.array(null));
+ }
+
+ @Test
+ void scalar_factory_rejectsNull() {
+ assertThrows(NullPointerException.class, () -> ColumnarValue.scalar(null));
+ }
+}
diff --git a/core/src/test/java/org/apache/datafusion/ScalarFunctionArgsTest.java b/core/src/test/java/org/apache/datafusion/ScalarFunctionArgsTest.java
new file mode 100644
index 0000000..df9f480
--- /dev/null
+++ b/core/src/test/java/org/apache/datafusion/ScalarFunctionArgsTest.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.junit.jupiter.api.Test;
+
+class ScalarFunctionArgsTest {
+
+ @Test
+ void construct_emptyArgs_zeroRows_ok() {
+ ScalarFunctionArgs a = new ScalarFunctionArgs(List.of(), 0);
+ assertEquals(List.of(), a.args());
+ assertEquals(0, a.rowCount());
+ }
+
+ @Test
+ void construct_rejectsNullArgs() {
+ assertThrows(NullPointerException.class, () -> new ScalarFunctionArgs(null, 3));
+ }
+
+ @Test
+ void construct_rejectsNegativeRowCount() {
+ assertThrows(IllegalArgumentException.class, () -> new ScalarFunctionArgs(List.of(), -1));
+ }
+
+ @Test
+ void construct_copiesArgsDefensively() {
+ try (BufferAllocator allocator = new RootAllocator();
+ IntVector v = new IntVector("v", allocator)) {
+ v.allocateNew(1);
+ v.setValueCount(1);
+ List source = new ArrayList<>();
+ source.add(ColumnarValue.scalar(v));
+ ScalarFunctionArgs a = new ScalarFunctionArgs(source, 1);
+ source.clear();
+ assertEquals(1, a.args().size());
+ assertThrows(UnsupportedOperationException.class, () -> a.args().clear());
+ }
+ }
+
+ @Test
+ void args_singletonCase_preservesValue() {
+ try (BufferAllocator allocator = new RootAllocator();
+ IntVector v = new IntVector("v", allocator)) {
+ v.allocateNew(1);
+ v.set(0, 7);
+ v.setValueCount(1);
+ ScalarFunctionArgs a =
+ new ScalarFunctionArgs(Collections.singletonList(ColumnarValue.scalar(v)), 5);
+ assertEquals(5, a.rowCount());
+ assertTrue(a.args().get(0) instanceof ColumnarValue.Scalar);
+ }
+ }
+}
diff --git a/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java b/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java
index 97f5f52..3e14580 100644
--- a/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java
+++ b/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java
@@ -86,8 +86,8 @@ static final class AddOne extends AbstractScalarFunction {
}
@Override
- public FieldVector evaluate(BufferAllocator allocator, List args) {
- IntVector in = (IntVector) args.get(0);
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
+ IntVector in = (IntVector) args.args().get(0).vector();
IntVector out = new IntVector("add_one_out", allocator);
int n = in.getValueCount();
out.allocateNew(n);
@@ -99,7 +99,7 @@ public FieldVector evaluate(BufferAllocator allocator, List args) {
}
}
out.setValueCount(n);
- return out;
+ return ColumnarValue.array(out);
}
}
@@ -133,11 +133,11 @@ static final class Concat extends AbstractScalarFunction {
}
@Override
- public FieldVector evaluate(BufferAllocator allocator, List args) {
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
org.apache.arrow.vector.VarCharVector left =
- (org.apache.arrow.vector.VarCharVector) args.get(0);
+ (org.apache.arrow.vector.VarCharVector) args.args().get(0).vector();
org.apache.arrow.vector.VarCharVector right =
- (org.apache.arrow.vector.VarCharVector) args.get(1);
+ (org.apache.arrow.vector.VarCharVector) args.args().get(1).vector();
org.apache.arrow.vector.VarCharVector out =
new org.apache.arrow.vector.VarCharVector("concat_out", allocator);
int n = left.getValueCount();
@@ -155,7 +155,7 @@ public FieldVector evaluate(BufferAllocator allocator, List args) {
}
}
out.setValueCount(n);
- return out;
+ return ColumnarValue.array(out);
}
}
@@ -188,8 +188,9 @@ static final class Square extends AbstractScalarFunction {
}
@Override
- public FieldVector evaluate(BufferAllocator allocator, List args) {
- org.apache.arrow.vector.Float8Vector in = (org.apache.arrow.vector.Float8Vector) args.get(0);
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
+ org.apache.arrow.vector.Float8Vector in =
+ (org.apache.arrow.vector.Float8Vector) args.args().get(0).vector();
org.apache.arrow.vector.Float8Vector out =
new org.apache.arrow.vector.Float8Vector("square_out", allocator);
int n = in.getValueCount();
@@ -203,7 +204,7 @@ public FieldVector evaluate(BufferAllocator allocator, List args) {
}
}
out.setValueCount(n);
- return out;
+ return ColumnarValue.array(out);
}
}
@@ -252,7 +253,7 @@ static final class ReturnsNull extends AbstractScalarFunction {
}
@Override
- public FieldVector evaluate(BufferAllocator allocator, List args) {
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
return null;
}
}
@@ -283,13 +284,13 @@ static final class WrongRowCount extends AbstractScalarFunction {
}
@Override
- public FieldVector evaluate(BufferAllocator allocator, List args) {
- IntVector in = (IntVector) args.get(0);
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
+ IntVector in = (IntVector) args.args().get(0).vector();
IntVector out = new IntVector("out", allocator);
out.allocateNew(in.getValueCount() + 1); // off by one
for (int i = 0; i < in.getValueCount() + 1; i++) out.set(i, 0);
out.setValueCount(in.getValueCount() + 1);
- return out;
+ return ColumnarValue.array(out);
}
}
@@ -319,14 +320,15 @@ static final class WrongType extends AbstractScalarFunction {
}
@Override
- public FieldVector evaluate(BufferAllocator allocator, List args) {
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
// Declared return type is Int32; return Float64.
org.apache.arrow.vector.Float8Vector out =
new org.apache.arrow.vector.Float8Vector("out", allocator);
- out.allocateNew(args.get(0).getValueCount());
- for (int i = 0; i < args.get(0).getValueCount(); i++) out.set(i, 0.0);
- out.setValueCount(args.get(0).getValueCount());
- return out;
+ FieldVector in = args.args().get(0).vector();
+ out.allocateNew(in.getValueCount());
+ for (int i = 0; i < in.getValueCount(); i++) out.set(i, 0.0);
+ out.setValueCount(in.getValueCount());
+ return ColumnarValue.array(out);
}
}
@@ -356,7 +358,7 @@ static final class ThrowsIAE extends AbstractScalarFunction {
}
@Override
- public FieldVector evaluate(BufferAllocator allocator, List args) {
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
throw new IllegalArgumentException("custom boom from UDF");
}
}
@@ -448,6 +450,152 @@ void udfAppliedToMultiRowQuery_processesAllRows() throws Exception {
}
}
+ /**
+ * Nullary UDF returning a length-1 Float8 vector. Marked VOLATILE so DataFusion's constant folder
+ * does not collapse the call before reaching us. Exercises the path that the abandoned PR #57
+ * added a separate rowCount parameter for: a nullary UDF can now broadcast its value through
+ * {@link ColumnarValue#scalar(FieldVector)} and the framework handles per-row expansion.
+ */
+ static final class JavaPi extends AbstractScalarFunction {
+ JavaPi() {
+ super("java_pi", List.of(), FLOAT64, Volatility.VOLATILE);
+ }
+
+ @Override
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
+ org.apache.arrow.vector.Float8Vector out =
+ new org.apache.arrow.vector.Float8Vector("pi_out", allocator);
+ out.allocateNew(1);
+ out.set(0, Math.PI);
+ out.setValueCount(1);
+ return ColumnarValue.scalar(out);
+ }
+ }
+
+ @Test
+ void nullaryScalarReturnUdf_overMultiRowQuery_broadcasts() throws Exception {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(new ScalarUdf(new JavaPi()));
+
+ try (DataFrame df = ctx.sql("SELECT java_pi() AS p FROM (VALUES (1), (2), (3)) AS t(x)");
+ ArrowReader r = df.collect(allocator)) {
+ assertEquals(true, r.loadNextBatch());
+ VectorSchemaRoot root = r.getVectorSchemaRoot();
+ org.apache.arrow.vector.Float8Vector p =
+ (org.apache.arrow.vector.Float8Vector) root.getVector("p");
+ assertEquals(3, p.getValueCount());
+ assertEquals(Math.PI, p.get(0), 0.0);
+ assertEquals(Math.PI, p.get(1), 0.0);
+ assertEquals(Math.PI, p.get(2), 0.0);
+ }
+ }
+ }
+
+ /**
+ * UDF over (int_col, int_literal). On every invocation it asserts that arg 0 is an Array and arg
+ * 1 is a Scalar (length-1 vector). Proves the FFI protocol preserves scalar-ness end-to-end
+ * rather than materialising the literal to a length-N array on the native side.
+ */
+ static final class AssertSecondArgIsScalar extends AbstractScalarFunction {
+ AssertSecondArgIsScalar() {
+ super("assert_scalar_arg", List.of(INT32, INT32), INT32, Volatility.IMMUTABLE);
+ }
+
+ @Override
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
+ if (!(args.args().get(0) instanceof ColumnarValue.Array)) {
+ throw new AssertionError(
+ "arg 0 expected Array, got " + args.args().get(0).getClass().getSimpleName());
+ }
+ if (!(args.args().get(1) instanceof ColumnarValue.Scalar)) {
+ throw new AssertionError(
+ "arg 1 expected Scalar, got " + args.args().get(1).getClass().getSimpleName());
+ }
+ IntVector left = (IntVector) args.args().get(0).vector();
+ IntVector right = (IntVector) args.args().get(1).vector();
+ if (right.getValueCount() != 1) {
+ throw new AssertionError(
+ "Scalar arg vector should have length 1, got " + right.getValueCount());
+ }
+ int rightVal = right.get(0);
+ IntVector out = new IntVector("out", allocator);
+ int n = left.getValueCount();
+ out.allocateNew(n);
+ for (int i = 0; i < n; i++) {
+ if (left.isNull(i)) {
+ out.setNull(i);
+ } else {
+ out.set(i, left.get(i) + rightVal);
+ }
+ }
+ out.setValueCount(n);
+ return ColumnarValue.array(out);
+ }
+ }
+
+ @Test
+ void scalarLiteralArg_arrivesAsScalarColumnarValue() throws Exception {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(new ScalarUdf(new AssertSecondArgIsScalar()));
+
+ try (DataFrame df =
+ ctx.sql(
+ "SELECT assert_scalar_arg(x, CAST(100 AS INT)) AS y"
+ + " FROM (VALUES (CAST(1 AS INT)), (CAST(2 AS INT)), (CAST(3 AS INT)))"
+ + " AS t(x)");
+ ArrowReader r = df.collect(allocator)) {
+ assertEquals(true, r.loadNextBatch());
+ VectorSchemaRoot root = r.getVectorSchemaRoot();
+ IntVector y = (IntVector) root.getVector("y");
+ assertEquals(3, y.getValueCount());
+ assertEquals(101, y.get(0));
+ assertEquals(102, y.get(1));
+ assertEquals(103, y.get(2));
+ }
+ }
+ }
+
+ /** UDF that ignores its input and returns a constant Scalar. */
+ static final class IgnoreInputReturnFortyTwo extends AbstractScalarFunction {
+ IgnoreInputReturnFortyTwo() {
+ super("forty_two", List.of(INT32), INT32, Volatility.IMMUTABLE);
+ }
+
+ @Override
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
+ IntVector out = new IntVector("out", allocator);
+ out.allocateNew(1);
+ out.set(0, 42);
+ out.setValueCount(1);
+ return ColumnarValue.scalar(out);
+ }
+ }
+
+ @Test
+ void udfReturningScalar_isBroadcastByFramework() throws Exception {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(new ScalarUdf(new IgnoreInputReturnFortyTwo()));
+
+ try (DataFrame df =
+ ctx.sql(
+ "SELECT forty_two(x) AS y"
+ + " FROM (VALUES (CAST(1 AS INT)), (CAST(2 AS INT)),"
+ + " (CAST(3 AS INT)), (CAST(4 AS INT)), (CAST(5 AS INT))) AS t(x)");
+ ArrowReader r = df.collect(allocator)) {
+ assertEquals(true, r.loadNextBatch());
+ VectorSchemaRoot root = r.getVectorSchemaRoot();
+ IntVector y = (IntVector) root.getVector("y");
+ assertEquals(5, y.getValueCount());
+ for (int i = 0; i < 5; i++) {
+ assertEquals(42, y.get(i));
+ }
+ }
+ }
+ }
+
@Test
void volatilityBytesRoundTrip_forAllThreeKinds() throws Exception {
for (Volatility v : Volatility.values()) {
diff --git a/docs/source/user-guide/scalar-udf.md b/docs/source/user-guide/scalar-udf.md
index 8b78410..b53b32f 100644
--- a/docs/source/user-guide/scalar-udf.md
+++ b/docs/source/user-guide/scalar-udf.md
@@ -19,9 +19,10 @@ under the License.
# Scalar UDFs
-A scalar UDF is a Java-implemented SQL function that operates on one row at a
-time, expressed in vectorised form: each invocation receives a batch of input
-columns and returns a single output column of the same length.
+A scalar UDF is a Java-implemented SQL function that operates one row at a time,
+expressed in vectorised form: each invocation receives a batch of input columns
+and returns either a per-row output column of the same length (`Array`) or a
+single value broadcast to every row (`Scalar`).
## Implement
@@ -32,10 +33,11 @@ per-batch `evaluate` body:
```java
import java.util.List;
import org.apache.arrow.memory.BufferAllocator;
-import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.datafusion.ColumnarValue;
import org.apache.datafusion.ScalarFunction;
+import org.apache.datafusion.ScalarFunctionArgs;
import org.apache.datafusion.Volatility;
public final class AddOne implements ScalarFunction {
@@ -47,8 +49,8 @@ public final class AddOne implements ScalarFunction {
@Override public Volatility volatility() { return Volatility.IMMUTABLE; }
@Override
- public FieldVector evaluate(BufferAllocator allocator, List args) {
- IntVector in = (IntVector) args.get(0);
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
+ IntVector in = (IntVector) args.args().get(0).vector();
IntVector out = new IntVector("add_one", allocator);
out.allocateNew(in.getValueCount());
for (int i = 0; i < in.getValueCount(); i++) {
@@ -56,15 +58,50 @@ public final class AddOne implements ScalarFunction {
else out.set(i, in.get(i) + 1);
}
out.setValueCount(in.getValueCount());
- return out;
+ return ColumnarValue.array(out);
}
}
```
+Each entry in `args.args()` is a `ColumnarValue` — either `ColumnarValue.Array`
+(a per-row vector of length `args.rowCount()`) or `ColumnarValue.Scalar` (a
+length-1 vector representing a single literal or folded constant). Access the
+underlying Arrow vector with `.vector()`.
+
Allocate any new vectors — including the result — from the supplied
`BufferAllocator`. The input vectors are read-only views; do not close them.
Ownership of the returned vector transfers to the framework on return.
+## Returning a Scalar
+
+Functions that yield a single value (nullary constants like `pi()`, or any
+function that wants the framework to broadcast a result across the batch) can
+return `ColumnarValue.scalar(...)` over a length-1 vector:
+
+```java
+public final class JavaPi implements ScalarFunction {
+ private static final ArrowType FLOAT64 =
+ new ArrowType.FloatingPoint(org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE);
+
+ @Override public String name() { return "java_pi"; }
+ @Override public List argTypes() { return List.of(); }
+ @Override public ArrowType returnType() { return FLOAT64; }
+ @Override public Volatility volatility() { return Volatility.VOLATILE; }
+
+ @Override
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
+ org.apache.arrow.vector.Float8Vector out =
+ new org.apache.arrow.vector.Float8Vector("pi", allocator);
+ out.allocateNew(1);
+ out.set(0, Math.PI);
+ out.setValueCount(1);
+ return ColumnarValue.scalar(out);
+ }
+}
+```
+
+The framework expands the scalar across `args.rowCount()` rows automatically.
+
## Register
Wrap the implementation in a `ScalarUdf` and pass it to
@@ -90,9 +127,11 @@ non-deterministic functions.
## Errors
If the UDF throws, the exception class and message surface in the
-`RuntimeException` raised from `collect()`. If the returned vector is `null`,
-has the wrong row count, or the wrong type, the runtime raises a
-`RuntimeException` with a descriptive message.
+`RuntimeException` raised from `collect()`. If the returned `ColumnarValue` is
+`null`, an Array result's vector length does not equal `args.rowCount()`, or
+the result's Arrow type differs from the declared return type, the runtime
+raises a `RuntimeException` with a descriptive message. A Scalar result whose
+vector is not length-1 is rejected at the `ColumnarValue.scalar` factory.
## Threading
diff --git a/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java b/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java
index c27bff0..d9416b1 100644
--- a/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java
+++ b/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java
@@ -23,13 +23,14 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
-import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.datafusion.ColumnarValue;
import org.apache.datafusion.DataFrame;
import org.apache.datafusion.ScalarFunction;
+import org.apache.datafusion.ScalarFunctionArgs;
import org.apache.datafusion.ScalarUdf;
import org.apache.datafusion.SessionContext;
import org.apache.datafusion.Volatility;
@@ -62,8 +63,8 @@ public Volatility volatility() {
}
@Override
- public FieldVector evaluate(BufferAllocator allocator, List args) {
- IntVector in = (IntVector) args.get(0);
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
+ IntVector in = (IntVector) args.args().get(0).vector();
IntVector out = new IntVector("add_one_out", allocator);
int n = in.getValueCount();
out.allocateNew(n);
@@ -75,7 +76,7 @@ public FieldVector evaluate(BufferAllocator allocator, List args) {
}
}
out.setValueCount(n);
- return out;
+ return ColumnarValue.array(out);
}
}
diff --git a/native/src/lib.rs b/native/src/lib.rs
index f6f16d3..fe46a07 100644
--- a/native/src/lib.rs
+++ b/native/src/lib.rs
@@ -663,7 +663,7 @@ pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerScalarU
let invoke_method = env.get_static_method_id(
&bridge_class_local,
"invokeScalarUdf",
- "(Lorg/apache/datafusion/ScalarFunction;JJJJI)V",
+ "(Lorg/apache/datafusion/ScalarFunction;JJJJ[BJJI)B",
)?;
let java_udf = crate::udf::JavaScalarUdf {
diff --git a/native/src/udf.rs b/native/src/udf.rs
index 62d0e24..d2b18b4 100644
--- a/native/src/udf.rs
+++ b/native/src/udf.rs
@@ -19,18 +19,18 @@
use std::any::Any;
use std::fmt;
-use std::sync::Arc;
use datafusion::arrow::array::{make_array, Array, ArrayRef, StructArray};
use datafusion::arrow::datatypes::{DataType, Field, Fields};
use datafusion::arrow::ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema};
+use datafusion::common::ScalarValue;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use jni::objects::{GlobalRef, JStaticMethodID, JThrowable};
use jni::signature::{Primitive, ReturnType};
-use jni::sys::{jlong, jvalue};
+use jni::sys::{jbyte, jlong, jvalue};
use jni::JNIEnv;
pub(crate) struct JavaScalarUdf {
@@ -99,21 +99,8 @@ impl ScalarUDFImpl for JavaScalarUdf {
) -> datafusion::error::Result {
let number_rows = args.number_rows;
- // 1. Materialise scalars to arrays so all columns are length-N.
- let arrays: Vec = args
- .args
- .iter()
- .map(|cv| cv.clone().into_array(number_rows))
- .collect::>>()?;
-
- // 2. Build a single struct array carrying all arg columns. Field names/types come
- // from the signature's Exact type list (matches what the Java caller declared).
- let signature_fields: Vec> = match &self.signature.type_signature {
- TypeSignature::Exact(types) => types
- .iter()
- .enumerate()
- .map(|(i, ty)| Arc::new(Field::new(format!("arg{}", i), ty.clone(), true)))
- .collect(),
+ let signature_types: &[DataType] = match &self.signature.type_signature {
+ TypeSignature::Exact(types) => types,
_ => {
return Err(DataFusionError::Internal(
"JavaScalarUdf signature is not Exact; only Signature::exact is supported"
@@ -122,43 +109,87 @@ impl ScalarUDFImpl for JavaScalarUdf {
}
};
- let fields = Fields::from(
- signature_fields
- .iter()
- .map(|f| f.as_ref().clone())
- .collect::>(),
- );
- let struct_array = StructArray::try_new_with_length(fields, arrays, None, number_rows)
+ if args.args.len() != signature_types.len() {
+ return Err(DataFusionError::Internal(format!(
+ "Java UDF '{}' called with {} args; signature declares {}",
+ self.name,
+ args.args.len(),
+ signature_types.len()
+ )));
+ }
+
+ // 1. Partition args by kind. ColumnarValue::Scalar stays as a length-1 array so the Java
+ // side observes it as a Scalar; ColumnarValue::Array passes through at full length.
+ let mut array_arrays: Vec = Vec::new();
+ let mut array_fields: Vec = Vec::new();
+ let mut scalar_arrays: Vec = Vec::new();
+ let mut scalar_fields: Vec = Vec::new();
+ let mut arg_kinds: Vec = Vec::with_capacity(args.args.len());
+
+ for (i, cv) in args.args.iter().enumerate() {
+ let ty = signature_types[i].clone();
+ match cv {
+ ColumnarValue::Array(a) => {
+ array_fields.push(Field::new(format!("arg{}", array_arrays.len()), ty, true));
+ array_arrays.push(a.clone());
+ arg_kinds.push(0);
+ }
+ ColumnarValue::Scalar(s) => {
+ let arr = s.to_array_of_size(1)?;
+ scalar_fields.push(Field::new(format!("arg{}", scalar_arrays.len()), ty, true));
+ scalar_arrays.push(arr);
+ arg_kinds.push(1);
+ }
+ }
+ }
+
+ // 2. Build the two struct arrays. Empty field+array vectors with the appropriate length
+ // cover nullary and all-one-kind cases.
+ let array_struct = StructArray::try_new_with_length(
+ Fields::from(array_fields),
+ array_arrays,
+ None,
+ number_rows,
+ )
+ .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
+ let scalar_struct =
+ StructArray::try_new_with_length(Fields::from(scalar_fields), scalar_arrays, None, 1)
+ .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
+
+ let (array_ffi_arr, array_ffi_sch) = to_ffi(&array_struct.into_data())
+ .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
+ let (scalar_ffi_arr, scalar_ffi_sch) = to_ffi(&scalar_struct.into_data())
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
- let args_data = struct_array.into_data();
- let (args_ffi_array, args_ffi_schema) =
- to_ffi(&args_data).map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
// 3. Pre-allocate empty FFI structs for the result.
- let result_ffi_array = FFI_ArrowArray::empty();
- let result_ffi_schema = FFI_ArrowSchema::empty();
+ let result_ffi_arr = FFI_ArrowArray::empty();
+ let result_ffi_sch = FFI_ArrowSchema::empty();
// 4. Box for stable addresses across the JNI call.
- let mut args_array_box = Box::new(args_ffi_array);
- let mut args_schema_box = Box::new(args_ffi_schema);
- let mut result_array_box = Box::new(result_ffi_array);
- let mut result_schema_box = Box::new(result_ffi_schema);
-
- let args_array_addr = args_array_box.as_mut() as *mut _ as jlong;
- let args_schema_addr = args_schema_box.as_mut() as *mut _ as jlong;
- let result_array_addr = result_array_box.as_mut() as *mut _ as jlong;
- let result_schema_addr = result_schema_box.as_mut() as *mut _ as jlong;
+ let mut array_arr_box = Box::new(array_ffi_arr);
+ let mut array_sch_box = Box::new(array_ffi_sch);
+ let mut scalar_arr_box = Box::new(scalar_ffi_arr);
+ let mut scalar_sch_box = Box::new(scalar_ffi_sch);
+ let mut result_arr_box = Box::new(result_ffi_arr);
+ let mut result_sch_box = Box::new(result_ffi_sch);
+
+ let array_arr_addr = array_arr_box.as_mut() as *mut _ as jlong;
+ let array_sch_addr = array_sch_box.as_mut() as *mut _ as jlong;
+ let scalar_arr_addr = scalar_arr_box.as_mut() as *mut _ as jlong;
+ let scalar_sch_addr = scalar_sch_box.as_mut() as *mut _ as jlong;
+ let result_arr_addr = result_arr_box.as_mut() as *mut _ as jlong;
+ let result_sch_addr = result_sch_box.as_mut() as *mut _ as jlong;
// 5. Attach JNI to current thread.
let mut env = crate::jvm()
.attach_current_thread()
.map_err(|e| DataFusionError::Execution(format!("JNI attach failed: {}", e)))?;
- // 6. Call JniBridge.invokeScalarUdf(udf, args*, result*, expectedRowCount).
- //
- // Build the jvalue argument array for call_static_method_unchecked.
- // SAFETY: we build the args inline and pass them immediately; the JObject
- // pointed to by udf_global_ref is alive for the duration of this call.
+ // 6. Build the byte[] for argKinds inside the JVM heap. JNI local; freed when env drops.
+ let arg_kinds_array = env.byte_array_from_slice(&arg_kinds).map_err(|e| {
+ DataFusionError::Execution(format!("byte_array_from_slice failed: {}", e))
+ })?;
+
let expected_rows = i32::try_from(number_rows).map_err(|_| {
DataFusionError::Execution(format!(
"batch row count {} exceeds i32::MAX; UDFs cannot handle batches larger than 2^31 - 1 rows",
@@ -167,29 +198,20 @@ impl ScalarUDFImpl for JavaScalarUdf {
})?;
let udf_jobject = self.udf_global_ref.as_obj();
- // SAFETY: udf_jobject is derived from a GlobalRef alive for the duration of this
- // function. The raw pointer is only read by the JNI call below, which happens
- // before any code that could drop udf_global_ref.
- let call_args: [jvalue; 6] = [
- // ScalarFunction instance
+ // SAFETY: udf_global_ref and arg_kinds_array are alive for the duration of this call.
+ let call_args: [jvalue; 9] = [
jvalue {
l: udf_jobject.as_raw(),
},
- // argsArrayAddr
- jvalue { j: args_array_addr },
- // argsSchemaAddr
+ jvalue { j: array_arr_addr },
+ jvalue { j: array_sch_addr },
+ jvalue { j: scalar_arr_addr },
+ jvalue { j: scalar_sch_addr },
jvalue {
- j: args_schema_addr,
+ l: arg_kinds_array.as_raw(),
},
- // resultArrayAddr
- jvalue {
- j: result_array_addr,
- },
- // resultSchemaAddr
- jvalue {
- j: result_schema_addr,
- },
- // expectedRowCount
+ jvalue { j: result_arr_addr },
+ jvalue { j: result_sch_addr },
jvalue { i: expected_rows },
];
@@ -197,12 +219,12 @@ impl ScalarUDFImpl for JavaScalarUdf {
env.call_static_method_unchecked(
&self.bridge_class,
self.invoke_method,
- ReturnType::Primitive(Primitive::Void),
+ ReturnType::Primitive(Primitive::Byte),
&call_args,
)
};
- // 7. If Java threw, translate to DataFusionError. Always check exception_check first.
+ // 7. Java-exception path: translate to DataFusionError.
if env.exception_check().unwrap_or(false) {
let throwable = env.exception_occurred().map_err(|e| {
DataFusionError::Execution(format!("exception_occurred failed: {}", e))
@@ -211,19 +233,22 @@ impl ScalarUDFImpl for JavaScalarUdf {
let message = jthrowable_to_string(&mut env, &throwable, &self.name);
return Err(DataFusionError::Execution(message));
}
- call_result.map_err(|e| DataFusionError::Execution(format!("JNI call failed: {}", e)))?;
-
- // 8. Import result. from_ffi consumes the FFI_ArrowArray.
- let result_array = *result_array_box;
- let result_schema = *result_schema_box;
- // SAFETY: Java's `Data.exportVector` populated `result_array_box` and
- // `result_schema_box` in place via the C Data Interface, and the
- // exception check above guarantees the call succeeded without
- // throwing — so the FFI structs are fully initialized.
- let result_data = unsafe { from_ffi(result_array, &result_schema) }
+
+ let result_kind: jbyte = call_result
+ .map_err(|e| DataFusionError::Execution(format!("JNI call failed: {}", e)))?
+ .b()
+ .map_err(|e| {
+ DataFusionError::Execution(format!("invokeScalarUdf return decode failed: {}", e))
+ })?;
+
+ // 8. Import the result vector. from_ffi consumes the FFI_ArrowArray.
+ let result_array_ffi = *result_arr_box;
+ let result_schema_ffi = *result_sch_box;
+ // SAFETY: bridge populated both structs via Arrow C Data Interface; the exception check
+ // above confirmed no Java exception, so the FFI structs are fully initialised.
+ let result_data = unsafe { from_ffi(result_array_ffi, &result_schema_ffi) }
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
- // 9. Validate type.
if result_data.data_type() != &self.return_type {
return Err(DataFusionError::Execution(format!(
"Java UDF '{}' returned vector of type {:?}; declared return type was {:?}",
@@ -234,7 +259,25 @@ impl ScalarUDFImpl for JavaScalarUdf {
}
let array: ArrayRef = make_array(result_data);
- Ok(ColumnarValue::Array(array))
+
+ match result_kind {
+ 0 => Ok(ColumnarValue::Array(array)),
+ 1 => {
+ if array.len() != 1 {
+ return Err(DataFusionError::Internal(format!(
+ "Java UDF '{}' returned Scalar with length {} (expected 1)",
+ self.name,
+ array.len()
+ )));
+ }
+ let scalar = ScalarValue::try_from_array(&array, 0)?;
+ Ok(ColumnarValue::Scalar(scalar))
+ }
+ other => Err(DataFusionError::Internal(format!(
+ "Java UDF '{}' returned unknown kind byte: {}",
+ self.name, other
+ ))),
+ }
}
}